package de.jungblut.online.regression.multinomial;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.dense.SingleEntryDoubleVector;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.online.minimizer.StochasticMinimizer;
import de.jungblut.online.ml.AbstractOnlineLearner;
import de.jungblut.online.ml.FeatureOutcomePair;
import de.jungblut.online.regression.RegressionLearner;
import de.jungblut.online.regression.RegressionModel;
/**
* A regression learner that learns multiple independent regression models and
* blends them into a single model.
*
* @author thomas.jungblut
*
*/
public class MultinomialRegressionLearner extends
AbstractOnlineLearner<MultinomialRegressionModel> {
private static final Logger LOG = LogManager
.getLogger(MultinomialRegressionLearner.class);
private static final SingleEntryDoubleVector POSITIVE = new SingleEntryDoubleVector(
1d);
private static final SingleEntryDoubleVector NEGATIVE = new SingleEntryDoubleVector(
0d);
private final IntFunction<RegressionLearner> learnerFactory;
private RegressionModel[] trainedModels;
public MultinomialRegressionLearner(StochasticMinimizer minimizer,
ActivationFunction activationFunction, LossFunction lossFunction) {
this((i) -> new RegressionLearner(minimizer, activationFunction,
lossFunction));
}
public MultinomialRegressionLearner(
IntFunction<RegressionLearner> learnerFactory) {
this.learnerFactory = Preconditions.checkNotNull(learnerFactory,
"learnerFactory");
}
@Override
public MultinomialRegressionModel train(
Supplier<Stream<FeatureOutcomePair>> streamSupplier) {
init(streamSupplier);
trainedModels = new RegressionModel[numOutcomeClasses];
// train the models in parallel
IntStream
.range(0, numOutcomeClasses)
.parallel()
.forEach(
i -> {
if (verbose) {
LOG.info("Training class " + i);
}
RegressionLearner learner = learnerFactory.apply(i);
final int k = i;
trainedModels[i] = learner.train(() -> streamSupplier.get().map(
(pair) -> makeBinary(pair, k)));
if (verbose) {
LOG.info("Done training class " + i);
}
});
return new MultinomialRegressionModel(trainedModels);
}
private static FeatureOutcomePair makeBinary(FeatureOutcomePair input,
int targetClassIndex) {
DoubleVector outcome = input.getOutcome();
if (outcome.maxIndex() == targetClassIndex) {
return new FeatureOutcomePair(input.getFeature(), POSITIVE);
}
return new FeatureOutcomePair(input.getFeature(), NEGATIVE);
}
}