/**
* Copyright (C) 2001-2017 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify it under the terms of the
* GNU Affero General Public License as published by the Free Software Foundation, either version 3
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without
* even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.tree.criterions;
import com.rapidminer.example.Attribute;
import com.rapidminer.operator.learner.tree.ColumnExampleTable;
/**
* This class represents the weight distribution of a sorted attribute column on two sides of a
* split while going along the column. The left label weights contain the weighted label values to
* the left of a split point and the right label values contain the ones to the right of the split
* point. At the start the left label weights are all zero and the right label weights are maximal.
* At each step from left to right the left and right label weights are updated. If there are
* missing attribute values, they are counted separately.
*
* @author Gisa Schaefer
*
*/
public class WeightDistribution {
/** the weighted total occurrences of each label value */
private double[] totalLabelWeights;
/** the weighted occurrences of each label value to the right of the split point */
private double[] leftLabelWeights;
/** the weighted occurrences of each label value to the right of the split point */
private double[] rightLabelWeights;
/** the weighted occurrences of each label value among the missing values */
private double[] missingsLabelWeights;
/** the sum of all leftWeights */
private double leftWeight;
/** the sum of all rightWeights */
private double rightWeight;
/** the sum of all totalWeights */
private double totalWeight;
/** the sum of all missingsWeights */
private double missingsWeight;
/** indicates whether there are missing attribute values in the example selection */
private boolean hasMissings = false;
/**
* Initializes the counting arrays with the start distribution.
*/
public WeightDistribution(ColumnExampleTable columnTable, int[] selection, int attributeNumber) {
calculateLabelWeights(columnTable, selection, attributeNumber);
leftLabelWeights = new double[totalLabelWeights.length];
leftWeight = 0;
totalWeight = getTotalWeight(totalLabelWeights);
if (hasMissings) {
missingsWeight = getTotalWeight(missingsLabelWeights);
rightWeight = totalWeight - missingsWeight;
rightLabelWeights = arrayDifference(totalLabelWeights, missingsLabelWeights);
} else {
missingsWeight = 0;
rightWeight = totalWeight;
rightLabelWeights = new double[totalLabelWeights.length];
System.arraycopy(totalLabelWeights, 0, rightLabelWeights, 0, totalLabelWeights.length);
}
}
/**
* Calculates the start distributions.
*/
private void calculateLabelWeights(ColumnExampleTable columnTable, int[] selection, int attributeNumber) {
Attribute label = columnTable.getLabel();
int[] labelColumn = columnTable.getLabelColumn();
Attribute weightAttribute = columnTable.getWeight();
double[] weightColumn = columnTable.getWeightColumn();
totalLabelWeights = new double[label.getMapping().size()];
missingsLabelWeights = new double[totalLabelWeights.length];
for (int j : selection) {
int labelIndex = labelColumn[j];
double weight = 1.0d;
if (weightAttribute != null) {
weight = weightColumn[j];
}
totalLabelWeights[labelIndex] += weight;
if (Double.isNaN(columnTable.getNumericalAttributeColumn(attributeNumber)[j])) {
hasMissings = true;
missingsLabelWeights[labelIndex] += weight;
}
}
}
/**
* Increments the left label weights at the given position by the given weight and decrements
* the right label weights. Updates the sum of all left and right label weights respectively.
*
* @param position
* @param weight
*/
public void increment(int position, double weight) {
leftLabelWeights[position] += weight;
rightLabelWeights[position] -= weight;
leftWeight += weight;
rightWeight -= weight;
}
/**
* @return the sum of the weighted label value occurrences to the left of the split point.
*/
public double getLeftWeigth() {
return leftWeight;
}
/**
* @return the sum of the weighted label value occurrences to the right of the split point.
*/
public double getRightWeigth() {
return rightWeight;
}
/**
* @return the total sum of the weighted label value occurrences.
*/
public double getTotalWeigth() {
return totalWeight;
}
/**
* @return the sum of the weighted label value occurrences at missing values of the attribute.
*/
public double getMissingsWeigth() {
return missingsWeight;
}
/**
* @return the weighted occurrences of each label value to the left of the split point
*/
public double[] getLeftLabelWeigths() {
return leftLabelWeights;
}
/**
* @return the weighted occurrences of each label value to the right of the split point
*/
public double[] getRightLabelWeigths() {
return rightLabelWeights;
}
/**
* @return the weighted total occurrences of each label value
*/
public double[] getTotalLabelWeigths() {
return totalLabelWeights;
}
/**
* @return the weighted occurrences of each label value among the missing values
*/
public double[] getMissingsLabelWeigths() {
return missingsLabelWeights;
}
/**
* @return <code>true</code> if the attribute has missing values among the current selection
*/
public boolean hasMissingValues() {
return hasMissings;
}
/** Returns the sum of the given weights. */
private double getTotalWeight(double[] weights) {
double sum = 0.0d;
for (double w : weights) {
sum += w;
}
return sum;
}
/**
* Creates an array containing the differences of the entries of the given arrays.
*
* @param array1
* @param array2
* @return
*/
private double[] arrayDifference(double[] array1, double[] array2) {
double[] difference = new double[array1.length];
for (int i = 0; i < array1.length; i++) {
difference[i] = array1[i] - array2[i];
}
return difference;
}
}