/** * (C) Copyright IBM Corp. 2010, 2015 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *  */ package com.ibm.bi.dml.test.integration.applications; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import org.junit.Assert; import org.junit.runners.Parameterized.Parameters; import com.ibm.bi.dml.api.DMLScript.RUNTIME_PLATFORM; import com.ibm.bi.dml.runtime.matrix.data.MatrixValue.CellIndex; import com.ibm.bi.dml.test.integration.AutomatedTestBase; import com.ibm.bi.dml.test.utils.TestUtils; import com.ibm.bi.dml.utils.Statistics; public abstract class ID3Test extends AutomatedTestBase { protected final static String TEST_DIR = "applications/id3/"; protected final static String TEST_NAME = "id3"; protected int numRecords, numFeatures; public ID3Test(int numRecords, int numFeatures) { this.numRecords = numRecords; this.numFeatures = numFeatures; } @Parameters public static Collection<Object[]> data() { //TODO fix R script (values in 'nodes' for different settings incorrect, e.g., with minSplit=10 instead of 2) Object[][] data = new Object[][] { {100, 50}, {1000, 50} }; return Arrays.asList(data); } @Override public void setUp() { addTestConfiguration(TEST_DIR, TEST_NAME); } protected void testID3(ScriptType scriptType) { System.out.println("------------ BEGIN " + TEST_NAME + " " + scriptType + " TEST {" + numRecords + ", " + numFeatures + "} ------------"); this.scriptType = scriptType; int rows = numRecords; // # of rows in the training data int cols = numFeatures; getAndLoadTestConfiguration(TEST_NAME); List<String> proArgs = new ArrayList<String>(); if (scriptType == ScriptType.PYDML) { proArgs.add("-python"); } proArgs.add("-explain"); proArgs.add("-args"); proArgs.add(input("X")); proArgs.add(input("y")); proArgs.add(output("nodes")); proArgs.add(output("edges")); programArgs = proArgs.toArray(new String[proArgs.size()]); fullDMLScriptName = getScript(); rCmd = getRCmd(inputDir(), expectedDir()); // prepare training data set double[][] X = round(getRandomMatrix(rows, cols, 1, 10, 1.0, 3)); double[][] y = round(getRandomMatrix(rows, 1, 1, 10, 1.0, 7)); writeInputMatrixWithMTD("X", X, true); writeInputMatrixWithMTD("y", y, true); //run tests //(changed expected MR from 62 to 66 because we now also count MR jobs in predicates) //(changed expected MR from 66 to 68 because we now rewrite sum(v1*v2) to t(v1)%*%v2 which rarely creates more jobs due to MMCJ incompatibility of other operations) runTest(true, EXCEPTION_NOT_EXPECTED, null, 68); //max 68 compiled jobs runRScript(true); //check also num actually executed jobs if(AutomatedTestBase.rtplatform != RUNTIME_PLATFORM.SPARK) { long actualMR = Statistics.getNoOfExecutedMRJobs(); Assert.assertEquals("Wrong number of executed jobs: expected 0 but executed "+actualMR+".", 0, actualMR); } //compare results HashMap<CellIndex, Double> nR = readRMatrixFromFS("nodes"); HashMap<CellIndex, Double> nSYSTEMML= readDMLMatrixFromHDFS("nodes"); HashMap<CellIndex, Double> eR = readRMatrixFromFS("edges"); HashMap<CellIndex, Double> eSYSTEMML= readDMLMatrixFromHDFS("edges"); TestUtils.compareMatrices(nR, nSYSTEMML, Math.pow(10, -14), "nR", "nSYSTEMML"); TestUtils.compareMatrices(eR, eSYSTEMML, Math.pow(10, -14), "eR", "eSYSTEMML"); } private double[][] round( double[][] data ) { for( int i=0; i<data.length; i++ ) for( int j=0; j<data[i].length; j++ ) data[i][j] = Math.round(data[i][j]); return data; } }