/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.df.split;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import java.util.Arrays;
/**
* Optimized implementation of IgSplit<br>
* This class can be used when the criterion variable is the categorical attribute.
*/
public class OptIgSplit extends IgSplit {
private int[][] counts;
private int[] countAll;
private int[] countLess;
@Override
public Split computeSplit(Data data, int attr) {
if (data.getDataset().isNumerical(attr)) {
return numericalSplit(data, attr);
} else {
return categoricalSplit(data, attr);
}
}
/**
* Computes the split for a CATEGORICAL attribute
*/
private static Split categoricalSplit(Data data, int attr) {
double[] values = data.values(attr);
int[][] counts = new int[values.length][data.getDataset().nblabels()];
int[] countAll = new int[data.getDataset().nblabels()];
Dataset dataset = data.getDataset();
// compute frequencies
for (int index = 0; index < data.size(); index++) {
Instance instance = data.get(index);
counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
countAll[(int) dataset.getLabel(instance)]++;
}
int size = data.size();
double hy = entropy(countAll, size); // H(Y)
double hyx = 0.0; // H(Y|X)
double invDataSize = 1.0 / size;
for (int index = 0; index < values.length; index++) {
size = DataUtils.sum(counts[index]);
hyx += size * invDataSize * entropy(counts[index], size);
}
double ig = hy - hyx;
return new Split(attr, ig);
}
/**
* Return the sorted list of distinct values for the given attribute
*/
private static double[] sortedValues(Data data, int attr) {
double[] values = data.values(attr);
Arrays.sort(values);
return values;
}
/**
* Instantiates the counting arrays
*/
void initCounts(Data data, double[] values) {
counts = new int[values.length][data.getDataset().nblabels()];
countAll = new int[data.getDataset().nblabels()];
countLess = new int[data.getDataset().nblabels()];
}
void computeFrequencies(Data data, int attr, double[] values) {
Dataset dataset = data.getDataset();
for (int index = 0; index < data.size(); index++) {
Instance instance = data.get(index);
counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
countAll[(int) dataset.getLabel(instance)]++;
}
}
/**
* Computes the best split for a NUMERICAL attribute
*/
Split numericalSplit(Data data, int attr) {
double[] values = sortedValues(data, attr);
initCounts(data, values);
computeFrequencies(data, attr, values);
int size = data.size();
double hy = entropy(countAll, size);
double invDataSize = 1.0 / size;
int best = -1;
double bestIg = -1.0;
// try each possible split value
for (int index = 0; index < values.length; index++) {
double ig = hy;
// instance with attribute value < values[index]
size = DataUtils.sum(countLess);
ig -= size * invDataSize * entropy(countLess, size);
// instance with attribute value >= values[index]
size = DataUtils.sum(countAll);
ig -= size * invDataSize * entropy(countAll, size);
if (ig > bestIg) {
bestIg = ig;
best = index;
}
DataUtils.add(countLess, counts[index]);
DataUtils.dec(countAll, counts[index]);
}
if (best == -1) {
throw new IllegalStateException("no best split found !");
}
return new Split(attr, bestIg, values[best]);
}
/**
* Computes the Entropy
*
* @param counts counts[i] = numInstances with label i
* @param dataSize numInstances
*/
private static double entropy(int[] counts, int dataSize) {
if (dataSize == 0) {
return 0.0;
}
double entropy = 0.0;
double invDataSize = 1.0 / dataSize;
for (int count : counts) {
if (count == 0) {
continue; // otherwise we get a NaN
}
double p = count * invDataSize;
entropy += -p * Math.log(p) / LOG2;
}
return entropy;
}
}