/* * Logic for classification algorithm, * based on associative rules. */ package weka.classifiers.functions; import weka.classifiers.Classifier; import weka.core.FastVector; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.Option; import weka.core.OptionHandler; import weka.core.Capabilities.Capability; import weka.filters.Filter; import weka.filters.supervised.attribute.NominalToBinary; import java.util.*; import weka.core.*; import weka.associations.*; public class Assocify extends Classifier implements WeightedInstancesHandler, OptionHandler { /** number of possible classes */ private int m_ClassCount; /** class Attribute */ private Attribute m_Class; /** contains the distributions of the classes in the learning set */ private double [] m_ClassDistribution; /** contains the instances in the learning set */ private Instances m_LearningSet; /** contains the association rules mined */ private FastVector [] m_AssociationRules; /** contains each association class coverage used to evaluate */ /** instance class */ private double[][] m_ClassScores; /** number of rules mined by the association algorithm */ private int m_numRules; /** min metric, required to mine the rules */ private double m_minMetric; /** are there any association rules mined */ private Boolean m_RulesMined; /** association rules miner */ private Apriori m_Apriori; /** init options */ public Assocify() { m_numRules = 70; m_minMetric = 0.9; m_RulesMined = false; } /** * Returns a string describing this classifier * @return a description of the classifier suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Algorithm, using assciation rules to predict instance class"; } public Enumeration listOptions() { String string1 = "\tThe number of rules to be mined by the association algorithm (default: " + m_numRules + ")", string2 = "\tThe min metric, to be satisfied (default: " + m_minMetric + ")"; FastVector optionsVector = new FastVector(2); optionsVector.addElement(new Option(string1, "N", 1, "-N <required number of rules to be mined>")); optionsVector.addElement(new Option(string2, "C", 1, "-C <minimum metric score for a rule>")); return optionsVector.elements(); } public String[] getOptions() { String[] options = new String[4]; int current = 0; options[current++] = "-N"; options[current++] = "" + m_numRules; options[current++] = "-C"; options[current++] = "" + m_minMetric; while (current < options.length) { options[current++] = ""; } return options; } public void setOptions(String[] options) throws Exception { String a_numRules = Utils.getOption('N',options), a_minMetric = Utils.getOption('C',options); if(a_numRules.length() != 0) m_numRules = Integer.parseInt(a_numRules); if(a_minMetric.length() != 0) m_minMetric = (new Double(a_minMetric)).doubleValue(); } public String toString() { String result = "Association rules classifier initiates classifier\n\n"; result += "Association rules classifier tries to mine association rules\n\n"; if(m_RulesMined) { result += m_Apriori.toString(); } else { result += "No Rules were mined with minMetric: " + m_minMetric; } return result; } public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); //attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } public void buildClassifier(Instances instances) throws Exception { getCapabilities().testWithFail(instances); m_LearningSet = instances; m_ClassCount = m_LearningSet.numClasses(); m_Class = m_LearningSet.classAttribute(); EvaluateClassDistribution(); m_Apriori = new Apriori(); m_Apriori.setNumRules(this.m_numRules); m_Apriori.setMinMetric(this.m_minMetric); m_Apriori.buildAssociations(m_LearningSet); m_AssociationRules = m_Apriori.getAllTheRules(); FastVector premises = m_AssociationRules[0]; if(premises.size() > 0) { m_RulesMined = true; } EvaluateRulesClassDistribution(); } public double classifyInstance(Instance instance) throws Exception { double [] instanceScores = new double [m_ClassCount]; for(int i = 0; i < m_ClassCount; i++) { instanceScores[i] = m_ClassDistribution[i]; } FastVector premises = m_AssociationRules[0]; FastVector consequences = m_AssociationRules[1]; for(int rule = 0; rule < m_numRules; rule ++) { AprioriItemSet premise = (AprioriItemSet)premises.elementAt(rule); AprioriItemSet consequence = (AprioriItemSet)consequences.elementAt(rule); if(premise.containedBy(instance) && consequence.containedBy(instance)) { for(int i = 0; i < m_ClassCount; i++) instanceScores[i] *= m_ClassScores[rule][i]; } } double returnClass = 0; double currentScore = 0; for(int i = 0; i < m_ClassCount; i++) { if(currentScore < instanceScores[i]) { returnClass = (double)i; currentScore = instanceScores[i]; } } return returnClass; } private void EvaluateRulesClassDistribution() { m_ClassScores = new double[m_numRules][]; FastVector premises = m_AssociationRules[0]; FastVector consequences = m_AssociationRules[1]; m_numRules = premises.size(); for(int i = 0; i< m_numRules; i++) { double [] numerators = new double[m_ClassCount]; double [] denominators = new double[m_ClassCount]; for(int pos = 0; pos < m_ClassCount; pos++) { numerators[pos] = denominators[pos] = 0.0; } Enumeration enu = m_LearningSet.enumerateInstances(); while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); if (!instance.classIsMissing()) { int classIndex = (int)instance.classValue(); denominators[classIndex] += 1; AprioriItemSet premise = (AprioriItemSet)premises.elementAt(i); AprioriItemSet consequence = (AprioriItemSet)consequences.elementAt(i); if(premise.containedBy(instance) && consequence.containedBy(instance)) numerators[classIndex] += 1; } } double [] classScores = new double [m_ClassCount]; for(int pos = 0; pos < m_ClassCount; pos++) { classScores[pos] = numerators[pos] / denominators[pos]; } m_ClassScores[i] = classScores; } } private void EvaluateClassDistribution() { m_ClassDistribution = new double[m_ClassCount]; for(int i =0; i < m_ClassCount; i++) m_ClassDistribution[i] = 0; Enumeration enu = m_LearningSet.enumerateInstances(); double sumOfWeights = 0.0; while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); if (!instance.classIsMissing()) { m_ClassDistribution[(int)instance.classValue()] += instance.weight(); sumOfWeights += instance.weight(); } } for(int i =0; i < m_ClassCount; i++) m_ClassDistribution[i] /= sumOfWeights; } public static void main(String argv[]) { runClassifier(new Assocify(), argv); } }