package de.jungblut.classification.eval; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import de.jungblut.math.DoubleVector; import de.jungblut.online.minimizer.IterationFinishedCallback; import de.jungblut.online.minimizer.PassFinishedCallback; public class ErrorCountingCallback implements IterationFinishedCallback, PassFinishedCallback { private static final Logger LOG = LogManager .getLogger(ErrorCountingCallback.class); private long errors; private long seen; @Override public void onIterationFinished(int pass, long iteration, double cost, DoubleVector currentWeights, boolean validation) { if (cost != 0d) { errors++; } seen++; } @Override public boolean onPassFinished(int pass, long iteration, double cost, DoubleVector currentWeights) { LOG.info("Errors | Pass " + pass + " | Iteration " + iteration + " | #Errors " + errors + " | Accuracy " + ((seen - errors) / (double) seen)); boolean continueComputation = errors != 0; errors = 0; // reset the errors seen = 0; return continueComputation; } }