/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.df.builder; import com.google.common.collect.Sets; import org.apache.mahout.classifier.df.data.Data; import org.apache.mahout.classifier.df.data.Dataset; import org.apache.mahout.classifier.df.data.Instance; import org.apache.mahout.classifier.df.data.conditions.Condition; import org.apache.mahout.classifier.df.node.CategoricalNode; import org.apache.mahout.classifier.df.node.Leaf; import org.apache.mahout.classifier.df.node.Node; import org.apache.mahout.classifier.df.node.NumericalNode; import org.apache.mahout.classifier.df.split.IgSplit; import org.apache.mahout.classifier.df.split.OptIgSplit; import org.apache.mahout.classifier.df.split.RegressionSplit; import org.apache.mahout.classifier.df.split.Split; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Collection; import java.util.Random; /** * Builds a classification tree or regression tree<br> * A classification tree is built when the criterion variable is the categorical attribute.<br> * A regression tree is built when the criterion variable is the numerical attribute. */ public class DecisionTreeBuilder implements TreeBuilder { private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class); private static final int[] NO_ATTRIBUTES = new int[0]; private static final double EPSILON = 1.0e-6; /** * indicates which CATEGORICAL attributes have already been selected in the parent nodes */ private boolean[] selected; /** * number of attributes to select randomly at each node */ private int m; /** * IgSplit implementation */ private IgSplit igSplit; /** * tree is complemented */ private boolean complemented = true; /** * minimum number for split */ private double minSplitNum = 2.0; /** * minimum proportion of the total variance for split */ private double minVarianceProportion = 1.0e-3; /** * full set data */ private Data fullSet; /** * minimum variance for split */ private double minVariance = Double.NaN; public void setM(int m) { this.m = m; } public void setIgSplit(IgSplit igSplit) { this.igSplit = igSplit; } public void setComplemented(boolean complemented) { this.complemented = complemented; } public void setMinSplitNum(int minSplitNum) { this.minSplitNum = minSplitNum; } public void setMinVarianceProportion(double minVarianceProportion) { this.minVarianceProportion = minVarianceProportion; } @Override public Node build(Random rng, Data data) { if (selected == null) { selected = new boolean[data.getDataset().nbAttributes()]; selected[data.getDataset().getLabelId()] = true; // never select the label } if (m == 0) { // set default m double e = data.getDataset().nbAttributes() - 1; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression m = (int) Math.ceil(e / 3.0); } else { // classification m = (int) Math.ceil(Math.sqrt(e)); } } if (data.isEmpty()) { return new Leaf(-1); } double sum = 0.0; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression // sum and sum squared of a label is computed double sumSquared = 0.0; for (int i = 0; i < data.size(); i++) { double label = data.getDataset().getLabel(data.get(i)); sum += label; sumSquared += label * label; } // computes the variance double var = sumSquared - (sum * sum) / data.size(); // computes the minimum variance if (Double.compare(minVariance, Double.NaN) == 0) { minVariance = var / data.size() * minVarianceProportion; log.debug("minVariance:{}", minVariance); } // variance is compared with minimum variance if ((var / data.size()) < minVariance) { log.debug("variance(" + (var / data.size()) + ") < minVariance(" + minVariance + ") Leaf(" + (sum / data.size()) + ')'); return new Leaf(sum / data.size()); } } else { // classification if (isIdentical(data)) { return new Leaf(data.majorityLabel(rng)); } if (data.identicalLabel()) { return new Leaf(data.getDataset().getLabel(data.get(0))); } } // store full set data if (fullSet == null) { fullSet = data; } int[] attributes = randomAttributes(rng, selected, m); if (attributes == null || attributes.length == 0) { // we tried all the attributes and could not split the data anymore double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression label = sum / data.size(); } else { // classification label = data.majorityLabel(rng); } log.warn("attribute which can be selected is not found Leaf({})", label); return new Leaf(label); } if (igSplit == null) { if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression igSplit = new RegressionSplit(); } else { // classification igSplit = new OptIgSplit(); } } // find the best split Split best = null; for (int attr : attributes) { Split split = igSplit.computeSplit(data, attr); if (best == null || best.getIg() < split.getIg()) { best = split; } } // information gain is near to zero. if (best.getIg() < EPSILON) { double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("ig is near to zero Leaf({})", label); return new Leaf(label); } log.debug("best split attr:" + best.getAttr() + ", split:" + best.getSplit() + ", ig:" + best.getIg()); boolean alreadySelected = selected[best.getAttr()]; if (alreadySelected) { // attribute already selected log.warn("attribute {} already selected in a parent node", best.getAttr()); } Node childNode; if (data.getDataset().isNumerical(best.getAttr())) { boolean[] temp = null; Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit())); Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit())); if (loSubset.isEmpty() || hiSubset.isEmpty()) { // the selected attribute did not change the data, avoid using it in the child notes selected[best.getAttr()] = true; } else { // the data changed, so we can unselect all previousely selected NUMERICAL attributes temp = selected; selected = cloneCategoricalAttributes(data.getDataset(), selected); } // size of the subset is less than the minSpitNum if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) { // branch is not split double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("branch is not split Leaf({})", label); return new Leaf(label); } Node loChild = build(rng, loSubset); Node hiChild = build(rng, hiSubset); // restore the selection state of the attributes if (temp != null) { selected = temp; } else { selected[best.getAttr()] = alreadySelected; } childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild); } else { // CATEGORICAL attribute double[] values = data.values(best.getAttr()); // tree is complemented Collection<Double> subsetValues = null; if (complemented) { subsetValues = Sets.newHashSet(); for (double value : values) { subsetValues.add(value); } values = fullSet.values(best.getAttr()); } int cnt = 0; Data[] subsets = new Data[values.length]; for (int index = 0; index < values.length; index++) { if (complemented && !subsetValues.contains(values[index])) { continue; } subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index])); if (subsets[index].size() >= minSplitNum) { cnt++; } } // size of the subset is less than the minSpitNum if (cnt < 2) { // branch is not split double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("branch is not split Leaf({})", label); return new Leaf(label); } selected[best.getAttr()] = true; Node[] children = new Node[values.length]; for (int index = 0; index < values.length; index++) { if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) { // tree is complemented double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("complemented Leaf({})", label); children[index] = new Leaf(label); continue; } children[index] = build(rng, subsets[index]); } selected[best.getAttr()] = alreadySelected; childNode = new CategoricalNode(best.getAttr(), values, children); } return childNode; } /** * checks if all the vectors have identical attribute values. Ignore selected attributes. * * @return true is all the vectors are identical or the data is empty<br> * false otherwise */ private boolean isIdentical(Data data) { if (data.isEmpty()) { return true; } Instance instance = data.get(0); for (int attr = 0; attr < selected.length; attr++) { if (selected[attr]) { continue; } for (int index = 1; index < data.size(); index++) { if (data.get(index).get(attr) != instance.get(attr)) { return false; } } } return true; } /** * Make a copy of the selection state of the attributes, unselect all numerical attributes * * @param selected selection state to clone * @return cloned selection state */ private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) { boolean[] cloned = new boolean[selected.length]; for (int i = 0; i < selected.length; i++) { cloned[i] = !dataset.isNumerical(i) && selected[i]; } cloned[dataset.getLabelId()] = true; return cloned; } /** * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes * * @param rng random-numbers generator * @param selected attributes' state (selected or not) * @param m number of attributes to choose * @return list of selected attributes' indices, or null if all attributes have already been selected */ private static int[] randomAttributes(Random rng, boolean[] selected, int m) { int nbNonSelected = 0; // number of non selected attributes for (boolean sel : selected) { if (!sel) { nbNonSelected++; } } if (nbNonSelected == 0) { log.warn("All attributes are selected !"); return NO_ATTRIBUTES; } int[] result; if (nbNonSelected <= m) { // return all non selected attributes result = new int[nbNonSelected]; int index = 0; for (int attr = 0; attr < selected.length; attr++) { if (!selected[attr]) { result[index++] = attr; } } } else { result = new int[m]; for (int index = 0; index < m; index++) { // randomly choose a "non selected" attribute int rind; do { rind = rng.nextInt(selected.length); } while (selected[rind]); result[index] = rind; selected[rind] = true; // temporarily set the chosen attribute to be selected } // the chosen attributes are not yet selected for (int attr : result) { selected[attr] = false; } } return result; } }