/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MetricLearner.java * Copyright (C) 2002 Mikhail Bilenko * */ package weka.core.metrics; import java.util.*; import java.io.Serializable; import java.io.*; import java.text.SimpleDateFormat; import weka.classifiers.*; import weka.classifiers.sparse.*; import weka.classifiers.functions.*; import weka.core.*; /** * ClassifierMetricLearner - learns metric parameters by constructing * "difference instances" and then learning weights that classify same-class * instances as positive, and different-class instances as negative. * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.4 $ */ public class ClassifierMetricLearner extends MetricLearner implements Serializable, OptionHandler { /** Classifier that is used for learning metric weights */ protected Classifier m_classifier = new SVMlight(); /** Class attribute for diff-instances can be either nominal or numeric */ protected boolean m_isDiffClassNominal = true; /** The metric that the classifier was used to learn, useful for external-calculation based metrics */ protected LearnableMetric m_metric = null; /** The pairwise selector used by the metric */ protected PairwiseSelector m_selector = new HardPairwiseSelector(); protected int m_numPosPairs = 200; protected int m_numNegPairs = 200; /** Create a new classifier metric learner */ public ClassifierMetricLearner() { } /** * Train a given metric using given training instances * * @param metric the metric to train * @param instances data to train the metric on * @exception Exception if training has gone bad. */ public void trainMetric(LearnableMetric metric, Instances instances) throws Exception { // If the data doesn't have a class attribute, bail if (instances.classIndex() < 0 || instances.numInstances() < 2) { metric.m_trained = false; System.out.println("Problem with training data"); return; } if (metric.getExternal()) { m_metric = metric; } ArrayList pairList = m_selector.createPairList(instances, m_numPosPairs, m_numNegPairs, metric); Instances diffInstances = createDiffInstances(pairList, metric); if (diffInstances == null) { metric.m_trained = false; System.out.println("null diffInstances"); return; } // BEGIN SANITY CHECK if (true) { // dump diff-instances into a temporary file try { File diffDir = new File("/tmp/diff"); diffDir.mkdir(); String diffName = instances.relationName() + "." + Utils.removeSubstring(metric.getClass().getName(), "weka.core.metrics.") + "." + Utils.removeSubstring(m_selector.getClass().getName(), "weka.core.metrics."); if (m_selector instanceof HardPairwiseSelector) { diffName = diffName + ((HardPairwiseSelector)m_selector).getNegativesMode().getSelectedTag().getReadable(); diffName = diffName + ((HardPairwiseSelector)m_selector).getPositivesMode().getSelectedTag().getReadable(); } diffInstances.setRelationName(diffName); PrintWriter writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" + diffName + ".arff"))); writer.println(diffInstances.toString()); writer.close(); // Do a sanity check - dump out the diffInstances, and // evaluation classification with an SVM. long trainTimeStart = System.currentTimeMillis(); SVMlight classifier = new SVMlight(); Evaluation eval = new Evaluation(diffInstances); eval.crossValidateModel(classifier, diffInstances, 3); writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" + diffName + ".dat", true))); writer.println(eval.pctCorrect()); writer.close(); System.out.println("** Record Sanity:" + (System.currentTimeMillis() - trainTimeStart) + " ms; " + eval.pctCorrect() + "% correct\t" + eval.numFalseNegatives(0) + "(" + (eval.falseNegativeRate(0)*100) + "%) false negatives\t" + eval.numFalsePositives(0) + "(" + (eval.falsePositiveRate(0)*100) + "%) false positives\t"); } catch (Exception e) { e.printStackTrace(); System.out.println(e.toString()); } } // END SANITY CHECK System.out.println(getTimestamp()+ " Building classifier: " + Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers.") + " " + concatStringArray(((OptionHandler)m_classifier).getOptions())); m_classifier.buildClassifier(diffInstances); // if we are learning coefficients, put them back into the distance metric if (!metric.getExternal()) { String classifierName = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers"); if (classifierName.equals("functions.LinearRegression")) { double[] coefficients = ((LinearRegression)m_classifier).coefficients(); System.out.println("Learned coefficients " + coefficients.length); metric.setWeights(coefficients); } else if (classifierName.equals("functions.SMO")) { FastVector weights = ((SMO)m_classifier).weights(); double[] m_sparseWeights = (double[]) weights.elementAt(0); int[] m_sparseIndices = (int[]) weights.elementAt(1); double[] coefficients = new double[metric.getNumAttributes()]; for (int i = 0; i < m_sparseIndices.length; i++) { coefficients[m_sparseIndices[i]] = m_sparseWeights[i]; } metric.setWeights(coefficients); } } System.out.println(getTimestamp() + " Done building " + Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers.")); metric.m_trained = true; } /** * Set the classifier * * @param classifier the classifier */ public void setClassifier (Classifier classifier) { m_classifier = classifier; } /** * Get the classifier * * @returns the classifierthat this metric employs */ public Classifier getClassifier () { return m_classifier; } /** Set the pairwise selector * @param selector the selector for training pairs */ public void setSelector (PairwiseSelector selector) { m_selector = selector; } /** Get the pairwise selector * @return the selector for training pairs */ public PairwiseSelector getSelector() { return m_selector; } /** Set the number of same-class training pairs * @param numPosPairs the number of same-class training pairs to create for training */ public void setNumPosPairs(int numPosPairs) { m_numPosPairs = numPosPairs; } /** Get the number of same-class training pairs * @return the number of same-class training pairs to create for training */ public int getNumPosPairs() { return m_numPosPairs; } /** Set the number of different-class training pairs * @param numNegPairs the number of different-class training pairs to create for training */ public void setNumNegPairs(int numNegPairs) { m_numNegPairs = numNegPairs; } /** Get the number of different-class training pairs * @return the number of different-class training pairs to create for training */ public int getNumNegPairs() { return m_numNegPairs; } /** * Use the Classifier for an estimation of similarity * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns sim an approximate similarity obtained from the classifier */ public double getSimilarity(Instance instance1, Instance instance2) throws Exception{ Instance diffInstance = m_metric.createDiffInstance(instance1, instance2); double d = (((DistributionClassifier)m_classifier).distributionForInstance(diffInstance))[1]; return Math.exp(-d); } /** * Use the Classifier for an estimation of distance * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns an approximate distance obtained from the classifier */ public double getDistance(Instance instance1, Instance instance2) throws Exception{ Instance diffInstance = m_metric.createDiffInstance(instance1, instance2); double d = (((DistributionClassifier)m_classifier).distributionForInstance(diffInstance))[1]; return d; } /** * Gets the current settings of WeightedDotP. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [40]; int current = 0; options[current++] = "-p"; options[current++] = "" + m_numPosPairs; options[current++] = "-n"; options[current++] = "" + m_numNegPairs; options[current++] = "-S"; options[current++] = Utils.removeSubstring(m_selector.getClass().getName(), "weka.core.metrics."); if (m_selector instanceof OptionHandler) { String[] selectorOptions = ((OptionHandler)m_selector).getOptions(); for (int i = 0; i < selectorOptions.length; i++) { options[current++] = selectorOptions[i]; } } options[current++] = "-B"; options[current++] = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers."); if (m_classifier instanceof OptionHandler) { String[] classifierOptions = ((OptionHandler)m_classifier).getOptions(); for (int i = 0; i < classifierOptions.length; i++) { options[current++] = classifierOptions[i]; } } while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. Valid options are:<p> * * -B classifierstring */ public void setOptions(String[] options) throws Exception { String classifierString = Utils.getOption('B', options); if (classifierString.length() == 0) { throw new Exception("A classifier must be specified" + " with the -B option."); } String [] classifierSpec = Utils.splitOptions(classifierString); if (classifierSpec.length == 0) { throw new Exception("Invalid classifier specification string"); } String classifierName = classifierSpec[0]; classifierSpec[0] = ""; System.out.println("Classifier name: " + classifierName + "\nClassifier parameters: " + weka.classifiers.sparse.IBkMetric.concatStringArray(classifierSpec)); setClassifier(Classifier.forName(classifierName, classifierSpec)); } /** * Gets a string containing current date and time. * * @return a string containing the date and time. */ protected static String getTimestamp() { return (new SimpleDateFormat("HH:mm:ss:")).format(new Date()); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(1); newVector.addElement(new Option( "\tFull class name of classifier to use, followed\n" + "\tby scheme options. (required)\n" + "\teg: \"weka.classifiers.bayes.NaiveBayes -D\"", "B", 1, "-B <classifier specification>")); return newVector.elements(); } /** Obtain a textual description of the metriclearner * @return a textual description of the metric learner */ public String toString() { return new String("ClassifierMetricLearner " + concatStringArray(getOptions())); } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } }