/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.ftvec.selection; import hivemall.utils.hadoop.WritableUtils; import java.util.ArrayList; import java.util.List; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.junit.Assert; import org.junit.Test; public class ChiSquareUDFTest { @Test public void testIris() throws Exception { final ChiSquareUDF chi2 = new ChiSquareUDF(); final List<List<DoubleWritable>> observed = new ArrayList<List<DoubleWritable>>(); final List<List<DoubleWritable>> expected = new ArrayList<List<DoubleWritable>>(); final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] { new GenericUDF.DeferredJavaObject(observed), new GenericUDF.DeferredJavaObject(expected)}; final double[][] matrix0 = new double[][] { {250.29999999999998, 170.90000000000003, 73.2, 12.199999999999996}, {296.8, 138.50000000000003, 212.99999999999997, 66.3}, {329.3999999999999, 148.7, 277.59999999999997, 101.29999999999998}}; final double[][] matrix1 = new double[][] { {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}, {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}, {292.1666753739119, 152.70000455081467, 187.93333893418327, 59.93333511948589}}; for (double[] row : matrix0) { observed.add(WritableUtils.toWritableList(row)); } for (double[] row : matrix1) { expected.add(WritableUtils.toWritableList(row)); } chi2.initialize(new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)), ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector))}); final List<DoubleWritable>[] result = chi2.evaluate(dObjs); final double[] result0 = new double[matrix0[0].length]; final double[] result1 = new double[matrix0[0].length]; for (int i = 0; i < result0.length; i++) { result0[i] = result[0].get(i).get(); result1[i] = result[1].get(i).get(); } // compare results to one of scikit-learn final double[] answer0 = new double[] {10.81782088, 3.59449902, 116.16984746, 67.24482759}; final double[] answer1 = new double[] {4.47651499e-03, 1.65754167e-01, 5.94344354e-26, 2.50017968e-15}; Assert.assertArrayEquals(answer0, result0, 1e-5); Assert.assertArrayEquals(answer1, result1, 1e-5); chi2.close(); } }