/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.smile.regression;
import hivemall.smile.data.Attribute;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.junit.Assert;
import org.junit.Test;
import smile.math.Math;
import smile.validation.LOOCV;
public class RegressionTreeTest {
@Test
public void testPredict() {
double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
{259.426, 232.5, 145.6, 108.632, 1948, 61.122},
{258.054, 368.2, 161.6, 109.773, 1949, 60.171},
{284.599, 335.1, 165.0, 110.929, 1950, 61.187},
{328.975, 209.9, 309.9, 112.075, 1951, 63.221},
{346.999, 193.2, 359.4, 113.270, 1952, 63.639},
{365.385, 187.0, 354.7, 115.094, 1953, 64.989},
{363.112, 357.8, 335.0, 116.219, 1954, 63.761},
{397.469, 290.4, 304.8, 117.388, 1955, 66.019},
{419.180, 282.2, 285.7, 118.734, 1956, 67.857},
{442.769, 293.6, 279.8, 120.445, 1957, 68.169},
{444.546, 468.1, 263.7, 121.950, 1958, 66.513},
{482.704, 381.3, 255.2, 123.366, 1959, 68.655},
{502.601, 393.1, 251.4, 125.368, 1960, 69.564},
{518.173, 480.6, 257.2, 127.852, 1961, 69.331},
{554.894, 400.7, 282.7, 130.081, 1962, 70.551}};
double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
112.6, 114.2, 115.7, 116.9};
Attribute[] attrs = new Attribute[longley[0].length];
for (int i = 0; i < attrs.length; i++) {
attrs[i] = new Attribute.NumericAttribute(i);
}
int n = longley.length;
LOOCV loocv = new LOOCV(n);
double rss = 0.0;
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(longley, loocv.train[i]);
double[] trainy = Math.slice(y, loocv.train[i]);
int maxLeafs = 10;
smile.math.Random rand = new smile.math.Random(i);
RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs, rand);
double r = y[loocv.test[i]] - tree.predict(longley[loocv.test[i]]);
rss += r * r;
}
Assert.assertTrue("MSE = " + (rss / n), (rss / n) < 42);
}
@Test
public void testSerPredict() throws HiveException {
double[][] longley = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323},
{259.426, 232.5, 145.6, 108.632, 1948, 61.122},
{258.054, 368.2, 161.6, 109.773, 1949, 60.171},
{284.599, 335.1, 165.0, 110.929, 1950, 61.187},
{328.975, 209.9, 309.9, 112.075, 1951, 63.221},
{346.999, 193.2, 359.4, 113.270, 1952, 63.639},
{365.385, 187.0, 354.7, 115.094, 1953, 64.989},
{363.112, 357.8, 335.0, 116.219, 1954, 63.761},
{397.469, 290.4, 304.8, 117.388, 1955, 66.019},
{419.180, 282.2, 285.7, 118.734, 1956, 67.857},
{442.769, 293.6, 279.8, 120.445, 1957, 68.169},
{444.546, 468.1, 263.7, 121.950, 1958, 66.513},
{482.704, 381.3, 255.2, 123.366, 1959, 68.655},
{502.601, 393.1, 251.4, 125.368, 1960, 69.564},
{518.173, 480.6, 257.2, 127.852, 1961, 69.331},
{554.894, 400.7, 282.7, 130.081, 1962, 70.551}};
double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
112.6, 114.2, 115.7, 116.9};
Attribute[] attrs = new Attribute[longley[0].length];
for (int i = 0; i < attrs.length; i++) {
attrs[i] = new Attribute.NumericAttribute(i);
}
int n = longley.length;
LOOCV loocv = new LOOCV(n);
for (int i = 0; i < n; i++) {
double[][] trainx = Math.slice(longley, loocv.train[i]);
double[] trainy = Math.slice(y, loocv.train[i]);
int maxLeafs = Integer.MAX_VALUE;
RegressionTree tree = new RegressionTree(attrs, trainx, trainy, maxLeafs);
byte[] b = tree.predictSerCodegen(true);
RegressionTree.Node node = RegressionTree.deserializeNode(b, b.length, true);
double expected = tree.predict(longley[loocv.test[i]]);
double actual = node.predict(longley[loocv.test[i]]);
Assert.assertEquals(expected, actual, 0.d);
}
}
}