package de.jungblut.online.ml;
import java.util.Random;
import java.util.function.Supplier;
import java.util.stream.Stream;
import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.sparse.SequentialSparseDoubleVector;
import de.jungblut.online.minimizer.StochasticMinimizer;
public abstract class AbstractMinimizingOnlineLearner<M extends Model> extends
AbstractOnlineLearner<M> {
protected final StochasticMinimizer minimizer;
protected Random random = new Random();
protected int numPasses = 1;
protected boolean sparseWeights;
public AbstractMinimizingOnlineLearner(StochasticMinimizer minimizer) {
this.minimizer = minimizer;
}
@Override
public M train(Supplier<Stream<FeatureOutcomePair>> streamSupplier) {
init(streamSupplier);
DoubleVector weights = randomInitialize(featureDimension);
DoubleVector minimized = minimizer.minimize(weights, streamSupplier,
this::observeExampleSafe, numPasses, verbose);
return createModel(minimized);
}
/**
* Observes the next example.
*
* @param next the next feature/outcome pair.
* @param weights the current weights.
* @return a cost gradient tuple that can be used for minimization.
*/
protected abstract CostGradientTuple observeExample(FeatureOutcomePair next,
DoubleVector weights);
/**
* Creates a model with the given minimized weights.
*
* @param weights the learned weights.
* @return a model that describes the weights.
*/
protected abstract M createModel(DoubleVector weights);
protected CostGradientTuple observeExampleSafe(FeatureOutcomePair next,
DoubleVector weights) {
// do some sanity checks before we actually do the computation
Preconditions.checkArgument(weights.getDimension() == featureDimension,
"Feature dimension must match the weight dimension! Expected: "
+ featureDimension + ", given " + weights.getDimension());
Preconditions.checkArgument(featureDimension == next.getFeature()
.getDimension(),
"Feature dimension must match the initially set dimension! Expected: "
+ featureDimension + ", given " + next.getFeature().getDimension());
Preconditions.checkArgument(outcomeDimension == next.getOutcome()
.getDimension(),
"Outcome dimension must match the initially set dimension! Expected: "
+ outcomeDimension + ", given " + next.getOutcome().getDimension());
return observeExample(next, weights);
}
protected DoubleVector randomInitialize(int dimension) {
if (sparseWeights) {
return new SequentialSparseDoubleVector(dimension);
} else {
// if the dimension is too big, we don't want to waste time on generating
// randoms
if (dimension < (2 << 15)) {
double[] array = new double[dimension];
for (int i = 0; i < array.length; i++) {
array[i] = (random.nextDouble() * 2) - 1d;
}
return new DenseDoubleVector(array);
}
return new DenseDoubleVector(dimension);
}
}
public void setRandom(Random random) {
this.random = Preconditions.checkNotNull(random,
"Supplied random was null!");
}
public void useSparseWeights() {
sparseWeights = true;
}
public void setNumPasses(int passes) {
Preconditions
.checkArgument(passes > 0,
"Iterative algorithms need at least a single pass. Supplied: "
+ passes);
this.numPasses = passes;
}
}