/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier.rule; import rapaio.core.tools.DVector; import rapaio.data.*; import rapaio.data.filter.var.VFRefSort; import rapaio.ml.classifier.AbstractClassifier; import rapaio.ml.classifier.CFit; import rapaio.ml.classifier.rule.onerule.NominalRule; import rapaio.ml.classifier.rule.onerule.NumericRule; import rapaio.ml.classifier.rule.onerule.Rule; import rapaio.ml.classifier.rule.onerule.RuleSet; import rapaio.ml.common.Capabilities; import rapaio.sys.WS; import rapaio.util.Pair; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; import java.util.stream.IntStream; /** * @author <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> */ public class OneRule extends AbstractClassifier { private static final long serialVersionUID = 6220103690711818091L; private static final Logger log = Logger.getLogger(OneRule.class.getName()); private double minCount = 6; private RuleSet bestRuleSet; @Override public String name() { return "OneRule"; } @Override public String fullName() { return String.format("OneRule (minCount=%s)", WS.formatFlex(minCount)); } @Override public OneRule newInstance() { return new OneRule().withMinCount(minCount); } public OneRule withMinCount(double minCount) { this.minCount = minCount; return this; } @Override public Capabilities capabilities() { return new Capabilities() .withInputCount(1, 1_000_000) .withTargetCount(1, 1) .withInputTypes(VarType.BINARY, VarType.INDEX, VarType.NOMINAL, VarType.NUMERIC, VarType.ORDINAL, VarType.STAMP) .withTargetTypes(VarType.NOMINAL) .withAllowMissingInputValues(true) .withAllowMissingTargetValues(false); } @Override protected boolean coreTrain(Frame df, Var weights) { bestRuleSet = null; for (String testCol : inputNames()) { RuleSet ruleSet; switch (df.var(testCol).type()) { case BINARY: case INDEX: case NUMERIC: case ORDINAL: case STAMP: ruleSet = buildNumeric(testCol, df, weights); break; default: ruleSet = buildNominal(testCol, df, weights); } if (bestRuleSet == null || ruleSet.getAccuracy() > bestRuleSet.getAccuracy()) { bestRuleSet = ruleSet; } } return true; } @Override protected CFit coreFit(final Frame test, final boolean withClasses, final boolean withDensities) { CFit pred = CFit.build(this, test, withClasses, withDensities); for (int i = 0; i < test.rowCount(); i++) { Pair<String, DVector> p = predict(test, i); if (withClasses) { pred.firstClasses().setLabel(i, p._1); } if (withDensities) { String[] dict = firstTargetLevels(); DVector dv = p._2.solidCopy(); dv.normalize(); for (int j = 0; j < dict.length; j++) { pred.firstDensity().setValue(i, j, dv.get(j)); } } } return pred; } private Pair<String, DVector> predict(Frame df, int row) { if (bestRuleSet == null) { log.severe("Best rule not found. Either the classifier was not trained, either something went wrong."); return Pair.from("?", DVector.empty(true, firstTargetLevels().length)); } String testVar = bestRuleSet.getVarName(); switch (df.var(testVar).type()) { case BINARY: case INDEX: case NUMERIC: case ORDINAL: case STAMP: boolean missing = df.var(testVar).missing(row); double value = df.value(row, testVar); for (Rule oneRule : bestRuleSet.getRules()) { NumericRule numRule = (NumericRule) oneRule; if (missing && numRule.isMissingValue()) { return Pair.from(numRule.getTargetClass(), numRule.getDV()); } if (!missing && !numRule.isMissingValue() && value >= numRule.getMinValue() && value <= numRule.getMaxValue()) { return Pair.from(numRule.getTargetClass(), numRule.getDV()); } } break; default: String label = df.label(row, testVar); for (Rule oneRule : bestRuleSet.getRules()) { NominalRule nomRule = (NominalRule) oneRule; if (nomRule.getTestLabel().equals(label)) { return Pair.from(nomRule.getTargetClass(), nomRule.getDV()); } } } return Pair.from("?", DVector.empty(true, firstTargetLevels().length)); } private RuleSet buildNominal(String testVar, Frame df, Var weights) { RuleSet set = new RuleSet(testVar); String[] testDict = df.var(testVar).levels(); String[] targetDict = firstTargetLevels(); DVector[] dvs = IntStream.range(0, testDict.length).boxed().map(i -> DVector.empty(false, targetDict)).toArray(DVector[]::new); df.stream().forEach(s -> dvs[df.index(s.row(), testVar)].increment(df.index(s.row(), firstTargetName()), weights.value(s.row()))); for (int i = 0; i < testDict.length; i++) { DVector dv = dvs[i]; int bestIndex = dv.findBestIndex(); set.getRules().add(new NominalRule(testDict[i], bestIndex, dv)); } return set; } private RuleSet buildNumeric(String testCol, Frame df, Var weights) { RuleSet set = new RuleSet(testCol); Var sort = new VFRefSort(RowComparators.numeric(df.var(testCol), true), RowComparators.nominal(df.var(firstTargetName()), true)).fitApply(Index.seq(weights.rowCount())); int pos = 0; while (pos < sort.rowCount()) { if (df.missing(sort.index(pos), testCol)) { pos++; continue; } break; } // first process missing values if (pos > 0) { DVector hist = DVector.empty(true, firstTargetLevels()); for (int i = 0; i < pos; i++) { hist.increment(df.index(sort.index(i), firstTargetName()), weights.value(sort.index(i))); } List<Integer> best = new ArrayList<>(); double max = Double.MIN_VALUE; int next = hist.findBestIndex(); set.getRules().add(new NumericRule(Double.NaN, Double.NaN, true, next, hist)); } // now learn numeric intervals List<NumericRule> candidates = new ArrayList<>(); //splits from same value int i = pos; int index; while (i < sort.rowCount()) { // start a new bucket int startIndex = i; DVector hist = DVector.empty(true, firstTargetLevels()); do { // fill it until it has enough of the majority class index = df.index(sort.index(i), firstTargetName()); hist.increment(index, weights.value(sort.index(i))); i++; } while (hist.get(index) < minCount && i < sort.rowCount()); // while class remains the same, keep on filling while (i < sort.rowCount()) { index = sort.index(i); if (df.index(sort.index(i), firstTargetName()) == index) { hist.increment(index, weights.value(sort.index(i))); i++; continue; } break; } // keep on while attr value is the same while (i < sort.rowCount() && df.value(sort.index(i - 1), testCol) == df.value(sort.index(i), testCol)) { index = df.index(sort.index(i), firstTargetName()); hist.increment(index, weights.value(sort.index(i))); i++; } int next = hist.findBestIndex(); double minValue = Double.NEGATIVE_INFINITY; if (startIndex != pos) { minValue = (df.value(sort.index(startIndex), testCol) + df.value(sort.index(startIndex - 1), testCol)) / 2.; } double maxValue = Double.POSITIVE_INFINITY; if (i != sort.rowCount()) { maxValue = (df.value(sort.index(i - 1), testCol) + df.value(sort.index(i), testCol)) / 2; } candidates.add(new NumericRule(minValue, maxValue, false, next, hist)); } NumericRule last = null; for (NumericRule rule : candidates) { if (last == null) { last = rule; continue; } if (last.getTargetClass().equals(rule.getTargetClass())) { DVector dv = last.getDV().solidCopy(); dv.increment(rule.getDV()); last = new NumericRule(last.getMinValue(), rule.getMaxValue(), false, last.getTargetIndex(), dv); } else { set.getRules().add(last); last = rule; } } set.getRules().add(last); return set; } @Override public String summary() { StringBuilder sb = new StringBuilder(); sb.append("OneRule model\n"); sb.append("================\n\n"); sb.append("Description:\n"); sb.append(fullName()).append("\n\n"); sb.append("Capabilities:\n"); sb.append(capabilities().summary()).append("\n"); sb.append("Learned model:\n"); if (!hasLearned()) { sb.append("Learning phase not called\n\n"); return sb.toString(); } sb.append(baseSummary()); sb.append("Best").append(bestRuleSet.toString()).append("\n"); for (Rule rule : bestRuleSet.getRules()) { sb.append("> ").append(rule.toString()).append("\n"); } sb.append("\n"); return sb.toString(); } }