package org.apache.samoa.evaluation; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.core.Processor; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.moa.core.Measurement; import org.apache.samoa.moa.evaluation.LearningCurve; import org.apache.samoa.moa.evaluation.LearningEvaluation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.PrintStream; import java.util.*; import java.util.concurrent.TimeUnit; public class EvaluatorCVProcessor implements Processor { /** * */ private static final long serialVersionUID = -2778051819116753612L; private static final Logger logger = LoggerFactory.getLogger(EvaluatorCVProcessor.class); private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances"; private final PerformanceEvaluator[] evaluators; private final int samplingFrequency; private final File dumpFile; private transient PrintStream immediateResultStream = null; private transient boolean firstDump = true; private long totalCount = 0; private long experimentStart = 0; private long sampleStart = 0; private LearningCurve learningCurve; private int id; private int foldNumber = 10; private EvaluatorCVProcessor(Builder builder) { evaluators = new PerformanceEvaluator[builder.foldNumber]; for (int i = 0; i < this.evaluators.length; i++) { evaluators[i] = (PerformanceEvaluator) builder.evaluator.copy(); } this.samplingFrequency = builder.samplingFrequency; this.dumpFile = builder.dumpFile; this.foldNumber = builder.foldNumber; } private boolean initiated = false; @Override public boolean process(ContentEvent event) { if (this.initiated == false) { sampleStart = System.nanoTime(); experimentStart = sampleStart; this.initiated = true; } ResultContentEvent result = (ResultContentEvent) event; int instanceIndex = (int) result.getInstanceIndex(); addStatisticsForInstanceReceived(instanceIndex, result.getEvaluationIndex(), 1); evaluators[result.getEvaluationIndex()].addResult(result.getInstance(), result.getClassVotes()); if (hasAllVotesArrivedInstance(instanceIndex)) { totalCount += 1; if (result.isLastEvent()) { this.concludeMeasurement(); return true; } //this.mapCountsforInstanceReceived.remove(instanceIndex); if ((totalCount > 0) && (totalCount % samplingFrequency) == 0) { long sampleEnd = System.nanoTime(); long sampleDuration = TimeUnit.SECONDS.convert(sampleEnd - sampleStart, TimeUnit.NANOSECONDS); sampleStart = sampleEnd; logger.info("{} seconds for {} instances", sampleDuration, samplingFrequency); this.addMeasurement(); } } return false; } protected Map<Integer, Integer> mapCountsforInstanceReceived; private boolean hasAllVotesArrivedInstance(int instanceIndex) { Map<Integer, Integer> map = this.mapCountsforInstanceReceived; int count = map.get(instanceIndex); return (count == this.foldNumber); } protected void addStatisticsForInstanceReceived(int instanceIndex, int evaluationIndex, int add) { if (this.mapCountsforInstanceReceived == null) { this.mapCountsforInstanceReceived = new HashMap<>(); } Integer count = this.mapCountsforInstanceReceived.get(instanceIndex); if (count == null) { count = 0; } this.mapCountsforInstanceReceived.put(instanceIndex, count + add); } @Override public void onCreate(int id) { this.id = id; this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME); if (this.dumpFile != null) { try { if (dumpFile.exists()) { this.immediateResultStream = new PrintStream( new FileOutputStream(dumpFile, true), true); } else { this.immediateResultStream = new PrintStream( new FileOutputStream(dumpFile), true); } } catch (FileNotFoundException e) { this.immediateResultStream = null; logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); } catch (Exception e) { this.immediateResultStream = null; logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); } } this.firstDump = true; } @Override public Processor newProcessor(Processor p) { EvaluatorCVProcessor originalProcessor = (EvaluatorCVProcessor) p; EvaluatorCVProcessor newProcessor = new EvaluatorCVProcessor.Builder(originalProcessor).build(); if (originalProcessor.learningCurve != null) { newProcessor.learningCurve = originalProcessor.learningCurve; } return newProcessor; } @Override public String toString() { StringBuilder report = new StringBuilder(); report.append(EvaluatorCVProcessor.class.getCanonicalName()); report.append("id = ").append(this.id); report.append('\n'); if (learningCurve.numEntries() > 0) { report.append(learningCurve.toString()); report.append('\n'); } return report.toString(); } private void addMeasurement() { List<Measurement> measurements = new Vector<>(); measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount )); Measurement[] finalMeasurements = getEvaluationMeasurements( measurements.toArray(new Measurement[measurements.size()]), evaluators); LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements); learningCurve.insertEntry(learningEvaluation); logger.debug("evaluator id = {}", this.id); logger.info(learningEvaluation.toString()); if (immediateResultStream != null) { if (firstDump) { immediateResultStream.println(learningCurve.headerToString()); firstDump = false; } immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); immediateResultStream.flush(); } } private void concludeMeasurement() { logger.info("last event is received!"); logger.info("total count: {}", this.totalCount); String learningCurveSummary = this.toString(); logger.info(learningCurveSummary); long experimentEnd = System.nanoTime(); long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS); logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount ); if (immediateResultStream != null) { immediateResultStream.println("# COMPLETED"); immediateResultStream.flush(); } // logger.info("average throughput rate: {} instances/seconds", // (totalCount/totalExperimentTime)); } public static class Builder { private final PerformanceEvaluator evaluator; private int samplingFrequency = 100000; private File dumpFile = null; private int foldNumber = 10; public Builder(PerformanceEvaluator evaluator) { this.evaluator = evaluator; } public Builder(EvaluatorCVProcessor oldProcessor) { this.evaluator = oldProcessor.evaluators[0]; this.samplingFrequency = oldProcessor.samplingFrequency; this.dumpFile = oldProcessor.dumpFile; } public Builder samplingFrequency(int samplingFrequency) { this.samplingFrequency = samplingFrequency; return this; } public Builder dumpFile(File file) { this.dumpFile = file; return this; } public Builder foldNumber(int foldNumber){ this.foldNumber = foldNumber; return this; } public EvaluatorCVProcessor build() { return new EvaluatorCVProcessor(this); } } public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements, PerformanceEvaluator[] subEvaluators) { List<Measurement> measurementList = new LinkedList<Measurement>(); if (modelMeasurements != null) { measurementList.addAll(Arrays.asList(modelMeasurements)); } // add average of sub-model measurements if ((subEvaluators != null) && (subEvaluators.length > 0)) { List<Measurement[]> subMeasurements = new LinkedList<Measurement[]>(); for (PerformanceEvaluator subEvaluator : subEvaluators) { if (subEvaluator != null) { subMeasurements.add(subEvaluator.getPerformanceMeasurements()); } } Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][])); measurementList.addAll(Arrays.asList(avgMeasurements)); } return measurementList.toArray(new Measurement[measurementList.size()]); } }