package org.apache.samoa.learners.classifiers.trees; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.util.HashMap; import java.util.Map; import org.apache.samoa.instances.Instance; import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; import org.slf4j.Logger; import org.slf4j.LoggerFactory; final class ActiveLearningNode extends LearningNode { /** * */ private static final long serialVersionUID = -2892102872646338908L; private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class); private double weightSeenAtLastSplitEvaluation; private final Map<Integer, String> attributeContentEventKeys; private AttributeSplitSuggestion bestSuggestion; private AttributeSplitSuggestion secondBestSuggestion; private final long id; private final int parallelismHint; private int suggestionCtr; private int thrownAwayInstance; private boolean isSplitting; ActiveLearningNode(double[] classObservation, int parallelismHint) { super(classObservation); this.weightSeenAtLastSplitEvaluation = this.getWeightSeen(); this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate(); this.attributeContentEventKeys = new HashMap<>(); this.isSplitting = false; this.parallelismHint = parallelismHint; } long getId() { return id; } protected AttributeBatchContentEvent[] attributeBatchContentEvent; public AttributeBatchContentEvent[] getAttributeBatchContentEvent() { return this.attributeBatchContentEvent; } public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) { this.attributeBatchContentEvent = attributeBatchContentEvent; } @Override void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { // TODO: what statistics should we keep for unused instance? if (isSplitting) { // currently throw all instance will splitting this.thrownAwayInstance++; return; } this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); // done: parallelize by sending attributes one by one // TODO: meanwhile, we can try to use the ThreadPool to execute it // separately // TODO: parallelize by sending in batch, i.e. split the attributes into // chunk instead of send the attribute one by one for (int i = 0; i < inst.numAttributes() - 1; i++) { int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); Integer obsIndex = i; String key = attributeContentEventKeys.get(obsIndex); if (key == null) { key = this.generateKey(i); attributeContentEventKeys.put(obsIndex, key); } AttributeContentEvent ace = new AttributeContentEvent.Builder( this.id, i, key) .attrValue(inst.value(instAttIndex)) .classValue((int) inst.classValue()) .weight(inst.weight()) .isNominal(inst.attribute(instAttIndex).isNominal()) .build(); if (this.attributeBatchContentEvent == null) { this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1]; } if (this.attributeBatchContentEvent[i] == null) { this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder( this.id, i, key) // .attrValue(inst.value(instAttIndex)) // .classValue((int) inst.classValue()) // .weight(inst.weight()] .isNominal(inst.attribute(instAttIndex).isNominal()) .build(); } this.attributeBatchContentEvent[i].add(ace); // proc.sendToAttributeStream(ace); } } @Override double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { return this.observedClassDistribution.getArrayCopy(); } double getWeightSeen() { return this.observedClassDistribution.sumOfValues(); } void setWeightSeenAtLastSplitEvaluation(double weight) { this.weightSeenAtLastSplitEvaluation = weight; } double getWeightSeenAtLastSplitEvaluation() { return this.weightSeenAtLastSplitEvaluation; } void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) { this.isSplitting = true; this.suggestionCtr = 0; this.thrownAwayInstance = 0; ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id, this.getObservedClassDistribution()); modelAggrProc.sendToControlStream(cce); } void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion) { // starts comparing from the best suggestion if (bestSuggestion != null) { if ((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)) { this.secondBestSuggestion = this.bestSuggestion; this.bestSuggestion = bestSuggestion; if (secondBestSuggestion != null) { if ((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { this.secondBestSuggestion = secondBestSuggestion; } } } else { if ((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { this.secondBestSuggestion = bestSuggestion; } } } // TODO: optimize the code to use less memory this.suggestionCtr++; } boolean isSplitting() { return this.isSplitting; } void endSplitting() { this.isSplitting = false; logger.trace("wasted instance: {}", this.thrownAwayInstance); this.thrownAwayInstance = 0; this.bestSuggestion = null; this.secondBestSuggestion = null; } AttributeSplitSuggestion getDistributedBestSuggestion() { return this.bestSuggestion; } AttributeSplitSuggestion getDistributedSecondBestSuggestion() { return this.secondBestSuggestion; } boolean isAllSuggestionsCollected() { return (this.suggestionCtr == this.parallelismHint); } private static int modelAttIndexToInstanceAttIndex(int index, Instance inst) { return inst.classIndex() > index ? index : index + 1; } private String generateKey(int obsIndex) { final int prime = 31; int result = 1; result = prime * result + (int) (this.id ^ (this.id >>> 32)); result = prime * result + obsIndex; return Integer.toString(result); } }