/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.experiment.selection;
import rapaio.core.tools.DVector;
import rapaio.data.*;
import rapaio.printer.Printable;
import rapaio.sys.WS;
import rapaio.util.Pair;
import java.util.*;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
/**
* Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 12/23/15.
*/
public class Apriori implements Printable {
private Var targetVar;
private Frame inputDf;
private BiPredicate<Integer, DVector> filter;
private List<List<Pair<AprioriRule, DVector>>> P;
private List<AprioriRule> rules;
private double coverage;
public String[] inputVarNames() {
return inputDf.varNames();
}
public void train(Frame df, String target, BiPredicate<Integer, DVector> filter) {
List<Var> inputVars = df.varStream()
.filter(var -> var.type().equals(VarType.NOMINAL))
.filter(var -> !var.name().equals(target))
.collect(Collectors.toList());
this.inputDf = SolidFrame.byVars(inputVars);
this.targetVar = df.var(target);
this.filter = filter;
List<AprioriRule> C = new ArrayList<>();
P = new ArrayList<>();
// build dictionary of rules $C0$
for (int i = 0; i < inputDf.varCount(); i++) {
Var input = inputDf.var(i);
for (String level : input.levels()) {
AprioriRuleClause clause = new AprioriRuleClause(input.name(), level);
AprioriRule rule = new AprioriRule();
rule.addClause(clause);
C.add(rule);
}
}
List<Pair<AprioriRule, DVector>> counts = C.stream().map(rule -> Pair.from(rule,
DVector.empty(false, targetVar.levels())))
.collect(Collectors.toList());
for (int i = 0; i < df.rowCount(); i++) {
for (Pair<AprioriRule, DVector> cnt : counts) {
if (cnt._1.matchRow(df, i)) {
cnt._2.increment(targetVar.index(i), 1);
}
}
}
List<Pair<AprioriRule, DVector>> list = counts.stream()
.filter(pair -> filter.test(df.rowCount(), pair._2))
.collect(Collectors.toList());
list.sort((o1, o2) -> -Double.compare(o1._2.sum(), o2._2.sum()));
P.add(list);
// do iterations
List<AprioriRule> base = P.get(0).stream().map(pair -> pair._1).collect(Collectors.toList());
while (true) {
int k = P.size();
Map<String, Pair<AprioriRule, DVector>> cnts = new HashMap<>();
// loop for all possibilities
for (int i = 0; i < df.rowCount(); i++) {
for (AprioriRule b : base) {
if (!b.matchRow(df, i))
continue;
for (Pair<AprioriRule, DVector> tPrev : P.get(k - 1)) {
if (!tPrev._1.isExtention(b))
continue;
if (!tPrev._1.matchRow(df, i))
continue;
AprioriRule next = tPrev._1.extend(b);
if (!cnts.containsKey(next.toString())) {
cnts.put(next.toString(), Pair.from(next, DVector.empty(false, targetVar.levels())));
}
cnts.get(next.toString())._2.increment(targetVar.index(i), 1);
}
}
}
// keep only survivors
List<Pair<AprioriRule, DVector>> top = cnts.values().stream()
.filter(pair -> filter.test(df.rowCount(), pair._2))
.collect(Collectors.toList());
if (top.isEmpty()) {
break;
}
top.sort((o1, o2) -> -Double.compare(o1._2.sum(), o2._2.sum()));
P.add(top);
}
// eliminate redundant tasks
for (int i = 0; i < P.size() - 1; i++) {
Iterator<Pair<AprioriRule, DVector>> it = P.get(i).iterator();
while (it.hasNext()) {
Pair<AprioriRule, DVector> next = it.next();
boolean out = false;
for (int j = i + 1; j < P.size(); j++) {
for (Pair<AprioriRule, DVector> pair : P.get(j)) {
if (pair._1.contains(next._1)) {
out = true;
break;
}
}
if (out)
break;
}
if (out)
it.remove();
}
}
// create final rules
rules = new ArrayList<>();
for (List<Pair<AprioriRule, DVector>> aP : P) {
rules.addAll(aP.stream().map(pair -> pair._1).collect(Collectors.toSet()));
}
// create coverage
double count = 0;
for (int i = 0; i < df.rowCount(); i++) {
for (int j = 0; j < rules.size(); j++) {
if (rules.get(j).matchRow(df, i)) {
count++;
break;
}
}
}
coverage = count / df.rowCount();
}
@Override
public String summary() {
StringBuilder sb = new StringBuilder();
// print a list of rules
sb.append("# Apriori\n");
for (int i = 0; i < P.size(); i++) {
sb.append("## rules of size: ").append(i + 1).append("\n");
for (int j = 0; j < P.get(i).size(); j++) {
sb.append(j + 1).append(". ").append(P.get(i).get(j)._1.toString())
.append(" ")
.append(WS.formatFlex(P.get(i).get(j)._2.sum())).append(" [");
for (int k = 1; k < targetVar.levels().length; k++) {
sb.append(WS.formatShort(P.get(i).get(j)._2.get(k))).append(",");
}
sb.append("]\n");
}
}
sb.append("\n");
sb.append("Rules: ").append(rules.size()).append("\n");
sb.append("Coverage: ").append(WS.formatFlex(coverage)).append("\n");
sb.append("\n");
return sb.toString();
}
public Frame buildFeatures(Frame df) {
List<Var> vars = rules.stream().map(r -> Nominal.empty(df.rowCount(), "?", "1", "0")).collect(Collectors.toList());
for (int i = 0; i < vars.size(); i++) {
vars.get(i).withName("Apriori_" + (i + 1));
}
for (int i = 0; i < df.rowCount(); i++) {
for (int j = 0; j < rules.size(); j++) {
vars.get(j).setIndex(i, rules.get(j).matchRow(df, i) ? 1 : 2);
}
}
return SolidFrame.byVars(vars);
}
}
class AprioriRule {
public final List<AprioriRuleClause> clauses = new ArrayList<>();
public void addClause(AprioriRuleClause clause) {
clauses.add(clause);
}
public boolean matchRow(Frame df, int row) {
for (AprioriRuleClause clause : clauses) {
if (!df.label(row, clause.varName).equals(clause.level))
return false;
}
return true;
}
public boolean isExtention(AprioriRule rule) {
return rule.clauses.size() == 1 &&
rule.clauses.get(0).full.compareTo(clauses.get(clauses.size() - 1).full) > 0;
}
public AprioriRule extend(AprioriRule rule) {
if (rule.clauses.size() != 1)
return null;
AprioriRule next = new AprioriRule();
for (AprioriRuleClause c : clauses)
next.addClause(c);
next.addClause(rule.clauses.get(0));
return next;
}
public boolean contains(AprioriRule rule) {
Set<AprioriRuleClause> set = clauses.stream().collect(Collectors.toSet());
for (AprioriRuleClause c : rule.clauses) {
if (!set.contains(c))
return false;
}
return true;
}
@Override
public String toString() {
return clauses.stream().map(AprioriRuleClause::toString).collect(Collectors.joining(", "));
}
}
class AprioriRuleClause {
public final String varName;
public final String level;
public final String full;
public AprioriRuleClause(String varName, String level) {
this.varName = varName;
this.level = level;
this.full = varName + ":" + level;
}
public String toString() {
return full;
}
}