/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.rule; import org.junit.Test; import rapaio.core.RandomSource; import rapaio.data.*; import rapaio.datasets.Datasets; import rapaio.ml.classifier.CFit; import java.io.IOException; import java.net.URISyntaxException; import static org.junit.Assert.*; /** * User: Aurelian Tutuianu <paderati@yahoo.com> */ public class OneRuleTest { private static final int SIZE = 6; private final Var classVar; private final Var heightVar; public OneRuleTest() { classVar = Nominal.empty(SIZE, "False", "True").withName("class"); classVar.setLabel(0, "True"); classVar.setLabel(1, "True"); classVar.setLabel(2, "True"); classVar.setLabel(3, "False"); classVar.setLabel(4, "False"); classVar.setLabel(5, "False"); heightVar = Numeric.copy(0.1, 0.3, 0.5, 10, 10.3, 10.5).withName("height"); } @Test public void testSimpleNumeric() { Frame df = SolidFrame.byVars(SIZE, heightVar, classVar); String[] labels; OneRule oneRule = new OneRule(); oneRule = oneRule.withMinCount(1); oneRule.train(df, "class"); CFit pred = oneRule.fit(df); labels = new String[]{"True", "True", "True", "False", "False", "False"}; for (int i = 0; i < SIZE; i++) { assertEquals(labels[i], pred.firstClasses().label(i)); } oneRule.withMinCount(2); oneRule.train(df, "class"); pred = oneRule.fit(df); labels = new String[]{"True", "True", "TrueFalse", "TrueFalse", "False", "False"}; for (int i = 0; i < SIZE; i++) { assertTrue(labels[i].contains(pred.firstClasses().label(i))); } oneRule.withMinCount(3); oneRule.train(df, "class"); pred = oneRule.fit(df); labels = new String[]{"True", "True", "True", "False", "False", "False"}; for (int i = 0; i < SIZE; i++) { assertTrue(labels[i].equals(pred.firstClasses().label(i))); } oneRule.withMinCount(4); oneRule.train(df, "class"); pred = oneRule.fit(df); for (int i = 1; i < SIZE; i++) { assertTrue(pred.firstClasses().label(i).equals(pred.firstClasses().label(0))); } } @Test public void testSummary() throws IOException, URISyntaxException { Frame df1 = Datasets.loadIrisDataset(); OneRule oneRule1 = new OneRule(); oneRule1.train(df1, "class"); oneRule1.printSummary(); assertEquals("OneRule model\n" + "================\n" + "\n" + "Description:\n" + "OneRule (minCount=6)\n" + "\n" + "Capabilities:\n" + "types inputs/targets: BINARY,INDEX,NOMINAL,NUMERIC,ORDINAL,STAMP/NOMINAL\n" + "counts inputs/targets: [1,1000000] / [1,1]\n" + "missing inputs/targets: true/false\n" + "\n" + "Learned model:\n" + "input vars: \n" + "\n" + " 0. sepal-length : NUMERIC | 1. sepal-width : NUMERIC | 2. petal-length : NUMERIC | 3. petal-width : NUMERIC |\n" + "\n" + "target vars:\n" + "> class : NOMINAL [?,setosa,versicolor,virginica]\n" + "BestRuleSet {var=petal-length, acc=0.9533333}\n" + "> NumericRule {min=-Infinity, max=2.45, class=setosa, errors=0, total=50, acc=1 }\n" + "> NumericRule {min=2.45, max=4.75, class=versicolor, errors=1, total=45, acc=0.9777778 }\n" + "> NumericRule {min=4.75, max=Infinity, class=virginica, errors=6, total=55, acc=0.8909091 }\n" + "\n", oneRule1.summary()); oneRule1.printSummary(); Frame df2 = Datasets.loadMushrooms(); RandomSource.setSeed(1); OneRule oneRule2 = new OneRule(); oneRule2.train(df2, "classes"); oneRule2.printSummary(); assertEquals("OneRule model\n" + "================\n" + "\n" + "Description:\n" + "OneRule (minCount=6)\n" + "\n" + "Capabilities:\n" + "types inputs/targets: BINARY,INDEX,NOMINAL,NUMERIC,ORDINAL,STAMP/NOMINAL\n" + "counts inputs/targets: [1,1000000] / [1,1]\n" + "missing inputs/targets: true/false\n" + "\n" + "Learned model:\n" + "input vars: \n" + "\n" + " 0. cap-shape : NOMINAL | 6. gill-spacing : NOMINAL | 12. stalk-surface-below-ring : NOMINAL | 18. ring-type : NOMINAL |\n" + " 1. cap-surface : NOMINAL | 7. gill-size : NOMINAL | 13. stalk-color-above-ring : NOMINAL | 19. spore-print-color : NOMINAL |\n" + " 2. cap-color : NOMINAL | 8. gill-color : NOMINAL | 14. stalk-color-below-ring : NOMINAL | 20. population : NOMINAL |\n" + " 3. bruises : NOMINAL | 9. stalk-shape : NOMINAL | 15. veil-type : NOMINAL | 21. habitat : NOMINAL |\n" + " 4. odor : NOMINAL | 10. stalk-root : NOMINAL | 16. veil-color : NOMINAL | \n" + " 5. gill-attachment : NOMINAL | 11. stalk-surface-above-ring : NOMINAL | 17. ring-number : NOMINAL | \n" + "\n" + "target vars:\n" + "> classes : NOMINAL [?,p,e]\n" + "BestRuleSet {var=odor, acc=0.985229}\n" + "> NominalRule {value=?, class=e, errors=0, total=0, acc=0}\n" + "> NominalRule {value=p, class=p, errors=0, total=256, acc=1}\n" + "> NominalRule {value=a, class=e, errors=0, total=400, acc=1}\n" + "> NominalRule {value=l, class=e, errors=0, total=400, acc=1}\n" + "> NominalRule {value=n, class=e, errors=120, total=3,528, acc=0.9659864}\n" + "> NominalRule {value=f, class=p, errors=0, total=2,160, acc=1}\n" + "> NominalRule {value=c, class=p, errors=0, total=192, acc=1}\n" + "> NominalRule {value=y, class=p, errors=0, total=576, acc=1}\n" + "> NominalRule {value=s, class=p, errors=0, total=576, acc=1}\n" + "> NominalRule {value=m, class=p, errors=0, total=36, acc=1}\n" + "\n", oneRule2.summary()); } @Test public void testFit() throws IOException, URISyntaxException { Frame df1 = Datasets.loadMushrooms(); OneRule oneRule1 = new OneRule(); oneRule1.train(df1, "classes"); oneRule1.printSummary(); CFit fit1 = oneRule1.fit(df1, true, true); fit1.printSummary(); Frame df2 = Datasets.loadIrisDataset(); OneRule oneRule2 = new OneRule(); oneRule2.train(df2, "class"); oneRule2.printSummary(); CFit fit2 = oneRule2.fit(df2, true, true); fit2.printSummary(); } }