/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ClassifierInstanceMetric.java
* Copyright (C) 2003 Mikhail Bilenko
*
*/
package weka.deduping.metrics;
import java.util.ArrayList;
import java.util.Vector;
import java.util.Enumeration;
import java.util.Date;
import java.text.SimpleDateFormat;
import java.io.*;
import weka.deduping.*;
import weka.core.*;
import weka.classifiers.DistributionClassifier;
import weka.classifiers.sparse.SVMlight;
import weka.classifiers.Evaluation;
/**
* ClassifierInstanceMetric class employs a classifier that uses
* values returned by various StringMetric's on individual fields
* as features and outputs a confidence value that corresponds to
* similarity between records
*
* @author Mikhail Bilenko (mbilenko@cs.utexas.edu)
* @version $Revision: 1.5 $
*/
public class ClassifierInstanceMetric extends InstanceMetric implements OptionHandler, Serializable {
/** Classifier that is used for estimating similarity between records */
protected DistributionClassifier m_classifier = new SVMlight();
/** A selector object that will create training sets */
PairwiseSelector m_selector = new PairwiseSelector();
/** The desired number of training pairs */
protected int m_numPosPairs = 200;
protected int m_numNegPairs = 200;
/** StringMetric prototype that are to be used on each field */
protected StringMetric [] m_stringMetrics = new StringMetric[0];
/** The actual array of metrics */
protected StringMetric [][] m_fieldMetrics = null;
/** A temporary dataset that contains diff-instances for training the classifier */
protected Instances m_diffInstances = null;
/** A default constructor */
public ClassifierInstanceMetric() {
}
/**
* Generates a new ClassifierInstanceMetric that computes
* similarity between records using the specified attributes. Has to
* initialize all metric fields with default string metrics
*
* @param attrIdxs the indeces of attributes that the metric will use
* @exception Exception if the distance metric has not been
* generated successfully. */
public void buildInstanceMetric(int[] attrIdxs) throws Exception {
// initialize the array of metrics for each attribute
m_attrIdxs = attrIdxs;
m_fieldMetrics = new StringMetric[m_stringMetrics.length][m_attrIdxs.length];
for (int i = 0; i < m_stringMetrics.length; i++) {
for (int j = 0; j < m_attrIdxs.length; j++) {
m_fieldMetrics[i][j] = (StringMetric) m_stringMetrics[i].clone();
}
}
}
/**
* Create a new metric for operating on specified instances
* @param trainData instances for training the metric
* @param testData instances that will be used for testing
*/
public void trainInstanceMetric(Instances trainData, Instances testData) throws Exception {
m_selector.initSelector(trainData);
// if we have data-dependent or trainable metrics
// (e.g. vector-space or learnable ED), build them with available
// test/train data
ArrayList [] attrStringLists = null;
for (int i = 0; i < m_stringMetrics.length; i++) {
if (m_stringMetrics[i] instanceof DataDependentStringMetric) {
// populate the list of strings for each attribute now that we need them
if (attrStringLists == null) {
attrStringLists = new ArrayList[m_attrIdxs.length];
for (int j = 0; j < m_attrIdxs.length; j++) {
attrStringLists[j] = getStringList(trainData, testData, m_attrIdxs[j]);
}
}
// initialize the data-dependent metric for each attribute
for (int j = 0; j < m_attrIdxs.length; j++) {
((DataDependentStringMetric)m_fieldMetrics[i][j]).buildMetric(attrStringLists[j]);
}
}
// if the metric is learnable, train it
if (m_stringMetrics[i] instanceof LearnableStringMetric) {
for (int j = 0; j < m_attrIdxs.length; j++) {
ArrayList strPairList = m_selector.getStringPairList(trainData, m_attrIdxs[j],
m_numPosPairs, m_numNegPairs,
m_fieldMetrics[i][j]);
((LearnableStringMetric)m_fieldMetrics[i][j]).trainMetric(strPairList);
}
}
}
// train the classifier
m_diffInstances = m_selector.getInstances(m_attrIdxs, m_fieldMetrics, m_numPosPairs, m_numNegPairs);
// get the stats on actual training data
AttributeStats classStats = m_diffInstances.attributeStats(m_diffInstances.classIndex());
m_numActualPosPairs = classStats.nominalCounts[0];
m_numActualNegPairs = classStats.nominalCounts[1];
// SANITY CHECK - CROSS-VALIDATION
if (false) {
// dump diff-instances into a temporary file
try {
File diffDir = new File("/tmp/diff");
diffDir.mkdir();
String diffName = trainData.relationName() + "." +
Utils.removeSubstring(m_fieldMetrics[0].getClass().getName(), "weka.deduping.metrics.");
m_diffInstances.setRelationName(diffName);
PrintWriter writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" +
diffName + ".arff")));
writer.println(m_diffInstances.toString());
writer.close();
// Do a sanity check - dump out the diffInstances, and
// evaluation classification with an SVM.
long trainTimeStart = System.currentTimeMillis();
SVMlight classifier = new SVMlight();
Evaluation eval = new Evaluation(m_diffInstances);
eval.crossValidateModel(classifier, m_diffInstances, 5);
writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" +
diffName + ".dat", true)));
writer.println(eval.pctCorrect());
writer.close();
System.out.println("** Record Sanity:" + (System.currentTimeMillis() - trainTimeStart) + " ms; " +
eval.pctCorrect() + "% correct\t" +
eval.numFalseNegatives(0) + "(" + eval.falseNegativeRate(0) + "%) false negatives\t" +
eval.numFalsePositives(0) + "(" + eval.falsePositiveRate(0) + "%) false positives\t");
} catch (Exception e) {
e.printStackTrace();
System.out.println(e.toString());
}
}
// END SANITY CHECK
System.out.println(getTimestamp() + ": Building " + m_classifier.getClass().getName());
m_classifier.buildClassifier(m_diffInstances);
System.out.println(getTimestamp() + ": Done building " + m_classifier.getClass().getName());
}
/** An internal method for creating a list of strings for a
* particular attribute from two sets of instances: trianing and
* test data
* @param trainData a dataset of records in the training fold
* @param testData a dataset of records in the testing fold
* @param attrIdx the index of the attribute for which strings are to be collected
* @return a list of strings that occur for this attribute; duplicates are allowed
*/
protected ArrayList getStringList(Instances trainData, Instances testData, int attrIdx) {
ArrayList stringList = new ArrayList();
// go through the training data and get all string values for that attribute
if (trainData != null) {
for (int i = 0; i < trainData.numInstances(); i++) {
Instance instance = trainData.instance(i);
String value = instance.stringValue(attrIdx);
stringList.add(value);
}
}
// go through the test data and get all string values for that attribute
for (int i = 0; i < testData.numInstances(); i++) {
Instance instance = testData.instance(i);
String value = instance.stringValue(attrIdx);
stringList.add(value);
}
return stringList;
}
/**
* Returns distance between two records
* @param instance1 First record.
* @param instance2 Second record.
* @exception Exception if distance could not be calculated.
*/
public double distance(Instance instance1, Instance instance2) throws Exception {
// go through all metrics collecting the values of distances for different attributes
double[] distances = new double[m_attrIdxs.length * m_stringMetrics.length + 1];
int counter = 0;
for (int i = 0; i < m_attrIdxs.length; i++) {
String str1 = instance1.stringValue(m_attrIdxs[i]);
String str2 = instance2.stringValue(m_attrIdxs[i]);
for (int j = 0; j < m_stringMetrics.length; j++) {
if (m_stringMetrics[j].isDistanceBased()) {
distances[counter++] = m_fieldMetrics[j][i].distance(str1, str2);
} else {
distances[counter++] = m_fieldMetrics[j][i].similarity(str1, str2);
}
}
}
Instance diffInstance = new Instance(1.0, distances);
diffInstance.setDataset(m_diffInstances);
return m_classifier.distributionForInstance(diffInstance)[1];
}
/**
* Returns similarity between two records
* @param instance1 First instance.
* @param instance2 Second instance.
* @exception Exception if similarity could not be calculated.
*/
public double similarity(Instance instance1, Instance instance2) throws Exception {
double d = distance(instance1, instance2);
return Math.exp(-d);
}
/** The computation can be either based on distance, or on similarity
* @returns true if the underlying metric computes distance, false if similarity
*/
public boolean isDistanceBased() {
return true;
};
/**
* Set the classifier
*
* @param classifier the classifier
*/
public void setClassifier (DistributionClassifier classifier) {
m_classifier = classifier;
}
/**
* Get the classifier
*
* @returns the classifierthat this metric employs
*/
public DistributionClassifier getClassifier () {
return m_classifier;
}
/**
* Set the baseline metric
*
* @param metrics string metrics that will used on each string attribute
*/
public void setStringMetrics (StringMetric[] metrics) {
m_stringMetrics = metrics;
}
/**
* Get the baseline string metrics
*
* @return the string metrics that are used for each field
*/
public StringMetric[] getStringMetrics () {
return m_stringMetrics;
}
/** Set the pairwise selector for this metric
* @param selector a new pairwise selector
*/
public void setSelector(PairwiseSelector selector) {
m_selector = selector;
}
/** Get the pairwise selector for this metric
* @return the pairwise selector
*/
public PairwiseSelector getSelector() {
return m_selector;
}
/** Set the number of same-class training pairs that is desired
* @param numPosPairs the number of same-class training pairs to be
* created for training the classifier
*/
public void setNumPosPairs(int numPosPairs) {
m_numPosPairs = numPosPairs;
}
/** Get the number of same-class training pairs
* @return the number of same-class training pairs to create for
* training the classifier
*/
public int getNumPosPairs() {
return m_numPosPairs;
}
/** Set the number of different-class training pairs
* @param numNegPairs the number of different-class training pairs
* to create for training the classifier
*/
public void setNumNegPairs(int numNegPairs) {
m_numNegPairs = numNegPairs;
}
/** Get the number of different-class training pairs
* @return the number of different-class training pairs to create
* for training the classifier
*/
public int getNumNegPairs() {
return m_numNegPairs;
}
/**
* Gets a string containing current date and time.
*
* @return a string containing the date and time.
*/
protected static String getTimestamp() {
return (new SimpleDateFormat("HH:mm:ss:")).format(new Date());
}
/** A little helper to create a single String from an array of Strings
* @param strings an array of strings
* @returns a single concatenated string
*/
public static String concatStringArray(String[] strings) {
StringBuffer buffer = new StringBuffer();
for (int i = 0; i < strings.length; i++) {
buffer.append(strings[i]);
buffer.append(" ");
}
return buffer.toString();
}
/**
* Returns an enumeration describing the available options
*
* @return an enumeration of all the available options
**/
public Enumeration listOptions() {
Vector newVector = new Vector(2);
newVector.addElement(new Option("\tMetric.\n"
+"\t(default=AffineMetric)", "M", 1,"-M metric_name metric_options"));
newVector.addElement(new Option("\tClassifier.\n"
+"\t(default=weka.classifiers.functions.SMO)", "C", 1,"-C clasifierName classifierOptions"));
return newVector.elements();
}
/**
* Parses a given list of options.
*
* Valid options are:<p>
*
* -M metric options <p>
* StringMetric used <p>
*
* -C classifier options <p>
* Classifier used <p>
*
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*
**/
public void setOptions(String[] options) throws Exception {
String optionString;
// TODO: implement command-line options
// String metricString = Utils.getOption('M', options);
// if (metricString.length() != 0) {
// String[] metricSpec = Utils.splitOptions(metricString);
// String metricName = metricSpec[0];
// metricSpec[0] = "";
// System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec));
// setMetric(StringMetric.forName(metricName, metricSpec));
// }
String classifierString = Utils.getOption('C', options);
if (classifierString.length() == 0) {
throw new Exception("A classifier must be specified"
+ " with the -C option.");
}
String [] classifierSpec = Utils.splitOptions(classifierString);
if (classifierSpec.length == 0) {
throw new Exception("Invalid classifier specification string");
}
String classifierName = classifierSpec[0];
classifierSpec[0] = "";
System.out.println("Classifier name: " + classifierName + "\nClassifier parameters: " +
concatStringArray(classifierSpec));
setClassifier((DistributionClassifier) DistributionClassifier.forName(classifierName, classifierSpec));
}
/**
* Gets the current settings of Greedy Agglomerative Clustering
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [200];
int current = 0;
if (m_selector instanceof OptionHandler) {
String[] selectorOptions = ((OptionHandler)m_selector).getOptions();
for (int i = 0; i < selectorOptions.length; i++) {
options[current++] = selectorOptions[i];
}
}
options[current++] = "-p";
options[current++] = "" + m_numPosPairs;
options[current++] = "-n";
options[current++] = "" + m_numNegPairs;
options[current++] = "-M" + m_stringMetrics.length;
for (int i = 0; i < m_stringMetrics.length; i++) {
options[current++] = Utils.removeSubstring(m_stringMetrics[i].getClass().getName(), "weka.deduping.metrics.");
if (m_stringMetrics[i] instanceof OptionHandler) {
String[] metricOptions = ((OptionHandler)m_stringMetrics[i]).getOptions();
for (int j = 0; j < metricOptions.length; j++) {
options[current++] = metricOptions[j];
}
}
}
options[current++] = "-C";
options[current++] = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers.");
if (m_classifier instanceof OptionHandler) {
String[] classifierOptions = ((OptionHandler)m_classifier).getOptions();
for (int i = 0; i < classifierOptions.length; i++) {
options[current++] = classifierOptions[i];
}
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
}