package org.apache.samoa.learners.classifiers.rules.distributed; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.core.Processor; import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Instances; import org.apache.samoa.learners.InstanceContentEvent; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; import org.apache.samoa.learners.classifiers.rules.common.Perceptron; import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; import org.apache.samoa.topology.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Default Rule Learner Processor (HAMR). * * @author Anh Thu Vu * */ public class AMRDefaultRuleProcessor implements Processor { /** * */ private static final long serialVersionUID = 23702084591044447L; private static final Logger logger = LoggerFactory.getLogger(AMRDefaultRuleProcessor.class); private int processorId; // Default rule protected transient ActiveRule defaultRule; protected transient int ruleNumberID; protected transient double[] statistics; // SAMOA Stream private Stream ruleStream; private Stream resultStream; // Options protected int pageHinckleyThreshold; protected double pageHinckleyAlpha; protected boolean driftDetection; protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 protected boolean constantLearningRatioDecay; protected double learningRatio; protected double splitConfidence; protected double tieThreshold; protected int gracePeriod; protected FIMTDDNumericAttributeClassLimitObserver numericObserver; /* * Constructor */ public AMRDefaultRuleProcessor(Builder builder) { this.pageHinckleyThreshold = builder.pageHinckleyThreshold; this.pageHinckleyAlpha = builder.pageHinckleyAlpha; this.driftDetection = builder.driftDetection; this.predictionFunction = builder.predictionFunction; this.constantLearningRatioDecay = builder.constantLearningRatioDecay; this.learningRatio = builder.learningRatio; this.splitConfidence = builder.splitConfidence; this.tieThreshold = builder.tieThreshold; this.gracePeriod = builder.gracePeriod; this.numericObserver = builder.numericObserver; } @Override public boolean process(ContentEvent event) { InstanceContentEvent instanceEvent = (InstanceContentEvent) event; // predict if (instanceEvent.isTesting()) { this.predictOnInstance(instanceEvent); } // train if (instanceEvent.isTraining()) { this.trainOnInstance(instanceEvent); } return false; } /* * Prediction */ private void predictOnInstance(InstanceContentEvent instanceEvent) { double[] vote = defaultRule.getPrediction(instanceEvent.getInstance()); ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); resultStream.put(rce); } private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); rce.setClassifierIndex(this.processorId); rce.setEvaluationIndex(inEvent.getEvaluationIndex()); return rce; } /* * Training */ private void trainOnInstance(InstanceContentEvent instanceEvent) { this.trainOnInstanceImpl(instanceEvent.getInstance()); } public void trainOnInstanceImpl(Instance instance) { defaultRule.updateStatistics(instance); if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), (RuleActiveRegressionNode) defaultRule.getLearningNode(), ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch defaultRule.split(); defaultRule.setRuleNumberID(++ruleNumberID); // send out the new rule sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); defaultRule = newDefaultRule; } } } /* * Create new rules */ private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { ActiveRule r = newRule(ID); if (node != null) { if (node.getPerceptron() != null) { r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); } if (statistics == null) { double mean; if (node.getNodeStatistics().getValue(0) > 0) { mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); r.getLearningNode().getTargetMean().reset(mean, 1); } } } if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) { double mean; if (statistics[0] > 0) { mean = statistics[1] / statistics[0]; ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); } } return r; } private ActiveRule newRule(int ID) { ActiveRule r = new ActiveRule.Builder(). threshold(this.pageHinckleyThreshold). alpha(this.pageHinckleyAlpha). changeDetection(this.driftDetection). predictionFunction(this.predictionFunction). statistics(new double[3]). learningRatio(this.learningRatio). numericObserver(numericObserver). id(ID).build(); return r; } @Override public void onCreate(int id) { this.processorId = id; this.statistics = new double[] { 0.0, 0, 0 }; this.ruleNumberID = 0; this.defaultRule = newRule(++this.ruleNumberID); } /* * Clone processor */ @Override public Processor newProcessor(Processor p) { AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p; Builder builder = new Builder(oldProcessor); AMRDefaultRuleProcessor newProcessor = builder.build(); newProcessor.resultStream = oldProcessor.resultStream; newProcessor.ruleStream = oldProcessor.ruleStream; return newProcessor; } /* * Send events */ private void sendAddRuleEvent(int ruleID, ActiveRule rule) { RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); this.ruleStream.put(rce); } /* * Output streams */ public void setRuleStream(Stream ruleStream) { this.ruleStream = ruleStream; } public Stream getRuleStream() { return this.ruleStream; } public void setResultStream(Stream resultStream) { this.resultStream = resultStream; } public Stream getResultStream() { return this.resultStream; } /* * Builder */ public static class Builder { private int pageHinckleyThreshold; private double pageHinckleyAlpha; private boolean driftDetection; private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 private boolean constantLearningRatioDecay; private double learningRatio; private double splitConfidence; private double tieThreshold; private int gracePeriod; private FIMTDDNumericAttributeClassLimitObserver numericObserver; private Instances dataset; public Builder(Instances dataset) { this.dataset = dataset; } public Builder(AMRDefaultRuleProcessor processor) { this.pageHinckleyThreshold = processor.pageHinckleyThreshold; this.pageHinckleyAlpha = processor.pageHinckleyAlpha; this.driftDetection = processor.driftDetection; this.predictionFunction = processor.predictionFunction; this.constantLearningRatioDecay = processor.constantLearningRatioDecay; this.learningRatio = processor.learningRatio; this.splitConfidence = processor.splitConfidence; this.tieThreshold = processor.tieThreshold; this.gracePeriod = processor.gracePeriod; this.numericObserver = processor.numericObserver; } public Builder threshold(int threshold) { this.pageHinckleyThreshold = threshold; return this; } public Builder alpha(double alpha) { this.pageHinckleyAlpha = alpha; return this; } public Builder changeDetection(boolean changeDetection) { this.driftDetection = changeDetection; return this; } public Builder predictionFunction(int predictionFunction) { this.predictionFunction = predictionFunction; return this; } public Builder constantLearningRatioDecay(boolean constantDecay) { this.constantLearningRatioDecay = constantDecay; return this; } public Builder learningRatio(double learningRatio) { this.learningRatio = learningRatio; return this; } public Builder splitConfidence(double splitConfidence) { this.splitConfidence = splitConfidence; return this; } public Builder tieThreshold(double tieThreshold) { this.tieThreshold = tieThreshold; return this; } public Builder gracePeriod(int gracePeriod) { this.gracePeriod = gracePeriod; return this; } public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { this.numericObserver = numericObserver; return this; } public AMRDefaultRuleProcessor build() { return new AMRDefaultRuleProcessor(this); } } }