/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.df.split;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Instance;
import java.util.Arrays;
/**
* Regression problem implementation of IgSplit.
* This class can be used when the criterion variable is the numerical attribute.
*/
public class RegressionSplit extends IgSplit {
/**
* Comparator for Instance sort
*/
private static class InstanceComparator implements java.util.Comparator<Instance> {
private final int attr;
InstanceComparator(int attr) {
this.attr = attr;
}
@Override
public int compare(Instance arg0, Instance arg1) {
return Double.compare(arg0.get(attr), arg1.get(attr));
}
}
@Override
public Split computeSplit(Data data, int attr) {
if (data.getDataset().isNumerical(attr)) {
return numericalSplit(data, attr);
} else {
return categoricalSplit(data, attr);
}
}
/**
* Computes the split for a CATEGORICAL attribute
*/
private static Split categoricalSplit(Data data, int attr) {
double[] sums = new double[data.getDataset().nbValues(attr)];
double[] sumSquared = new double[data.getDataset().nbValues(attr)];
double[] counts = new double[data.getDataset().nbValues(attr)];
double totalSum = 0;
double totalSumSquared = 0;
// sum and sum of squares
for (int i = 0; i < data.size(); i++) {
Instance instance = data.get(i);
int value = (int) instance.get(attr);
double label = data.getDataset().getLabel(instance);
double square = label * label;
sums[value] += label;
sumSquared[value] += square;
counts[value]++;
totalSum += label;
totalSumSquared += square;
}
// computes the variance
double totalVar = totalSumSquared - (totalSum * totalSum) / data.size();
double var = variance(sums, sumSquared, counts);
double ig = totalVar - var;
return new Split(attr, ig);
}
/**
* Computes the best split for a NUMERICAL attribute
*/
static Split numericalSplit(Data data, int attr) {
// Instance sort
Instance[] instances = new Instance[data.size()];
for (int i = 0; i < data.size(); i++) {
instances[i] = data.get(i);
}
Arrays.sort(instances, new InstanceComparator(attr));
// sum and sum of squares
double totalSum = 0.0;
double totalSumSquared = 0.0;
for (Instance instance : instances) {
double label = data.getDataset().getLabel(instance);
totalSum += label;
totalSumSquared += label * label;
}
double[] sums = new double[2];
double[] curSums = new double[2];
sums[1] = curSums[1] = totalSum;
double[] sumSquared = new double[2];
double[] curSumSquared = new double[2];
sumSquared[1] = curSumSquared[1] = totalSumSquared;
double[] counts = new double[2];
double[] curCounts = new double[2];
counts[1] = curCounts[1] = data.size();
// find the best split point
double curSplit = instances[0].get(attr);
double bestVal = Double.MAX_VALUE;
double split = Double.NaN;
for (Instance instance : instances) {
if (instance.get(attr) > curSplit) {
double curVal = variance(curSums, curSumSquared, curCounts);
if (curVal < bestVal) {
bestVal = curVal;
split = (instance.get(attr) + curSplit) / 2.0;
for (int j = 0; j < 2; j++) {
sums[j] = curSums[j];
sumSquared[j] = curSumSquared[j];
counts[j] = curCounts[j];
}
}
}
curSplit = instance.get(attr);
double label = data.getDataset().getLabel(instance);
double square = label * label;
curSums[0] += label;
curSumSquared[0] += square;
curCounts[0]++;
curSums[1] -= label;
curSumSquared[1] -= square;
curCounts[1]--;
}
// computes the variance
double totalVar = totalSumSquared - (totalSum * totalSum) / data.size();
double var = variance(sums, sumSquared, counts);
double ig = totalVar - var;
return new Split(attr, ig, split);
}
/**
* Computes the variance
*
* @param s
* data
* @param ss
* squared data
* @param dataSize
* numInstances
*/
private static double variance(double[] s, double[] ss, double[] dataSize) {
double var = 0;
for (int i = 0; i < s.length; i++) {
if (dataSize[i] > 0) {
var += ss[i] - ((s[i] * s[i]) / dataSize[i]);
}
}
return var;
}
}