package com.yahoo.labs.samoa.learners.classifiers; /* * #%L * SAMOA * %% * Copyright (C) 2013 Yahoo! Inc. * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ /** * License */ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.yahoo.labs.samoa.core.ContentEvent; import com.yahoo.labs.samoa.core.Processor; import com.yahoo.labs.samoa.instances.Instance; import com.yahoo.labs.samoa.learners.InstanceContentEvent; import com.yahoo.labs.samoa.learners.ResultContentEvent; import com.yahoo.labs.samoa.moa.classifiers.core.driftdetection.ChangeDetector; import com.yahoo.labs.samoa.topology.Stream; import static com.yahoo.labs.samoa.moa.core.Utils.maxIndex; /** * The Class LearnerProcessor. */ final public class LocalLearnerProcessor implements Processor { /** * */ private static final long serialVersionUID = -1577910988699148691L; private static final Logger logger = LoggerFactory.getLogger(LocalLearnerProcessor.class); private LocalLearner model; private Stream outputStream; private int modelId; private long instancesCount = 0; /** * Sets the learner. * * @param model the model to set */ public void setLearner(LocalLearner model) { this.model = model; } /** * Gets the learner. * * @return the model */ public LocalLearner getLearner() { return model; } /** * Set the output streams. * * @param outputStream the new output stream */ public void setOutputStream(Stream outputStream) { this.outputStream = outputStream; } /** * Gets the output stream. * * @return the output stream */ public Stream getOutputStream() { return outputStream; } /** * Gets the instances count. * * @return number of observation vectors used in training iteration. */ public long getInstancesCount() { return instancesCount; } /** * Update stats. * * @param event the event */ private void updateStats(InstanceContentEvent event) { Instance inst = event.getInstance(); this.model.trainOnInstance(inst); this.instancesCount++; if (this.changeDetector != null) { boolean correctlyClassifies = this.correctlyClassifies(inst); double oldEstimation = this.changeDetector.getEstimation(); this.changeDetector.input(correctlyClassifies ? 0 : 1); if (this.changeDetector.getChange() && this.changeDetector.getEstimation() > oldEstimation) { //Start a new classifier this.model.resetLearning(); this.changeDetector.resetLearning(); } } } /** * Gets whether this classifier correctly classifies an instance. Uses * getVotesForInstance to obtain the prediction and the instance to obtain * its true class. * * * @param inst the instance to be classified * @return true if the instance is correctly classified */ private boolean correctlyClassifies(Instance inst) { return maxIndex(model.getVotesForInstance(inst)) == (int) inst.classValue(); } /** The test. */ protected int test; //to delete /** * On event. * * @param event the event * @return true, if successful */ @Override public boolean process(ContentEvent event) { InstanceContentEvent inEvent = (InstanceContentEvent) event; Instance instance = inEvent.getInstance(); if (inEvent.getInstanceIndex() < 0) { //end learning ResultContentEvent outContentEvent = new ResultContentEvent(-1, instance, 0, new double[0], inEvent.isLastEvent()); outContentEvent.setClassifierIndex(this.modelId); outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); outputStream.put(outContentEvent); return false; } if (inEvent.isTesting()){ double[] dist = model.getVotesForInstance(instance); ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), instance, inEvent.getClassId(), dist, inEvent.isLastEvent()); outContentEvent.setClassifierIndex(this.modelId); outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); logger.trace(inEvent.getInstanceIndex() + " {} {}", modelId, dist); outputStream.put(outContentEvent); } if (inEvent.isTraining()) { updateStats(inEvent); } return false; } /* (non-Javadoc) * @see samoa.core.Processor#onCreate(int) */ @Override public void onCreate(int id) { this.modelId = id; model = model.create(); } /* (non-Javadoc) * @see samoa.core.Processor#newProcessor(samoa.core.Processor) */ @Override public Processor newProcessor(Processor sourceProcessor) { LocalLearnerProcessor newProcessor = new LocalLearnerProcessor(); LocalLearnerProcessor originProcessor = (LocalLearnerProcessor) sourceProcessor; if (originProcessor.getLearner() != null){ newProcessor.setLearner(originProcessor.getLearner().create()); } if (originProcessor.getChangeDetector() != null){ newProcessor.setChangeDetector(originProcessor.getChangeDetector()); } newProcessor.setOutputStream(originProcessor.getOutputStream()); return newProcessor; } protected ChangeDetector changeDetector; public ChangeDetector getChangeDetector() { return this.changeDetector; } public void setChangeDetector(ChangeDetector cd) { this.changeDetector = cd; } }