package tr.gov.ulakbim.jDenetX.experiments.wrappers; import tr.gov.ulakbim.jDenetX.classifiers.AbstractClassifier; import tr.gov.ulakbim.jDenetX.evaluation.BasicClassificationPerformanceEvaluator; import tr.gov.ulakbim.jDenetX.evaluation.ClassificationPerformanceEvaluator; import tr.gov.ulakbim.jDenetX.evaluation.LearningEvaluation; import tr.gov.ulakbim.jDenetX.streams.InstanceStream; import weka.core.Instance; public class EvalModel { private int MaxInstances = 1000000; private static final int INSTANCES_BETWEEN_MONITOR_UPDATES = 10; public int getMaxInstances() { return MaxInstances; } public void setMaxInstances(int maxInstances) { MaxInstances = maxInstances; } protected void trainModel(InstanceStream trainStream, AbstractClassifier model) { ClassificationPerformanceEvaluator evaluator = new BasicClassificationPerformanceEvaluator(); evaluator.reset(); long instancesProcessed = 0; System.out.println("Started learning the model..."); while (trainStream.hasMoreInstances() && ((MaxInstances < 0) || (instancesProcessed < MaxInstances))) { model.trainOnInstance(trainStream.nextInstance()); instancesProcessed++; if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { long estimatedRemainingInstances = trainStream .estimatedRemainingInstances(); if (MaxInstances > 0) { long maxRemaining = MaxInstances - instancesProcessed; if ((estimatedRemainingInstances < 0) || (maxRemaining < estimatedRemainingInstances)) { estimatedRemainingInstances = maxRemaining; } } } } } public LearningEvaluation evalModel (InstanceStream trainStream, InstanceStream testStream, AbstractClassifier model) { ClassificationPerformanceEvaluator evaluator = new BasicClassificationPerformanceEvaluator(); evaluator.reset(); long instancesProcessed = 0; System.out.println("Evaluating model..."); trainModel(trainStream, model); while (testStream.hasMoreInstances() && ((MaxInstances < 0) || (instancesProcessed < MaxInstances))) { Instance testInst = (Instance) testStream.nextInstance().copy(); int trueClass = (int) testInst.classValue(); testInst.setClassMissing(); double[] prediction = model.getVotesForInstance(testInst); evaluator.addClassificationAttempt(trueClass, prediction, testInst .weight()); instancesProcessed++; if (instancesProcessed % INSTANCES_BETWEEN_MONITOR_UPDATES == 0) { long estimatedRemainingInstances = testStream .estimatedRemainingInstances(); if (MaxInstances > 0) { long maxRemaining = MaxInstances - instancesProcessed; if ((estimatedRemainingInstances < 0) || (maxRemaining < estimatedRemainingInstances)) { estimatedRemainingInstances = maxRemaining; } } } } return new LearningEvaluation(evaluator.getPerformanceMeasurements()); } }