package com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers; /* * #%L * SAMOA * %% * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.NominalAttributeBinaryTest; import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.NominalAttributeMultiwayTest; import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; import com.yahoo.labs.samoa.moa.core.ObjectRepository; import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; import com.yahoo.labs.samoa.moa.core.Utils; import com.yahoo.labs.samoa.moa.core.AutoExpandVector; import com.yahoo.labs.samoa.moa.core.DoubleVector; import com.yahoo.labs.samoa.moa.options.AbstractOptionHandler; /** * Class for observing the class data distribution for a nominal attribute. * This observer monitors the class distribution of a given attribute. * Used in naive Bayes and decision trees to monitor data statistics on leaves. * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ public class NominalAttributeClassObserver extends AbstractOptionHandler implements DiscreteAttributeClassObserver { private static final long serialVersionUID = 1L; protected double totalWeightObserved = 0.0; protected double missingWeightObserved = 0.0; public AutoExpandVector<DoubleVector> attValDistPerClass = new AutoExpandVector<>(); @Override public void observeAttributeClass(double attVal, int classVal, double weight) { if (Utils.isMissingValue(attVal)) { this.missingWeightObserved += weight; } else { int attValInt = (int) attVal; DoubleVector valDist = this.attValDistPerClass.get(classVal); if (valDist == null) { valDist = new DoubleVector(); this.attValDistPerClass.set(classVal, valDist); } valDist.addToValue(attValInt, weight); } this.totalWeightObserved += weight; } @Override public double probabilityOfAttributeValueGivenClass(double attVal, int classVal) { DoubleVector obs = this.attValDistPerClass.get(classVal); return obs != null ? (obs.getValue((int) attVal) + 1.0) / (obs.sumOfValues() + obs.numValues()) : 0.0; } public double totalWeightOfClassObservations() { return this.totalWeightObserved; } public double weightOfObservedMissingValues() { return this.missingWeightObserved; } @Override public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binaryOnly) { AttributeSplitSuggestion bestSuggestion = null; int maxAttValsObserved = getMaxAttValsObserved(); if (!binaryOnly) { double[][] postSplitDists = getClassDistsResultingFromMultiwaySplit(maxAttValsObserved); double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); bestSuggestion = new AttributeSplitSuggestion( new NominalAttributeMultiwayTest(attIndex), postSplitDists, merit); } for (int valIndex = 0; valIndex < maxAttValsObserved; valIndex++) { double[][] postSplitDists = getClassDistsResultingFromBinarySplit(valIndex); double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) { bestSuggestion = new AttributeSplitSuggestion( new NominalAttributeBinaryTest(attIndex, valIndex), postSplitDists, merit); } } return bestSuggestion; } public int getMaxAttValsObserved() { int maxAttValsObserved = 0; for (DoubleVector attValDist : this.attValDistPerClass) { if ((attValDist != null) && (attValDist.numValues() > maxAttValsObserved)) { maxAttValsObserved = attValDist.numValues(); } } return maxAttValsObserved; } public double[][] getClassDistsResultingFromMultiwaySplit( int maxAttValsObserved) { DoubleVector[] resultingDists = new DoubleVector[maxAttValsObserved]; for (int i = 0; i < resultingDists.length; i++) { resultingDists[i] = new DoubleVector(); } for (int i = 0; i < this.attValDistPerClass.size(); i++) { DoubleVector attValDist = this.attValDistPerClass.get(i); if (attValDist != null) { for (int j = 0; j < attValDist.numValues(); j++) { resultingDists[j].addToValue(i, attValDist.getValue(j)); } } } double[][] distributions = new double[maxAttValsObserved][]; for (int i = 0; i < distributions.length; i++) { distributions[i] = resultingDists[i].getArrayRef(); } return distributions; } public double[][] getClassDistsResultingFromBinarySplit(int valIndex) { DoubleVector equalsDist = new DoubleVector(); DoubleVector notEqualDist = new DoubleVector(); for (int i = 0; i < this.attValDistPerClass.size(); i++) { DoubleVector attValDist = this.attValDistPerClass.get(i); if (attValDist != null) { for (int j = 0; j < attValDist.numValues(); j++) { if (j == valIndex) { equalsDist.addToValue(i, attValDist.getValue(j)); } else { notEqualDist.addToValue(i, attValDist.getValue(j)); } } } } return new double[][]{equalsDist.getArrayRef(), notEqualDist.getArrayRef()}; } @Override public void getDescription(StringBuilder sb, int indent) { // TODO Auto-generated method stub } @Override protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { // TODO Auto-generated method stub } @Override public void observeAttributeTarget(double attVal, double target) { throw new UnsupportedOperationException("Not supported yet."); } }