package org.apache.samoa.learners.classifiers.ensemble; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ /** * License */ import java.util.HashMap; import java.util.Map; import java.util.Random; import org.apache.samoa.core.ContentEvent; import org.apache.samoa.instances.Instance; import org.apache.samoa.learners.InstanceContentEvent; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.moa.core.DoubleVector; import org.apache.samoa.moa.core.Utils; import org.apache.samoa.topology.Stream; /** * The Class BoostingPredictionCombinerProcessor. */ public class BoostingPredictionCombinerProcessor extends PredictionCombinerProcessor { private static final long serialVersionUID = -1606045723451191232L; // Weigths classifier protected double[] scms; // Weights instance protected double[] swms; /** * On event. * * @param event * the event * @return true, if successful */ @Override public boolean process(ContentEvent event) { ResultContentEvent inEvent = (ResultContentEvent) event; double[] prediction = inEvent.getClassVotes(); int instanceIndex = (int) inEvent.getInstanceIndex(); addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); // Boosting addPredictions(instanceIndex, inEvent, prediction); if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); if (combinedVote == null) { combinedVote = new DoubleVector(); } ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), combinedVote.getArrayCopy(), inEvent.isLastEvent()); outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); outputStream.put(outContentEvent); clearStatisticsInstance(instanceIndex); // Boosting computeBoosting(inEvent, instanceIndex); return true; } return false; } protected Random random; protected int trainingWeightSeenByModel; @Override protected double getEnsembleMemberWeight(int i) { double em = this.swms[i] / (this.scms[i] + this.swms[i]); if ((em == 0.0) || (em > 0.5)) { return 0.0; } double Bm = em / (1.0 - em); return Math.log(1.0 / Bm); } @Override public void reset() { this.random = new Random(); this.trainingWeightSeenByModel = 0; this.scms = new double[this.ensembleSize]; this.swms = new double[this.ensembleSize]; } private boolean correctlyClassifies(int i, Instance inst, int instanceIndex) { int predictedClass = (int) mapPredictions.get(instanceIndex).getValue(i); return predictedClass == (int) inst.classValue(); } protected Map<Integer, DoubleVector> mapPredictions; private void addPredictions(int instanceIndex, ResultContentEvent inEvent, double[] prediction) { if (this.mapPredictions == null) { this.mapPredictions = new HashMap<>(); } DoubleVector predictions = this.mapPredictions.get(instanceIndex); if (predictions == null) { predictions = new DoubleVector(); } predictions.setValue(inEvent.getClassifierIndex(), Utils.maxIndex(prediction)); this.mapPredictions.put(instanceIndex, predictions); } private void computeBoosting(ResultContentEvent inEvent, int instanceIndex) { // Starts code for Boosting // Send instances to train double lambda_d = 1.0; for (int i = 0; i < this.ensembleSize; i++) { double k = lambda_d; Instance inst = inEvent.getInstance(); if (k > 0.0) { Instance weightedInst = inst.copy(); weightedInst.setWeight(inst.weight() * k); // this.ensemble[i].trainOnInstance(weightedInst); InstanceContentEvent instanceContentEvent = new InstanceContentEvent( inEvent.getInstanceIndex(), weightedInst, true, false); instanceContentEvent.setClassifierIndex(i); instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); trainingStream.put(instanceContentEvent); } if (this.correctlyClassifies(i, inst, instanceIndex)) { this.scms[i] += lambda_d; lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]); } else { this.swms[i] += lambda_d; lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]); } } } /** * Gets the training stream. * * @return the training stream */ public Stream getTrainingStream() { return trainingStream; } /** * Sets the training stream. * * @param trainingStream * the new training stream */ public void setTrainingStream(Stream trainingStream) { this.trainingStream = trainingStream; } /** The training stream. */ private Stream trainingStream; }