/* Project Knowledge Discovery from Data Streams, FCT LIAAD-INESC TEC, * * Contact: jgama@fep.up.pt */ package org.apache.samoa.moa.classifiers.core.attributeclassobservers; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.io.Serializable; import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; import org.apache.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; import org.apache.samoa.moa.core.DoubleVector; import org.apache.samoa.moa.core.ObjectRepository; import org.apache.samoa.moa.tasks.TaskMonitor; public class FIMTDDNumericAttributeClassObserver extends BinaryTreeNumericAttributeClassObserver implements NumericAttributeClassObserver { private static final long serialVersionUID = 1L; protected class Node implements Serializable { private static final long serialVersionUID = 1L; // The split point to use public double cut_point; // E-BST statistics public DoubleVector leftStatistics = new DoubleVector(); public DoubleVector rightStatistics = new DoubleVector(); // Child nodes public Node left; public Node right; public Node(double val, double label, double weight) { this.cut_point = val; this.leftStatistics.addToValue(0, 1); this.leftStatistics.addToValue(1, label); this.leftStatistics.addToValue(2, label * label); } /** * Insert a new value into the tree, updating both the sum of values and sum of squared values arrays */ public void insertValue(double val, double label, double weight) { // If the new value equals the value stored in a node, update // the left (<=) node information if (val == this.cut_point) { this.leftStatistics.addToValue(0, 1); this.leftStatistics.addToValue(1, label); this.leftStatistics.addToValue(2, label * label); } // If the new value is less than the value in a node, update the // left distribution and send the value down to the left child node. // If no left child exists, create one else if (val <= this.cut_point) { this.leftStatistics.addToValue(0, 1); this.leftStatistics.addToValue(1, label); this.leftStatistics.addToValue(2, label * label); if (this.left == null) { this.left = new Node(val, label, weight); } else { this.left.insertValue(val, label, weight); } } // If the new value is greater than the value in a node, update the // right (>) distribution and send the value down to the right child node. // If no right child exists, create one else { // val > cut_point this.rightStatistics.addToValue(0, 1); this.rightStatistics.addToValue(1, label); this.rightStatistics.addToValue(2, label * label); if (this.right == null) { this.right = new Node(val, label, weight); } else { this.right.insertValue(val, label, weight); } } } } // Root node of the E-BST structure for this attribute public Node root = null; // Global variables for use in the FindBestSplit algorithm double sumTotalLeft; double sumTotalRight; double sumSqTotalLeft; double sumSqTotalRight; double countRightTotal; double countLeftTotal; public void observeAttributeClass(double attVal, double classVal, double weight) { if (!Double.isNaN(attVal)) { if (this.root == null) { this.root = new Node(attVal, classVal, weight); } else { this.root.insertValue(attVal, classVal, weight); } } } @Override public double probabilityOfAttributeValueGivenClass(double attVal, int classVal) { // TODO: NaiveBayes broken until implemented return 0.0; } @Override public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist, int attIndex, boolean binaryOnly) { // Initialise global variables sumTotalLeft = 0; sumTotalRight = preSplitDist[1]; sumSqTotalLeft = 0; sumSqTotalRight = preSplitDist[2]; countLeftTotal = 0; countRightTotal = preSplitDist[0]; return searchForBestSplitOption(this.root, null, criterion, attIndex); } /** * Implementation of the FindBestSplit algorithm from E.Ikonomovska et al. */ protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode, AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) { // Return null if the current node is null or we have finished looking // through all the possible splits if (currentNode == null || countRightTotal == 0.0) { return currentBestOption; } if (currentNode.left != null) { currentBestOption = searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex); } sumTotalLeft += currentNode.leftStatistics.getValue(1); sumTotalRight -= currentNode.leftStatistics.getValue(1); sumSqTotalLeft += currentNode.leftStatistics.getValue(2); sumSqTotalRight -= currentNode.leftStatistics.getValue(2); countLeftTotal += currentNode.leftStatistics.getValue(0); countRightTotal -= currentNode.leftStatistics.getValue(0); double[][] postSplitDists = new double[][] { { countLeftTotal, sumTotalLeft, sumSqTotalLeft }, { countRightTotal, sumTotalRight, sumSqTotalRight } }; double[] preSplitDist = new double[] { (countLeftTotal + countRightTotal), (sumTotalLeft + sumTotalRight), (sumSqTotalLeft + sumSqTotalRight) }; double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); if ((currentBestOption == null) || (merit > currentBestOption.merit)) { currentBestOption = new AttributeSplitSuggestion( new NumericAttributeBinaryTest(attIndex, currentNode.cut_point, true), postSplitDists, merit); } if (currentNode.right != null) { currentBestOption = searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex); } sumTotalLeft -= currentNode.leftStatistics.getValue(1); sumTotalRight += currentNode.leftStatistics.getValue(1); sumSqTotalLeft -= currentNode.leftStatistics.getValue(2); sumSqTotalRight += currentNode.leftStatistics.getValue(2); countLeftTotal -= currentNode.leftStatistics.getValue(0); countRightTotal += currentNode.leftStatistics.getValue(0); return currentBestOption; } /** * A method to remove all nodes in the E-BST in which it and all it's children represent 'bad' split points */ public void removeBadSplits(SplitCriterion criterion, double lastCheckRatio, double lastCheckSDR, double lastCheckE) { removeBadSplitNodes(criterion, this.root, lastCheckRatio, lastCheckSDR, lastCheckE); } /** * Recursive method that first checks all of a node's children before deciding if it is 'bad' and may be removed */ private boolean removeBadSplitNodes(SplitCriterion criterion, Node currentNode, double lastCheckRatio, double lastCheckSDR, double lastCheckE) { boolean isBad = false; if (currentNode == null) { return true; } if (currentNode.left != null) { isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE); } if (currentNode.right != null && isBad) { isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE); } if (isBad) { double[][] postSplitDists = new double[][] { { currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1), currentNode.leftStatistics.getValue(2) }, { currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1), currentNode.rightStatistics.getValue(2) } }; double[] preSplitDist = new double[] { (currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0)), (currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1)), (currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2)) }; double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); if ((merit / lastCheckSDR) < (lastCheckRatio - (2 * lastCheckE))) { currentNode = null; return true; } } return false; } @Override public void getDescription(StringBuilder sb, int indent) { // TODO Auto-generated method stub } @Override protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { // TODO Auto-generated method stub } }