/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.df; import java.util.List; import java.util.Random; import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder; import org.apache.mahout.classifier.df.data.Data; import org.apache.mahout.classifier.df.data.DataLoader; import org.apache.mahout.classifier.df.data.Dataset; import org.apache.mahout.classifier.df.data.DescriptorException; import org.apache.mahout.classifier.df.data.Instance; import org.apache.mahout.classifier.df.node.Node; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.RandomUtils; import org.junit.Test; import com.google.common.collect.Lists; public final class DecisionForestTest extends MahoutTestCase { private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no", "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes", "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no", "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no", "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes", "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes", "rainy,71,91,TRUE,no"}; private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-", "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",}; private Random rng; @Override public void setUp() throws Exception { super.setUp(); rng = RandomUtils.getRandom(); } private static Data[] generateTrainingDataA() throws DescriptorException { // Dataset Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA); // Training data Data data = DataLoader.loadData(dataset, TRAIN_DATA); @SuppressWarnings("unchecked") List<Instance>[] instances = new List[3]; for (int i = 0; i < instances.length; i++) { instances[i] = Lists.newArrayList(); } for (int i = 0; i < data.size(); i++) { if (data.get(i).get(0) == 0.0d) { instances[0].add(data.get(i)); } else { instances[1].add(data.get(i)); } } Data[] datas = new Data[instances.length]; for (int i = 0; i < datas.length; i++) { datas[i] = new Data(dataset, instances[i]); } return datas; } private static Data[] generateTrainingDataB() throws DescriptorException { // Training data String[] trainData = new String[20]; for (int i = 0; i < trainData.length; i++) { if (i % 3 == 0) { trainData[i] = "A," + (40 - i) + ',' + (i + 20); } else if (i % 3 == 1) { trainData[i] = "B," + (i + 20) + ',' + (40 - i); } else { trainData[i] = "C," + (i + 20) + ',' + (i + 20); } } // Dataset Dataset dataset = DataLoader.generateDataset("C N L", true, trainData); Data[] datas = new Data[3]; datas[0] = DataLoader.loadData(dataset, trainData); // Training data trainData = new String[20]; for (int i = 0; i < trainData.length; i++) { if (i % 2 == 0) { trainData[i] = "A," + (50 - i) + ',' + (i + 10); } else { trainData[i] = "B," + (i + 10) + ',' + (50 - i); } } datas[1] = DataLoader.loadData(dataset, trainData); // Training data trainData = new String[10]; for (int i = 0; i < trainData.length; i++) { trainData[i] = "A," + (40 - i) + ',' + (i + 20); } datas[2] = DataLoader.loadData(dataset, trainData); return datas; } private DecisionForest buildForest(Data[] datas) { List<Node> trees = Lists.newArrayList(); for (Data data : datas) { // build tree DecisionTreeBuilder builder = new DecisionTreeBuilder(); builder.setM(data.getDataset().nbAttributes() - 1); builder.setMinSplitNum(0); builder.setComplemented(false); trees.add(builder.build(rng, data)); } return new DecisionForest(trees); } @Test public void testClassify() throws DescriptorException { // Training data Data[] datas = generateTrainingDataA(); // Build Forest DecisionForest forest = buildForest(datas); // Test data Dataset dataset = datas[0].getDataset(); Data testData = DataLoader.loadData(dataset, TEST_DATA); double noValue = dataset.valueOf(4, "no"); double yesValue = dataset.valueOf(4, "yes"); assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(0)), EPSILON); // This one is tie-broken -- 1 is OK too //assertEquals(yesValue, forest.classify(testData.getDataset(), rng, testData.get(1)), EPSILON); assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(2)), EPSILON); } @Test public void testClassifyData() throws DescriptorException { // Training data Data[] datas = generateTrainingDataA(); // Build Forest DecisionForest forest = buildForest(datas); // Test data Dataset dataset = datas[0].getDataset(); Data testData = DataLoader.loadData(dataset, TEST_DATA); double[][] predictions = new double[testData.size()][]; forest.classify(testData, predictions); double noValue = dataset.valueOf(4, "no"); double yesValue = dataset.valueOf(4, "yes"); assertArrayEquals(new double[][]{{noValue, Double.NaN, Double.NaN}, {noValue, yesValue, Double.NaN}, {noValue, noValue, Double.NaN}}, predictions); } @Test public void testRegression() throws DescriptorException { Data[] datas = generateTrainingDataB(); DecisionForest[] forests = new DecisionForest[datas.length]; for (int i = 0; i < datas.length; i++) { Data[] subDatas = new Data[datas.length - 1]; int k = 0; for (int j = 0; j < datas.length; j++) { if (j != i) { subDatas[k] = datas[j]; k++; } } forests[i] = buildForest(subDatas); } double[][] predictions = new double[datas[0].size()][]; forests[0].classify(datas[0], predictions); assertArrayEquals(new double[]{20.0, 20.0}, predictions[0], EPSILON); assertArrayEquals(new double[]{39.0, 29.0}, predictions[1], EPSILON); assertArrayEquals(new double[]{Double.NaN, 29.0}, predictions[2], EPSILON); assertArrayEquals(new double[]{Double.NaN, 23.0}, predictions[17], EPSILON); predictions = new double[datas[1].size()][]; forests[1].classify(datas[1], predictions); assertArrayEquals(new double[]{30.0, 29.0}, predictions[19], EPSILON); predictions = new double[datas[2].size()][]; forests[2].classify(datas[2], predictions); assertArrayEquals(new double[]{29.0, 28.0}, predictions[9], EPSILON); assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), EPSILON); assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), EPSILON); assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), EPSILON); } }