/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MetricLearner.java * Copyright (C) 2002 Mikhail Bilenko * */ package weka.core.metrics; import java.util.*; import java.io.*; import java.text.SimpleDateFormat; import weka.core.*; /** * MatlabMetricLearner - learns metric parameters by constructing * "difference instances" and then learning weights that classify same-class * instances as positive, and different-class instances as negative using an * external Matlab program. * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.1 $ */ public class MatlabMetricLearner extends MetricLearner implements Serializable { /** Matlab program that is used for learning metric weights */ protected String m_scriptFilename = new String("/tmp/matlab1.m"); /** Name of the temporary file where the matrix representing the same-class diff. instances is going to be */ protected String m_posMatrixFilename = new String("/tmp/posMatrix.txt"); /** Name of the temporary file where the matrix representing the diff-class diff. instances is going to be */ protected String m_negMatrixFilename = new String("/tmp/negMatrix.txt"); /** Name of the temporary file where the weights will be stored by Matlab after calculation */ protected String m_weightsFilename = new String("/tmp/weights.txt"); /** Debugging output */ protected boolean m_debug = true; /** Create a new matlab metric learner */ public MatlabMetricLearner() { } /** * Train a given metric using given training instances * * @param metric the metric to train * @param instances data to train the metric on * @exception Exception if training has gone bad. */ public void trainMetric(LearnableMetric metric, Instances instances) throws Exception { // If the data doesn't have a class attribute, bail if (instances.classIndex() < 0) { return; } // First, create positive and negative diff-instances ArrayList[] diffInstanceLists = createDiffInstanceLists(instances, metric, metric.getNumPosDiffInstances(), metric.getPosNegDiffInstanceRatio()); ArrayList posDiffInstanceList = diffInstanceLists[0]; ArrayList negDiffInstanceList = diffInstanceLists[1]; prepareMatlabScript(); dumpInstanceList(posDiffInstanceList, m_posMatrixFilename); dumpInstanceList(negDiffInstanceList, m_negMatrixFilename); runMatlab(m_scriptFilename, "matlab.out"); double[] coefficients = readVector(m_weightsFilename); if (m_debug) System.out.println(getTimestamp() + " Read " + coefficients.length + " coefficients"); for (int i = 0; i < coefficients.length; i++) { // coefficients[i] = (coefficients[i]+1)/2; } metric.setWeights(coefficients); } /** Create matlab m-file for PCA * @param filename file where matlab script is created */ public void prepareMatlabScript() { try{ PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_scriptFilename))); // writer.println("function w = fitMetricWeights() "); writer.println("S = load('" + m_posMatrixFilename + "'); "); writer.println("D = load('" + m_negMatrixFilename + "'); "); writer.println("[mD,n] = size(D); "); writer.println("[mS,n] = size(S); "); writer.println(""); writer.println("lb = zeros(n, 1); "); writer.println("ub = ones(n, 1); "); writer.println("x0 = ones(n, 1)/sqrt(n); "); writer.println(""); writer.println("b = 2* norm(S*x0)/mS * ones(mD, 1); "); writer.println("w = fmincon(inline('1/norm(S*x)', 'x', 'S'), x0, D, b, [], [], lb, ub, [],[],S);"); writer.println("w = w/norm(w)"); writer.println("save " + m_weightsFilename + " w -ASCII -DOUBLE;"); writer.close(); } catch (Exception e) { System.err.println("Could not create matlab file: " + e); } } /** Run matlab in command line with a given argument * @param inFile file to be input to Matlab * @param outFile file where results are stored */ public void runMatlab(String inFile, String outFile) { // call matlab to do the dirty work try { int exitValue; do { if (m_debug) System.out.println(getTimestamp() + " starting Matlab"); Process proc = Runtime.getRuntime().exec("matlab -tty < " + inFile + " > " + outFile); exitValue = proc.waitFor(); if (exitValue != 0) { System.err.println(getTimestamp() + " WARNING!!!!! Matlab returned exit value 1, trying again later!"); Thread.sleep(300000); } } while (exitValue != 0); if (m_debug) System.out.println(getTimestamp() + " Matlab done"); } catch (Exception e) { System.err.println("Problems running matlab: " + e); } } /** * Gets a string containing current date and time. * * @return a string containing the date and time. */ protected static String getTimestamp() { return (new SimpleDateFormat("HH:mm:ss:")).format(new Date()); } /** Read a column vector from a text file * @param name file name * @returns double[] array corresponding to a vector */ public double[] readVector(String name) throws Exception { BufferedReader r = new BufferedReader(new FileReader(name)); int numAttributes = -1; ArrayList vectorList = new ArrayList(); String s; while ((s = r.readLine()) != null) { try { vectorList.add(new Double(s)); } catch (Exception e) { System.err.println("Couldn't parse " + s + " as double"); } } int length = vectorList.size(); double [] vector = new double[length]; for (int i = 0; i < length; i++) { vector[i] = ((Double) vectorList.get(i)).doubleValue(); } return vector; } /** Dump a list of instances as a matrix of attribute values * @param instanceList a list of instances * @param filename name of the file where the matrix is saved */ public void dumpInstanceList(ArrayList instanceList, String filename) { try { PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(filename))); int numInstances = instanceList.size(); for (int i = 0; i < numInstances; i++) { Instance instance = (Instance) instanceList.get(i); int numAttributes = instance.numAttributes(); int classIdx = instance.classIndex(); for (int j = 0; j < numAttributes; j++) { if (j != classIdx) { writer.print(instance.value(j) + " "); } } writer.println(); } writer.close(); } catch (Exception e) { System.err.println("Could not create a temporary file for dumping the instance list: " + e); } } /** * Use Matlab for an estimation of similarity * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns sim an approximate similarity obtained from the classifier */ public double getSimilarity(Instance instance1, Instance instance2) throws Exception{ throw new Exception("MatlabMetricLearner cannot be used as an external distance metric!"); } /** * Use Matlab for an estimation of distance * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns sim an approximate distance obtained from the classifier */ public double getDistance(Instance instance1, Instance instance2) throws Exception{ throw new Exception("MatlabMetricLearner cannot be used as an external distance metric!"); } }