/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * LPAssigner.java * LP-based assignment for K-Means following Kleinberg&Tardos * Copyright (C) 2004 Misha Bilenko * */ package weka.clusterers.assigners; import java.io.*; import java.util.*; import weka.core.*; import weka.core.metrics.*; import weka.clusterers.*; import weka.clusterers.assigners.*; import jmatlink.JMatLink; public class LPAssigner extends MPCKMeansAssigner { /** fields to be initialized from m_clusterer */ protected Instances m_instances = null; protected HashMap m_constraintHash = null; protected int m_numInstances = 0; protected int m_numClusters = 0; protected int m_numConstraints = 0; protected int m_numCLConstraints = 0; protected int m_numMLConstraints = 0; protected int m_numLabelVars = 0; protected int m_numConstraintVars = 0; protected int m_numVars = 0; protected boolean m_useMultipleMetrics = false; protected Metric m_metric = null; protected LearnableMetric[] m_metrics = null; protected double[] m_maxCLDistances = null; protected Instances m_centroids = null; /** Different engines that can be used to solve the LP */ public static final int ENGINE_JMATLINK = 1; public static final int ENGINE_OCTAVE = 2; public static final int ENGINE_MATLAB = 4; public static final int ENGINE_TOMLAB = 8; public static final Tag[] TAGS_ENGINE_TYPE = { new Tag(ENGINE_JMATLINK, "Matlab via JMatLink"), new Tag(ENGINE_OCTAVE, "Octave"), new Tag(ENGINE_MATLAB, "Matlab"), new Tag(ENGINE_TOMLAB, "TomLab via Matlab") }; /** The engine*/ protected int m_engineType = ENGINE_MATLAB; /** The matlab engine */ protected JMatLink m_engine = null; /** Engine auxiliary files */ /** Path to the directory where temporary files will be stored */ protected String m_tempDirPath = new String("/tmp/"); protected File m_tempDirFile = null; protected String m_progFilename = new String(m_tempDirPath + "LPAssigner.m"); protected String m_dataFilenameBase = new String("data"); protected String m_dataFilename = null; protected String m_outFilenameBase = new String("output"); protected String m_outFilename = null; /** This is a sequential assignment method */ public boolean isSequential() { return false; } /** Initialize fields from the current clustererer */ protected void initialize() throws Exception { if (m_clusterer != null) { m_instances = m_clusterer.getInstances(); m_numInstances = m_instances.numInstances(); m_constraintHash = m_clusterer.getConstraintsHash(); m_numConstraints = m_constraintHash.size(); m_numMLConstraints = 0; m_numCLConstraints = 0; // go through the constraints and count ML and CL Iterator pairItr = ((Set) m_constraintHash.keySet()).iterator(); while(pairItr.hasNext()) { InstancePair pair = (InstancePair) pairItr.next(); int linkType = ((Integer) m_constraintHash.get(pair)).intValue(); if (linkType == InstancePair.MUST_LINK) { m_numMLConstraints++; } else if (linkType == InstancePair.CANNOT_LINK) { m_numCLConstraints++; } } System.out.println(m_numConstraints +" total constraints: " + m_numMLConstraints + " must-links and " + m_numCLConstraints + " cannot-links"); m_numClusters = m_clusterer.getNumClusters(); m_useMultipleMetrics = m_clusterer.getUseMultipleMetrics(); m_metric = m_clusterer.getMetric(); m_metrics = m_clusterer.getMetrics(); m_centroids = m_clusterer.getClusterCentroids(); if (m_clusterer.m_maxCLPoints != null) { m_maxCLDistances = calculateMaxDistances(m_clusterer.m_maxCLPoints); } } else { System.err.println("\n******Clusterer is null in LPAssigner.initialize()!\n******"); } } /** The main method * @return the number of points that changed assignment */ public int assign() throws Exception { int moved = 0; initialize(); // open the engine if (m_engineType == ENGINE_JMATLINK) { if (m_engine == null) { m_engine = new JMatLink(); } m_engine.engOpen(); } /** formulate the LP **/ // Coefficients of the objective function. Consist of the following: // 1) distortion coeffs x_{ij} - indexed as currCluster*numInstances+currInstance; // x_{ij}=1 iff i-th instance belongs to j-th cluster // 2) constraint coeffs WRT cluster j - indexed as currConstraint*numClusters+currCluster // y_{ij}=1 iff i-th constraint is violated and either 1st or 2nd instance belongs to j-th cluster m_numLabelVars = m_numInstances * m_numClusters; m_numConstraintVars = m_numConstraints * m_numClusters; m_numVars = m_numLabelVars + m_numConstraintVars; System.out.println("m_numLabelVars=" + m_numLabelVars + "\tm_numConstraintVars=" + m_numConstraintVars); double [] objCoeffs = new double[m_numVars]; accumulateDistortionCoeffs(objCoeffs); accumulateConstraintCoeffs(objCoeffs); // create the array of equality constraints (sum of probs for each instance is 1) double[][] A_eq = new double[m_numInstances][m_numVars]; for (int instanceIdx = 0; instanceIdx < m_numInstances; instanceIdx++) { for (int clusterIdx = 0; clusterIdx < m_numClusters; clusterIdx++) { A_eq[instanceIdx][clusterIdx * m_numInstances + instanceIdx] = 1; } } double[] b_eq = new double[m_numInstances]; for (int instanceIdx = 0; instanceIdx < m_numInstances; instanceIdx++) { b_eq[instanceIdx] = 1; } // create the array of inequality constraints (positivity + 2perML + 2perCL) System.out.println("allocating for A: " + (m_numVars + 2*m_numConstraints*m_numClusters) + "x" + m_numVars + " (numConstraints=" + m_numConstraints); double[][] A = new double[m_numVars + 2*m_numMLConstraints*m_numClusters + 2*m_numCLConstraints*m_numClusters][m_numVars]; System.out.println("done allocating for A: " + A.length + "x" + A[0].length); double [] b = new double[m_numVars + 2*m_numMLConstraints*m_numClusters + 2*m_numCLConstraints*m_numClusters]; // positivity for (int i = 0; i < m_numVars; i++) { A[i][i] = -1; b[i] = 0; } // Constraint vars Iterator pairItr = ((Set) m_constraintHash.keySet()).iterator(); int idx = 0; int offset = m_numVars; while(pairItr.hasNext()) { InstancePair pair = (InstancePair) pairItr.next(); int linkType = ((Integer) m_constraintHash.get(pair)).intValue(); if (linkType == InstancePair.MUST_LINK) { for (int centroidIdx = 0; centroidIdx < m_numClusters; centroidIdx++) { A[offset+2*idx*m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.first] = 1; A[offset+2*idx*m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.second] = -1; A[offset+2*idx*m_numClusters + centroidIdx][m_numLabelVars + idx * m_numClusters + centroidIdx] = -1; A[offset+2*idx*m_numClusters + m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.first] = -1; A[offset+2*idx*m_numClusters + m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.second] = 1; A[offset+2*idx*m_numClusters + m_numClusters + centroidIdx][m_numLabelVars + idx * m_numClusters + centroidIdx] = -1; } } else if (linkType == InstancePair.CANNOT_LINK) { for (int centroidIdx = 0; centroidIdx < m_numClusters; centroidIdx++) { A[offset+2*idx*m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.first] = -1; A[offset+2*idx*m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.second] = -1; A[offset+2*idx*m_numClusters + centroidIdx][m_numLabelVars + idx * m_numClusters + centroidIdx] = 1; A[offset+2*idx*m_numClusters + m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.first] = 1; A[offset+2*idx*m_numClusters + m_numClusters + centroidIdx][centroidIdx * m_numInstances + pair.second] = 1; A[offset+2*idx*m_numClusters + m_numClusters + centroidIdx][m_numLabelVars + idx * m_numClusters + centroidIdx] = -1; b[offset+2*idx*m_numClusters + m_numClusters + centroidIdx] = 1; } } idx++; } /** Send the LP to the engine and get back the solution **/ double[][] probs = null; if (m_engineType == ENGINE_OCTAVE || m_engineType == ENGINE_MATLAB || m_engineType == ENGINE_TOMLAB ) { dumpData(objCoeffs, A_eq, b_eq, A, b); prepareEngine(); runEngine(); probs = getSolution(); } else if (m_engineType == ENGINE_JMATLINK) { m_engine.engPutArray("f", objCoeffs); m_engine.engPutArray("Aeq", A_eq); m_engine.engPutArray("beq", b_eq); m_engine.engPutArray("A", A); m_engine.engPutArray("b", b); // solve the LP m_engine.engEvalString("x = linprog(f,A,b,Aeq,beq)"); // get the solution back probs = m_engine.engGetArray("x"); m_engine.engClose(); } else { throw new Exception("Unknown engine type: " + m_engineType); } if (m_clusterer.getVerbose()) { for (int i = 0; i < probs.length; i++) { for (int j = 0; j < probs[i].length; j++) { System.out.print(((float)probs[i][j]) + "\t"); } } } /** Get cluster assignments from the solution probabilistically */ int [] assignments = new int [m_numInstances]; Arrays.fill(assignments, -1); int numAssigned = 0; Random r = new Random(m_clusterer.getRandomSeed()); int phase = 0; int m_maxPhases = 5000; while (numAssigned < m_numInstances && phase < m_maxPhases) { // pick a random label int clusterIdx = r.nextInt(m_numClusters); double alpha = r.nextDouble(); for (int i = 0; i < m_numInstances; i++) { if (assignments[i] == -1) { if (probs[clusterIdx * m_numInstances + i][0] >= alpha) { assignments[i] = clusterIdx; numAssigned++; } } } phase++; } /****/ /**** Compare to default assigner */ /****/ SimpleAssigner simple = new SimpleAssigner(m_clusterer); int [] clusterAssignments = m_clusterer.getClusterAssignments(); int [] oldAssignments = new int[m_numInstances]; int [] simpleAssignments = new int[m_numInstances]; // backup assignments before E-step for (int i = 0; i < m_numInstances; i++) { oldAssignments[i] = clusterAssignments[i]; } // get assignments with default E-step simple.assign(); for (int i = 0; i < m_numInstances; i++) { simpleAssignments[i] = clusterAssignments[i]; // restore assignments to state before E-step clusterAssignments[i] = oldAssignments[i]; } // number of differences between default and RMN assignments int numDiff = 0; int numSame = 0; int totalDiff = 0; boolean invalidAssignments = false; // Make new cluster assignments, count num moved System.out.println(phase + " phases; " + numAssigned + "/" + m_numInstances + " assigned"); double ratioMissassigned = 0; double ratioNonMissassigned = 0; for (int i = 0; i < m_numInstances; i++) { if (clusterAssignments[i] != assignments[i]) { // System.out.println("Moving instance " + i + " from cluster " + clusterAssignments[i] + " to cluster " + assignments[i]); clusterAssignments[i] = assignments[i]; moved++; } // count number of constraint violations for this point HashMap instanceConstraintHash = m_clusterer.getInstanceConstraintsHash(); int numViolated = 0; int numTotal = 0; Object list = instanceConstraintHash.get(new Integer(i)); if (list != null) { // there are constraints associated with this instance ArrayList constraintList = (ArrayList) list; numTotal = constraintList.size(); for (int j = 0; j < constraintList.size(); j++) { InstancePair pair = (InstancePair) constraintList.get(j); int firstIdx = pair.first; int secondIdx = pair.second; int centroidIdx = (firstIdx == i) ? clusterAssignments[firstIdx] : clusterAssignments[secondIdx]; int otherIdx = (firstIdx == i) ? clusterAssignments[secondIdx] : clusterAssignments[firstIdx]; // check whether the constraint is violated if (otherIdx != -1 && otherIdx < m_numClusters) { if (otherIdx != centroidIdx && pair.linkType == InstancePair.MUST_LINK) { numViolated++; } else if (otherIdx == centroidIdx && pair.linkType == InstancePair.CANNOT_LINK) { numViolated++; } } } } // compare to simpleAssignments if (clusterAssignments[i] != simpleAssignments[i]) { totalDiff++; } double ratio = (numTotal == 0) ? 0 : ((numViolated+0.0)/numTotal); if (numTotal > 0) { if (clusterAssignments[i] != simpleAssignments[i]) { numDiff++; // System.out.println("MISSASSIGNED; violated/total = " + numViolated + "/" + numTotal + "\t=" + ((float) ratio)); // check where it would be assigned without taking constraints into account: // KLUDGE-ish, assuming a single metric double closestDistance = Double.MAX_VALUE; int centroidIdx = -1; Instance instance = m_instances.instance(i); for (int j = 0; j < m_numClusters; j++) { Instance centroid = m_clusterer.getClusterCentroids().instance(j); double distance = m_clusterer.getMetric().distance(centroid, instance); if (distance < closestDistance) { closestDistance = distance; centroidIdx = j; } } System.out.println("ASSIGNED to: " + clusterAssignments[i] + "; SimpleAssigner assigns to: " + simpleAssignments[i] + "; without constraints closest centroid: " + centroidIdx); ratioMissassigned += ratio; } else { // System.out.println("NOT MISASSIGNED; violated/total = " + numViolated + "/" + numTotal + "\t=" + ((float) ratio)); ratioNonMissassigned += ratio; numSame++; } } } System.out.println("Total missassigned: " + totalDiff); System.out.println("\tAVG for misassigned: " + ((float) (ratioMissassigned/numDiff)) + "\n\tAVG for non-misassigned: " + ((float) (ratioNonMissassigned/numSame))); System.out.println("Moved " + moved + " points in RMN inference E-step"); /****/ /**** End of comparing to default assigner */ /****/ return moved; } /** go through all instances and all clusters and accumulate the distortion contributions */ protected void accumulateDistortionCoeffs(double [] objCoeffs) throws Exception { for (int centroidIdx = 0; centroidIdx < m_numClusters; centroidIdx++) { Instance centroid = m_centroids.instance(centroidIdx); for (int instanceIdx = 0; instanceIdx < m_numInstances; instanceIdx++) { Instance instance = m_instances.instance(instanceIdx); int coeffIdx = centroidIdx * m_numInstances + instanceIdx; if (!m_clusterer.isObjFunDecreasing()) { // increasing obj. function if (m_useMultipleMetrics) { // multiple metrics objCoeffs[coeffIdx] = m_metrics[centroidIdx].similarity(instance, centroid); } else { objCoeffs[coeffIdx] = m_metric.similarity(instance, centroid); } } else { // decreasing obj. function if (m_useMultipleMetrics) { // multiple metrics objCoeffs[coeffIdx] = m_metrics[centroidIdx].distance(instance, centroid); } else { objCoeffs[coeffIdx] = m_metric.distance(instance, centroid); } } } } } /** Accumulate contribution from constraints */ protected void accumulateConstraintCoeffs(double [] objCoeffs) throws Exception{ if (m_constraintHash != null) { Set pointPairs = (Set) m_constraintHash.keySet(); Iterator pairItr = pointPairs.iterator(); int idx = 0; while( pairItr.hasNext() ){ InstancePair pair = (InstancePair) pairItr.next(); addPairPenalties(pair, idx, objCoeffs); idx++; } } } /** accumulate penalties associated with a given constraint */ protected void addPairPenalties(InstancePair pair, int idx, double[] objCoeffs) throws Exception { int instance1Idx = pair.first; int instance2Idx = pair.second; Instance instance1 = m_instances.instance(instance1Idx); Instance instance2 = m_instances.instance(instance2Idx); int linkType = ((Integer) m_constraintHash.get(pair)).intValue(); double cost = 0; if (linkType == InstancePair.MUST_LINK) { cost = m_clusterer.getMustLinkWeight(); } else if (linkType == InstancePair.CANNOT_LINK) { cost = m_clusterer.getCannotLinkWeight(); } // if a single metric is used, we don't need to calculate separately for each cluster if (!m_useMultipleMetrics) { // MAJOR KLUDGE. TODO: create penalty(InstancePair) method in MPCKMeans; use both internally and here; // avoid iterating through constraints inside individual calculateConstraintPenalties methods double penalty = 0; // add the penalty for different types of metrics if (m_metric instanceof WeightedDotP) { double sim = m_metric.similarity(instance1, instance2); if (linkType == InstancePair.MUST_LINK) { penalty = -cost * (1 - sim); } else if (linkType == InstancePair.CANNOT_LINK) { penalty = -cost * sim; } } else if (m_metric instanceof KL) { double distance = ((KL) m_metric).distanceJS(instance1, instance2); if (linkType == InstancePair.MUST_LINK) { penalty = cost * distance; } else if (linkType == InstancePair.CANNOT_LINK) { penalty = cost * (2.0 - distance); } } else if (m_metric instanceof WeightedEuclidean || m_metric instanceof WeightedMahalanobis) { double distance = m_metric.distance(instance1, instance2); if (linkType == InstancePair.MUST_LINK) { penalty = cost * distance * distance; } else if (linkType == InstancePair.CANNOT_LINK) { penalty = cost * (m_maxCLDistances[0] * m_maxCLDistances[0] - distance * distance); } } else { throw new Exception("Unknown metric: " + m_metric.getClass().getName()); } // y_m = 0.5 sum_j (y_{mj}) if (linkType == InstancePair.MUST_LINK) { penalty = 0.5 * penalty; } else { // penalty = -0.5 * penalty; } int offset = m_numLabelVars; for (int centroidIdx = 0; centroidIdx < m_numClusters; centroidIdx++) { objCoeffs[offset + idx * m_numClusters + centroidIdx] += penalty; } } else { // MULTIPLE METRICS // KLUDGE - TODO - CURRENTLY WRONG! // for (int centroidIdx1 = 0; centroidIdx1 < m_numClusters; centroidIdx1++) { // for (int centroidIdx2 = 0; centroidIdx2 < m_numClusters; centroidIdx2++) { // double penalty = 0; // if (m_metric instanceof WeightedDotP) { // double sim1 = m_metrics[centroidIdx1].similarity(instance1, instance2); // double sim2 = m_metrics[centroidIdx2].similarity(instance1, instance2); // penalty = 0.5 * cost * (1 - sim2) + 0.5 * cost * (1 - sim1); // } else if (m_metric instanceof KL) { // double penalty1 = ((KL) m_metrics[centroidIdx1]).distanceJS(instance1, instance2); // double penalty2 = ((KL) m_metrics[centroidIdx2]).distanceJS(instance1, instance2); // penalty = 0.5 * cost * (penalty1 + penalty2); // } else if (m_metric instanceof WeightedEuclidean || m_metric instanceof WeightedMahalanobis) { // double distance1 = m_metrics[centroidIdx1].distance(instance1, instance2); // double distance2 = m_metrics[centroidIdx2].distance(instance1, instance2); // penalty = 0.5 * cost * (distance1*distance1 + distance2*distance2); // } else { // throw new Exception("Unknown metric: " + m_metric.getClass().getName()); // } // objCoeffs[centroidIdx1 * m_numInstances + instance1Idx] += penalty; // objCoeffs[centroidIdx1 * m_numInstances + instance2Idx] += penalty; // objCoeffs[centroidIdx2 * m_numInstances + instance1Idx] += penalty; // objCoeffs[centroidIdx2 * m_numInstances + instance2Idx] += penalty; // } // } } } /** * Dump data matrix into a file */ protected void dumpData(double[] objCoeffs, double[][] A_eq, double[] b_eq, double[][] A, double[] b) { if (m_engineType == ENGINE_TOMLAB) { dumpDataTomLab(objCoeffs,A_eq, b_eq, A,b); } else { try { File dataFile = File.createTempFile(m_dataFilenameBase, ".m", m_tempDirFile); m_dataFilename = dataFile.getPath(); if (!m_clusterer.getVerbose()) { dataFile.deleteOnExit(); } PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(dataFile))); // dump f writer.print("f = ["); for (int i = 0; i < objCoeffs.length; i++) { writer.print(objCoeffs[i] + "; "); } writer.println("];"); // dump Aeq if (m_engineType != ENGINE_OCTAVE) { writer.print("Aeq = ["); for (int i = 0; i < A_eq.length; i++) { for (int j = 0; j < A_eq[i].length; j++) { writer.print(A_eq[i][j] + ", "); } writer.flush(); writer.println(";"); } writer.println("];"); } else { // for octave, we dump into a separate file... PrintWriter writerAeq = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_tempDirPath + "Aeq"))); for (int i = 0; i < A_eq.length; i++) { for (int j = 0; j < A_eq[i].length; j++) { writerAeq.print(A_eq[i][j] + " "); } writerAeq.flush(); writerAeq.println(); } writerAeq.close(); } // dump b writer.print("beq = ["); for (int i = 0; i < b_eq.length; i++) { writer.print(b_eq[i] + "; "); } writer.println("];"); // dump A PrintWriter writerA = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_dataFilename + ".A"))); for (int i = 0; i < A.length; i++) { for (int j = 0; j < A[i].length; j++) { writerA.print(A[i][j] + " "); } writerA.println(); } writerA.flush(); writerA.close(); // dump b writer.print("b = ["); for (int i = 0; i < b.length; i++) { writer.print(b[i] + "; "); } writer.println("];"); writer.close(); } catch (Exception e) { System.err.println("Could not create temporary file \'" + m_dataFilename + "\' for dumping the LP: " + e); } } } /** * Dump data matrix into a file */ protected void dumpDataTomLab(double[] objCoeffs, double[][] A_eq, double[] b_eq, double[][] A, double[] b) { try { File dataFile = File.createTempFile(m_dataFilenameBase, ".m", m_tempDirFile); m_dataFilename = dataFile.getPath(); if (!m_clusterer.getVerbose()) { dataFile.deleteOnExit(); } PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(dataFile))); // dump f writer.print("f = ["); for (int i = 0; i < objCoeffs.length; i++) { writer.print(objCoeffs[i] + "; "); } writer.println("];"); // dump xl and xu writer.println("xl = zeros(" + m_numVars + ",1);"); writer.println("xu = ones(" + m_numVars + ",1);"); // dump bu writer.print("bu = [ones(" + m_numInstances + ",1); "); for (int i = m_numVars; i < b.length; i++) { writer.print(b[i] + "; "); } writer.println("];"); // dump bl writer.println("bl = [ones(" + m_numInstances + ",1); zeros(" + b.length + "-" + m_numVars + ",1)];"); writer.println("bl(" + (m_numInstances+1) + ":" + (m_numInstances + b.length - m_numVars) + ",:)=-Inf;"); writer.close(); // dump A into a separate file PrintWriter writerA = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_dataFilename + ".A"))); File aFile = new File(m_dataFilename + ".A"); aFile.deleteOnExit(); // first, dump Aeq for (int i = 0; i < A_eq.length; i++) { for (int j = 0; j < A_eq[i].length; j++) { writerA.print(A_eq[i][j] + " "); } writerA.flush(); writerA.println(); } // next, dump constraints from A for (int i = m_numVars; i < A.length; i++) { for (int j = 0; j < A[i].length; j++) { writerA.print(A[i][j] + " "); } writerA.flush(); writerA.println(); } writerA.close(); } catch (Exception e) { System.err.println("Could not create temporary file \'" + m_dataFilename + "\' for dumping the LP: " + e); } } /** Read the solution from the output file of Octave */ protected double[][] getSolution() { double[][] probs = new double[m_numLabelVars][1]; try { BufferedReader r = new BufferedReader(new FileReader(m_outFilename)); String s = null; int i = 0; while ((s = r.readLine()) != null && i < m_numLabelVars) { probs[i++][0] = Double.parseDouble(s); } } catch (Exception e) { System.out.println("Problems reading the solution from the engine: " + e); e.printStackTrace(); } File aFile = new File(m_dataFilename + ".A"); aFile.delete(); File dataFile = new File(m_dataFilename); dataFile.delete(); return probs; } /** Create octave m-file * @param filename file where the script is created */ public void prepareEngine() { try{ PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_progFilename))); writer.println("cd " + m_tempDirPath + ";"); String dataFilename = Utils.removeSubstring(m_dataFilename, m_tempDirPath); dataFilename = Utils.removeSubstring(dataFilename, ".m"); writer.println(dataFilename + ";"); switch (m_engineType) { case ENGINE_MATLAB: writer.println("A = load(\'" + m_dataFilename + ".A" + "\');"); writer.println("x = linprog(f,A,b,Aeq,beq);"); break; case ENGINE_TOMLAB: writer.println("cd /u/ml/software/tomlab;"); writer.println("startup;"); writer.println("A = load(\'" + m_dataFilename + ".A" + "\');"); writer.println("Prob = lpAssign(f,A,bl,bu,xl,xu,[],'test');"); writer.println("Result = tomRun('pdco', Prob,[],1);"); writer.println("x = Result.x_k;"); break; case ENGINE_OCTAVE: // load A and Aeq stored in auxiliary files writer.println("load A;"); writer.println("load Aeq;"); } File outFile = File.createTempFile(m_outFilenameBase, ".out", m_tempDirFile); m_outFilename = outFile.getPath(); if (!m_clusterer.getVerbose()) { outFile.deleteOnExit(); File outFileDump = new File(m_outFilename + ".dump"); outFileDump.deleteOnExit(); } writer.println("x"); writer.println("save " + m_outFilename + " x -ascii;"); writer.close(); } catch (Exception e) { System.err.println("Could not create script file \'" + m_progFilename + "\': " + e); } } /** Run octave in command line with a given argument * @param inFile file to be input * @param outFile file where results are stored */ public int runEngine() { int exitValue = -1; try { String cmd = ""; if (m_engineType == ENGINE_OCTAVE) { cmd = "octave " + m_progFilename + " > " + m_outFilename; } else if (m_engineType == ENGINE_MATLAB || m_engineType == ENGINE_TOMLAB) { cmd = "matlab -nodesktop -nosplash < " + m_progFilename + " > " + m_outFilename + ".dump"; } System.out.println("Starting to run engine " + m_engineType + cmd); Process proc = Runtime.getRuntime().exec(cmd); System.out.println("Waiting for process ..."); // read the error if (proc != null){ BufferedReader procError = new BufferedReader(new InputStreamReader(proc.getErrorStream())); try { String line; while ((line = procError.readLine()) != null){ System.out.println("ERROR: " + line); } } catch (Exception e) { System.err.println("Problems trapping error stream in debug mode: " + e); e.printStackTrace(); } } // read the output if (proc != null){ BufferedReader procOutput = new BufferedReader(new InputStreamReader(proc.getInputStream())); try { String line; while ((line = procOutput.readLine()) != null){ System.out.println("OUTPUT: " + line); } } catch (Exception e) { System.err.println("Problems trapping output in debug mode: " + e); e.printStackTrace(); } } exitValue = proc.waitFor(); System.out.println("End of running engine, exitValue = " + exitValue); } catch (Exception e) { System.err.println("Problems running engine: " + e); e.printStackTrace(); } return exitValue; } protected double[] calculateMaxDistances(Instance maxCLPoints[][]) throws Exception { double [] maxCLDistances = new double[maxCLPoints.length]; for (int i = 0; i < maxCLDistances.length; i++) { if (m_useMultipleMetrics) { maxCLDistances[i] = m_metrics[i].distance(maxCLPoints[i][0], maxCLPoints[i][1]); } else { maxCLDistances[i] = m_metric.distance(maxCLPoints[0][0], maxCLPoints[0][1]); } } return maxCLDistances; } /** Set the engine type * @param type one of the kernel types */ public void setEngineType(SelectedTag engineType) { if (engineType.getTags() == TAGS_ENGINE_TYPE) { m_engineType = engineType.getSelectedTag().getID(); } } /** Get the engine type * @return engine type */ public SelectedTag getEngineType() { return new SelectedTag(m_engineType, TAGS_ENGINE_TYPE); } public void setOptions (String[] options) throws Exception { // TODO } public Enumeration listOptions () { // TODO return null; } public String [] getOptions () { String[] options = new String[1]; int current = 0; switch (m_engineType) { case ENGINE_JMATLINK: options[current++] = "jmatlink"; break; case ENGINE_OCTAVE: options[current++] = "octave"; break; case ENGINE_MATLAB: options[current++] = "matlab"; break; case ENGINE_TOMLAB: options[current++] = "tomlab"; break; default: options[current++] = "unknown"; } return options; } }