package de.jungblut.classification.eval;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import de.jungblut.classification.eval.Evaluator.EvaluationResult;
import de.jungblut.math.DoubleVector;
import de.jungblut.online.minimizer.PassFinishedCallback;
import de.jungblut.online.minimizer.ValidationFinishedCallback;
import de.jungblut.online.ml.FeatureOutcomePair;
import de.jungblut.online.regression.RegressionClassifier;
import de.jungblut.online.regression.RegressionLearner;
import de.jungblut.online.regression.RegressionModel;
public class RegressionValidationCallback implements
ValidationFinishedCallback, PassFinishedCallback {
private static final Logger LOG = LogManager
.getLogger(RegressionValidationCallback.class);
private EvaluationResult currentResult;
private RegressionLearner learner;
public RegressionValidationCallback(RegressionLearner learner) {
this.learner = learner;
setupNewEvaluationResult();
}
@Override
public boolean onPassFinished(int pass, long iteration, double cost,
DoubleVector currentWeights) {
LOG.info("Evaluation | Pass " + pass + " | Iteration " + iteration);
currentResult.print(LOG);
setupNewEvaluationResult();
return true;
}
@Override
public void onValidationFinished(int pass, long iteration, double cost,
DoubleVector currentWeights, FeatureOutcomePair pair) {
RegressionModel model = learner.createModel(currentWeights);
RegressionClassifier classifier = new RegressionClassifier(model);
currentResult.testSize++;
DoubleVector predict = classifier.predict(pair.getFeature());
Evaluator.observeBinaryClassificationElement(classifier, null,
currentResult, pair.getOutcome(), predict);
}
private void setupNewEvaluationResult() {
currentResult = new EvaluationResult();
currentResult.numLabels = 2;
}
}