/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.classifier.tree.ctree;
import org.junit.Before;
import org.junit.Test;
import rapaio.data.Frame;
import rapaio.data.Numeric;
import rapaio.data.SolidFrame;
import rapaio.data.Var;
import rapaio.ml.classifier.tree.CTreeCandidate;
import rapaio.ml.classifier.tree.CTreeMissingHandler;
import rapaio.util.Pair;
import java.util.List;
import static org.junit.Assert.*;
/**
* Tests splitters implementations for CTree
* <p>
* Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 9/29/15.
*/
public class CTreeMissingHandlerTest {
private Frame df;
private Var w;
private CTreeCandidate c;
@Before
public void setUp() throws Exception {
Numeric values = Numeric.wrap(1, 2, 3, 4, Double.NaN, Double.NaN, Double.NaN, -3, -2, -1);
df = SolidFrame.byVars(values.solidCopy().withName("x"));
w = values.solidCopy().stream().transValue(x -> Double.isNaN(x) ? x : Math.abs(x)).toMappedVar().withName("w");
c = new CTreeCandidate(1, "test");
c.addGroup("> 0", s -> s.value("x") > 0);
c.addGroup("< 0", s -> s.value("x") < 0);
}
@Test
public void testIgnored() {
Pair<List<Frame>, List<Var>> pairs = CTreeMissingHandler.Ignored.performSplit(df, w, c);
assertEquals(2, pairs._1.size());
assertEquals(2, pairs._2.size());
assertEquals(4, pairs._1.get(0).stream().filter(s -> s.value("x") > 0).count());
assertEquals(4, pairs._2.get(0).stream().filter(s -> s.value() > 0).count());
assertEquals(3, pairs._1.get(1).stream().filter(s -> s.value("x") < 0).count());
assertEquals(3, pairs._2.get(1).stream().filter(s -> s.value() > 0).count());
}
@Test
public void testMajority() {
Pair<List<Frame>, List<Var>> pairs = CTreeMissingHandler.ToMajority.performSplit(df, w, c);
assertEquals(2, pairs._1.size());
assertEquals(2, pairs._2.size());
assertEquals(7, pairs._1.get(0).stream().filter(s -> s.missing() || s.value("x") > 0).count());
assertEquals(7, pairs._2.get(0).stream().filter(s -> s.missing() || s.value() > 0).count());
assertEquals(3, pairs._1.get(1).stream().filter(s -> s.value("x") < 0).count());
assertEquals(3, pairs._2.get(1).stream().filter(s -> s.value() > 0).count());
}
@Test
public void testToAllWeighted() {
Pair<List<Frame>, List<Var>> pairs = CTreeMissingHandler.ToAllWeighted.performSplit(df, w, c);
assertEquals(2, pairs._1.size());
assertEquals(2, pairs._2.size());
assertEquals(7, pairs._1.get(0).stream().filter(s -> s.missing() || s.value("x") > 0).count());
assertEquals(7, pairs._2.get(0).stream().filter(s -> s.missing() || s.value() > 0).count());
assertEquals(6, pairs._1.get(1).stream().filter(s -> s.missing() || s.value("x") < 0).count());
assertEquals(6, pairs._2.get(1).stream().filter(s -> s.missing() || s.value() > 0).count());
assertEquals(1 + 2 + 3 + 4 + 3 * 4 / 7.0, pairs._2.get(0).stream().mapToDouble().sum(), 1e-20);
assertEquals(1 + 2 + 3 + 3 * 3 / 7.0, pairs._2.get(1).stream().mapToDouble().sum(), 1e-20);
}
@Test
public void testToRandom() {
Pair<List<Frame>, List<Var>> pairs = CTreeMissingHandler.ToRandom.performSplit(df, w, c);
df.printLines();
assertEquals(2, pairs._1.size());
assertEquals(2, pairs._2.size());
long firstCount1 = pairs._1.get(0).stream().filter(s -> s.missing() || s.value("x") > 0).count();
assertTrue(4 <= firstCount1);
assertTrue(7 >= firstCount1);
long firstCount2 = pairs._1.get(1).stream().filter(s -> s.missing() || s.value("x") < 0).count();
assertTrue(3 <= firstCount2);
assertTrue(6 >= firstCount2);
long secondCount1 = pairs._2.get(0).stream().count();
assertTrue(4 <= secondCount1);
assertTrue(7 >= secondCount1);
long secondCount2 = pairs._2.get(1).stream().count();
assertTrue(3 <= secondCount2);
assertTrue(6 >= secondCount2);
}
}