/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ package weka.classifiers.bayes; import weka.classifiers.*; import weka.classifiers.sparse.*; import java.io.*; import java.util.*; import weka.core.*; /** * Semi supervised learner that uses EM initialized with labeled data and then * runs EM iterations on the unlabeled data to improve the model. * * See: Kamal Nigam, Andrew McCallum, Sebastian Thrun and Tom * Mitchell. Text Classification from Labeled and Unlabeled Documents * using EM. Machine Learning, 39(2/3). pp. 103-134. 2000. * * Assumes use of a base classifier that is a SoftClassifer that * accepts training data with a soft class distribution rather than * a hard assignment, i.e. SoftClassifiedInstances. Sample soft * classifiers are NaiveBayesSimpleSoft and NaiveBayesSimpleSparseSoft * * @author Ray Mooney (mooney@cs.utexas.edu)*/ public class SemiSupEM extends DistributionClassifier implements SemiSupClassifier, OptionHandler{ /** Original set of unlabeled Instances */ protected Instances m_UnlabeledData; /** Soft labeled version of unlabeled data */ protected SoftClassifiedInstances m_UnlabeledInstances; /** Hard Labeled data */ protected Instances m_LabeledInstances; /** Complete set of labeled and unlabeled instances for EM */ protected SoftClassifiedInstances m_AllInstances; /** Base classifier that supports soft classified instances */ protected SoftClassifier m_Classifier = new NaiveBayesSimpleSoft(); /** Weight of unlabeled examples during EM training versus labeled examples (see Nigam et al.)*/ protected double m_Lambda = 1.0; /** random numbers and seed */ protected Random m_Random; protected int m_rseed; /** maximum iterations to perform */ protected int m_max_iterations; /** Create soft labeled Seed for unseen classes */ protected boolean m_seedUnseenClasses; /** Verbose? */ protected boolean m_verbose; protected static double m_minLogLikelihoodIncr = 1e-6; /** The minimum values for numeric attributes. */ protected double [] m_MinArray; /** The maximum values for numeric attributes. */ protected double [] m_MaxArray; /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Classifier trained using both labeled and unlabeled data using EM"; } /** * Returns an enumeration describing the available options.. <p> * * Valid options are:<p> * * -V <br> * Verbose. <p> * * -I <max iterations> <br> * Terminate after this many iterations if EM has not converged. <p> * * -S <seed> <br> * Specify random number seed. <p> * * -M <num> <br> * Set the minimum allowable standard deviation for normal density * calculation. <p> * * @return an enumeration of all the available options. * **/ public Enumeration listOptions () { Vector newVector = new Vector(7); newVector.addElement(new Option( "\tFull name of classifier to boost.\n" +"\teg: weka.classifiers.bayes.NaiveBayes", "W", 1, "-W <class name>")); newVector.addElement(new Option("\tLambda weight for unlabeled data.\n(default 1)", "L" , 1, "-L <num>")); newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I" , 1, "-I <num>")); newVector.addElement(new Option("\trandom number seed.\n(default 1)" , "S", 1, "-S <num>")); newVector.addElement(new Option("\tverbose.", "V", 0, "-V")); newVector.addElement(new Option("\tSeed unseen classes.", "U", 0, "-U")); if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { newVector.addElement(new Option( "", "", 0, "\nOptions specific to classifier " + m_Classifier.getClass().getName() + ":")); Enumeration enum = ((OptionHandler)m_Classifier).listOptions(); while (enum.hasMoreElements()) { newVector.addElement(enum.nextElement()); } } return newVector.elements(); } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions (String[] options) throws Exception { resetOptions(); String classifierName = Utils.getOption('W', options); if (classifierName.length() == 0) { throw new Exception("A classifier must be specified with" + " the -W option."); } setClassifier((SoftClassifier)Classifier.forName(classifierName, Utils.partitionOptions(options))); setDebug(Utils.getFlag('V', options)); setSeedUnseenClasses(Utils.getFlag('U', options)); String optionString = Utils.getOption('I', options); if (optionString.length() != 0) { setMaxIterations(Integer.parseInt(optionString)); } optionString = Utils.getOption('S', options); if (optionString.length() != 0) { setSeed(Integer.parseInt(optionString)); } optionString = Utils.getOption('L', options); if (optionString.length() != 0) { setLambda(Double.parseDouble(optionString)); } } /** * Reset to default options */ protected void resetOptions () { m_max_iterations = 100; m_rseed = 100; m_verbose = false; m_seedUnseenClasses = false; m_Classifier = new NaiveBayesSimpleSoft(); m_Lambda = 1.0; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "random number seed"; } /** * Set the random number seed * * @param s the seed */ public void setSeed (int s) { m_rseed = s; } /** * Get the random number seed * * @return the seed */ public int getSeed () { return m_rseed; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxIterationsTipText() { return "maximum number of EM iterations"; } /** * Set the maximum number of iterations to perform * * @param i the number of iterations * @exception Exception if i is less than 1 */ public void setMaxIterations (int i) throws Exception { if (i < 1) { throw new Exception("Maximum number of iterations must be > 0!"); } m_max_iterations = i; } /** * Get the maximum number of iterations * * @return the number of iterations */ public int getMaxIterations () { return m_max_iterations; } /** * Set debug mode - verbose output * * @param v true for verbose output */ public void setDebug (boolean v) { m_verbose = v; } /** * Get debug mode * * @return true if debug mode is set */ public boolean getDebug () { return m_verbose; } public void setSeedUnseenClasses (boolean v) { m_seedUnseenClasses = v; } public boolean getSeedUnseenClasses () { return m_seedUnseenClasses; } public String seedUnseenClassesTipText() { return "create soft seeds for unseen classes using farthest-first"; } public void setLambda (double v) { m_Lambda = v; } public double getLambda () { return m_Lambda; } public String lambdaTipText() { return "set weight of unlabeled examples vs. labeled"; } /** * Set the classifier for boosting. * * @param newClassifier the Classifier to use. */ public void setClassifier(SoftClassifier newClassifier) { m_Classifier = newClassifier; } /** * Get the classifier used as the classifier * * @return the classifier used as the classifier */ public SoftClassifier getClassifier() { return m_Classifier; } public String classifierTipText() { return "Base SoftClassifier to use for underlying probabilistic classification"; } /** * Gets the current settings of EM. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { String [] classifierOptions = new String [0]; if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { classifierOptions = ((OptionHandler)m_Classifier).getOptions(); } String [] options = new String [classifierOptions.length + 10]; int current = 0; if (m_verbose) { options[current++] = "-V"; } if (m_seedUnseenClasses) { options[current++] = "-U"; } options[current++] = "-I"; options[current++] = "" + m_max_iterations; options[current++] = "-S"; options[current++] = "" + m_rseed; options[current++] = "-L"; options[current++] = "" + m_Lambda; if (getClassifier() != null) { options[current++] = "-W"; options[current++] = getClassifier().getClass().getName(); } options[current++] = "--"; System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length); current += classifierOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Provide unlabeled data to the classifier. * @unlabeled the unlabeled Instances */ public void setUnlabeled(Instances unlabeled){ m_UnlabeledData = unlabeled; } /** Simple constructor, must set options using command line or GUI */ public SemiSupEM() { resetOptions(); } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { if (data.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (data.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Can't handle a numeric class!"); } if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } m_LabeledInstances = data; // Add "hard" soft-labeled instances of labeled data to the data for EM m_AllInstances = new SoftClassifiedInstances(data); Random m_Random = new Random(m_rseed); // Make random soft-labeled instances for unlabeled data m_UnlabeledInstances = new SoftClassifiedInstances(m_UnlabeledData, m_Random); if (m_Lambda != 1.0) weightInstances(m_UnlabeledInstances, m_Lambda); // Add the unlabeled data to the complete data set m_AllInstances.addInstances(m_UnlabeledInstances); initModel(); if (m_verbose) { System.out.println("Labeled Data Classes: "); Enumeration enumInsts = m_LabeledInstances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); System.out.print(m_AllInstances.classAttribute().value((int)instance.classValue()) + " "); } System.out.println("\nNum Unlabeled: " + m_UnlabeledInstances.numInstances() ); // System.out.println("Labeled data: " + m_LabeledInstances); // System.out.println("Unlabeled data: " + m_UnlabeledInstances); } if (m_UnlabeledInstances.numInstances() != 0) iterate(); } /** Weighted all given instances with given weight */ protected void weightInstances (Instances insts, double weight) { Enumeration enumInsts = insts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); instance.setWeight(weight); } } /** Intialize model using appropriate set of data */ protected void initModel() throws Exception { SoftClassifiedInstances seedInstances = new SoftClassifiedInstances(m_LabeledInstances); if (m_seedUnseenClasses && m_UnlabeledInstances.numInstances() != 0) { List unseenClasses = unseenClasses(seedInstances); if (!unseenClasses.isEmpty()) { if (m_verbose) System.out.println("Unseen classes: " + unseenClasses); // Add a seed instance for the unseen classes that is soft labeled equally // in all unkown classes. Instance farthest = farthestInstance(m_UnlabeledInstances, seedInstances); softLabelClasses((SoftClassifiedInstance)farthest, unseenClasses); if (m_verbose) System.out.println("Seeded Instance: " + classDistributionString((SoftClassifiedInstance)farthest)); seedInstances.add(farthest); } } m_Classifier.buildClassifier(seedInstances); } /** Return a list of class values for which there are no * instances in insts */ protected ArrayList unseenClasses(Instances insts) { int[] classCounts = new int[insts.numClasses()]; Enumeration enumInsts = insts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance inst = (Instance) enumInsts.nextElement(); classCounts[(int)inst.classValue()]++; } ArrayList result = new ArrayList(); for (int i = 0; i < insts.numClasses(); i++) { if (classCounts[i] == 0) { result.add(new Integer(i)); } } return result; } /** Return the instance in candidateInsts that is farthest from any instance * in insts */ protected Instance farthestInstance(Instances candidateInsts, Instances insts) { double maxDist = Double.NEGATIVE_INFINITY; Instance farthestInst = null; double dist; setMinMax(m_AllInstances); Enumeration enumInsts = candidateInsts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance candidate = (Instance) enumInsts.nextElement(); dist = minimumDistance(candidate, insts); if (dist > maxDist) { maxDist = dist; farthestInst = candidate; } } return farthestInst; } /** Return the distance from inst to the closest instance in insts */ protected double minimumDistance(Instance inst, Instances insts) { double minDist = Double.POSITIVE_INFINITY; double dist; Enumeration enumInsts = insts.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance X = (Instance) enumInsts.nextElement(); dist = distance(inst, X); if (dist < minDist) { minDist = dist; } } return minDist; } /** Soft label inst as being equally likely to be in an of the given classes */ protected void softLabelClasses(SoftClassifiedInstance inst, List classes) throws Exception { double prob = 1.0/classes.size(); double[] dist = new double[((Instance)inst).dataset().numClasses()]; for (int i = 0; i < classes.size(); i++) { dist[((Integer)classes.get(i)).intValue()] = prob; } inst.setClassDistribution(dist); } /** Run EM iterations until likelihood stops increasing significantly or max iterations exhausted */ protected void iterate() throws Exception { double logLikelihood, oldLogLikelihood; logLikelihood = 0; oldLogLikelihood = 0; for (int i = 0; i < m_max_iterations; i++) { // if (m_verbose) { // System.out.println(m_Classifier); // } oldLogLikelihood = logLikelihood; logLikelihood = eStep(); if (m_verbose) { System.out.println("\nIteration " + i + ": LogLikelihood = " + logLikelihood + "\n\n"); } if ( (i > 0) && ((logLikelihood - oldLogLikelihood) < m_minLogLikelihoodIncr)) break; mStep(); } } protected double eStep() throws Exception { double logLikelihood = 0; double classifiedCorrect = 0; double[] dist; Enumeration enumInsts = m_UnlabeledInstances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); dist = m_Classifier.unNormalizedDistributionForInstance(instance); // instance.setClassDistribution(dist); // System.out.println("Instance:" + instance + " Dist: " + classDistributionString(instance)); logLikelihood += logSum(dist); NaiveBayesSimple.normalizeLogs(dist); // System.out.println("Norm Dist: " + classDistributionString((SoftClassifiedInstance)instance)); ((SoftClassifiedInstance)instance).setClassDistribution(dist); if (m_verbose) { // System.out.println(classDistributionString(instance)); if (Utils.maxIndex(dist) == (int)instance.classValue()) { classifiedCorrect++; } } } if (m_verbose) { System.out.println("\nAccuracy on Unlabeled: " + classifiedCorrect/ m_UnlabeledInstances.numInstances()); } enumInsts = m_LabeledInstances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); dist = m_Classifier.unNormalizedDistributionForInstance(instance); logLikelihood += logSum(dist); } return logLikelihood/m_AllInstances.numInstances(); } /** Sums log of probabilities using special method for summing in log space */ public double logSum(double[] logProbs) { double sum = 0; double max = logProbs[Utils.maxIndex(logProbs)]; for (int i = 0; i < logProbs.length; i++) { sum += Math.exp(logProbs[i] - max); } return max + Math.log(sum); } protected String classDistributionString(SoftClassifiedInstance inst) { double[] dist = inst.getClassDistribution(); StringBuffer text = new StringBuffer(); Attribute classAtt = m_AllInstances.classAttribute(); text.append(classAtt.value((int)((Instance)inst).classValue()) + " | "); for (int i = 0; i < m_AllInstances.numClasses(); i++) { text.append(classAtt.value(i) + ":" + dist[i] + " "); } return text.toString(); } protected void mStep() throws Exception { m_Classifier.buildClassifier(m_AllInstances); } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed */ public double[] distributionForInstance(Instance instance) throws Exception { double[] dist = m_Classifier.unNormalizedDistributionForInstance(instance); NaiveBayesSimple.normalizeLogs(dist); return dist; } /** * Calculates the distance between two instances * * @param first the first instance * @param second the second instance * @return the distance between the two given instances */ protected double distance(Instance first, Instance second) { double diff, distance = 0; Instances dataset = first.dataset(); for(int i = 0; i < dataset.numAttributes(); i++) { if (i == dataset.classIndex()) { continue; } if (dataset.attribute(i).isNominal()) { // If attribute is nominal if (first.isMissing(i) || second.isMissing(i) || ((int)first.value(i) != (int)second.value(i))) { distance += 1; } } else { // If attribute is numeric if (first.isMissing(i) || second.isMissing(i)){ if (first.isMissing(i) && second.isMissing(i)) { diff = 1; } else { if (second.isMissing(i)) { diff = norm(first.value(i), i); } else { diff = norm(second.value(i), i); } if (diff < 0.5) { diff = 1.0 - diff; } } } else { diff = norm(first.value(i), i) - norm(second.value(i), i); } distance += diff * diff; } } return distance; } /** * Normalizes a given value of a numeric attribute. * * @param x the value to be normalized * @param i the attribute's index */ protected double norm(double x,int i) { if (Double.isNaN(m_MinArray[i]) || Utils.eq(m_MaxArray[i], m_MinArray[i])) { return 0; } else { return (x - m_MinArray[i]) / (m_MaxArray[i] - m_MinArray[i]); } } /** Compute and store min max values for each numeric feature */ protected void setMinMax(Instances insts) { m_MinArray = new double [insts.numAttributes()]; m_MaxArray = new double [insts.numAttributes()]; for (int i = 0; i < insts.numAttributes(); i++) { m_MinArray[i] = m_MaxArray[i] = Double.NaN; } Enumeration enum = insts.enumerateInstances(); while (enum.hasMoreElements()) { updateMinMax((Instance) enum.nextElement()); } } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance */ protected void updateMinMax(Instance instance) { Instances dataset = instance.dataset(); for (int j = 0;j < dataset.numAttributes(); j++) { if ((dataset.attribute(j).isNumeric()) && (!instance.isMissing(j))) { if (Double.isNaN(m_MinArray[j])) { m_MinArray[j] = instance.value(j); m_MaxArray[j] = instance.value(j); } else { if (instance.value(j) < m_MinArray[j]) { m_MinArray[j] = instance.value(j); } else { if (instance.value(j) > m_MaxArray[j]) { m_MaxArray[j] = instance.value(j); } } } } } } /** * Main method for testing this class. * * @param argv the options */ // public static void main(String [] argv) { // try { // NaiveBayesSimpleSoft baseClassifier = new NaiveBayesSimpleSoft(); // baseClassifier.setMinStdDev(.15); // Instances instances = new Instances(new BufferedReader(new FileReader(argv[0]))); // instances.setClassIndex(instances.numAttributes() - 1); // SemiSupEM emClassifier = new SemiSupEM(); // emClassifier.resetOptions(); // emClassifier.setClassifier(baseClassifier); // emClassifier.setDebug(true); // // emClassifier.setUnlabeledSeeding(true); // Random random = new Random(); // instances.randomize(random); // int numLabeled = Integer.parseInt(argv[1]); // Instances labeledInsts = new Instances(instances, 0, numLabeled); // Instances unlabeledInsts = new Instances(instances, numLabeled, (instances.numInstances() - numLabeled)); // emClassifier.setUnlabeled(unlabeledInsts); // emClassifier.buildClassifier(labeledInsts); // } catch (Exception e) { // System.err.println(e.getMessage()); // } // } public static void main(String [] argv) { try { Instances instances = new Instances(new BufferedReader(new FileReader(argv[0]))); instances.setClassIndex(instances.numAttributes() - 1); Random random = new Random(Integer.parseInt(argv[2])); instances.randomize(random); int numLabeled = Integer.parseInt(argv[1]); Instances labeledInsts = new Instances(instances, 0, numLabeled); Instances unlabeledInsts = new Instances(instances, numLabeled, (instances.numInstances() - numLabeled)); SemiSupEM emClassifier = new SemiSupEM(); emClassifier.setOptions(argv); emClassifier.setUnlabeled(unlabeledInsts); emClassifier.buildClassifier(labeledInsts); } catch (Exception e) { System.err.println(e.getMessage()); } } }