/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.smile.classification; import static org.junit.Assert.assertEquals; import hivemall.smile.ModelType; import hivemall.smile.classification.DecisionTree.Node; import hivemall.smile.data.Attribute; import hivemall.smile.tools.TreePredictUDF; import hivemall.smile.utils.SmileExtUtils; import hivemall.smile.vm.StackMachine; import hivemall.utils.lang.ArrayUtils; import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.text.ParseException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.IntWritable; import org.junit.Assert; import org.junit.Test; import smile.data.AttributeDataset; import smile.data.parser.ArffParser; import smile.math.Math; import smile.validation.LOOCV; public class DecisionTreeTest { private static final boolean DEBUG = false; /** * Test of learn method, of class DecisionTree. * * @throws ParseException * @throws IOException */ @Test public void testWeather() throws IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset weather = arffParser.parse(is); double[][] x = weather.toArray(new double[weather.size()][]); int[] y = weather.toArray(new int[weather.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); int error = 0; for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(weather.attributes()); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 3); if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) error++; } debugPrint("Decision Tree error = " + error); assertEquals(5, error); } @Test public void testIris() throws IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); int error = 0; for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); smile.math.Random rand = new smile.math.Random(i); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, Integer.MAX_VALUE, rand); if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) error++; } debugPrint("Decision Tree error = " + error); assertEquals(8, error); } @Test public void testIrisDepth4() throws IOException, ParseException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); int error = 0; for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) error++; } debugPrint("Decision Tree error = " + error); assertEquals(7, error); } @Test public void testIrisStackmachine() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); assertEquals(tree.predict(x[loocv.test[i]]), predictByStackMachine(tree, x[loocv.test[i]])); } } @Test public void testIrisJavascript() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); assertEquals(tree.predict(x[loocv.test[i]]), predictByJavascript(tree, x[loocv.test[i]])); } } @Test public void testIrisSerializedObj() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); byte[] b = tree.predictSerCodegen(false); Node node = DecisionTree.deserializeNode(b, b.length, false); assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]])); } } @Test public void testIrisSerializeObjCompressed() throws IOException, ParseException, HiveException { URL url = new URL( "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"); InputStream is = new BufferedInputStream(url.openStream()); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); AttributeDataset iris = arffParser.parse(is); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); for (int i = 0; i < n; i++) { double[][] trainx = Math.slice(x, loocv.train[i]); int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, trainx, trainy, 4); byte[] b1 = tree.predictSerCodegen(true); byte[] b2 = tree.predictSerCodegen(false); Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length, b1.length < b2.length); Node node = DecisionTree.deserializeNode(b1, b1.length, true); assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]])); } } private static int predictByStackMachine(DecisionTree tree, double[] x) throws HiveException, IOException { String script = tree.predictOpCodegen(StackMachine.SEP); debugPrint(script); TreePredictUDF udf = new TreePredictUDF(); udf.initialize(new ObjectInspector[] { PrimitiveObjectInspectorFactory.javaStringObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), new DeferredJavaObject(ModelType.opscode.getId()), new DeferredJavaObject(script), new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)}; IntWritable result = (IntWritable) udf.evaluate(arguments); result = (IntWritable) udf.evaluate(arguments); udf.close(); return result.get(); } private static int predictByJavascript(DecisionTree tree, double[] x) throws HiveException, IOException { String script = tree.predictJsCodegen(); debugPrint(script); TreePredictUDF udf = new TreePredictUDF(); udf.initialize(new ObjectInspector[] { PrimitiveObjectInspectorFactory.javaStringObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), new DeferredJavaObject(ModelType.javascript.getId()), new DeferredJavaObject(script), new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)}; IntWritable result = (IntWritable) udf.evaluate(arguments); result = (IntWritable) udf.evaluate(arguments); udf.close(); return result.get(); } private static void debugPrint(String msg) { if (DEBUG) { System.out.println(msg); } } }