/** * Copyright 2014 Marco Cornolti * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package it.acubelab.smaph.learn; import it.acubelab.smaph.SmaphUtils; import java.io.*; import java.util.*; import libsvm.*; /** * A gatherer of examples for a binary classifier. An example is a pair * <features, expected result> generated from an instance, where * expected_result is either 'positive' or 'negative'. An instance may generate * zero or more pairs. Pairs generated from the same instance form a group. */ public class BinaryExampleGatherer { private Vector<Vector<double[]>> positiveFeatureVectors = new Vector<>(); private Vector<Vector<double[]>> negativeFeatureVectors = new Vector<>(); private int ftrCount = -1; /** * Add all examples of an instance, forming a new group. * * @param posVectors * feature vectors of the positive examples. * @param negVectors * feature vectors of the negative examples. */ public void addExample(Vector<double[]> posVectors, Vector<double[]> negVectors) { { Vector<double[]> mergedFtrVects = new Vector<>(); mergedFtrVects.addAll(posVectors); mergedFtrVects.addAll(negVectors); for (double[] ftrVect : mergedFtrVects) { if (ftrCount == -1) ftrCount = ftrVect.length; if (ftrCount != ftrVect.length) throw new RuntimeException( "Adding feature of a wrong size. ftrCount=" + ftrCount + " passed array size=" + ftrVect.length); } } positiveFeatureVectors.add(posVectors); negativeFeatureVectors.add(negVectors); } /** * @return a libsvm problem (that is, a list of examples) including all * features. */ public svm_problem generateLibSvmProblem() { return generateLibSvmProblem(SmaphUtils.getAllFtrVect(this .getFtrCount())); } /** * @param pickedFtrs * the list of features to pick. * @return a libsvm problem (that is, a list of examples) including only * features given in pickedFtrs. */ public svm_problem generateLibSvmProblem(Vector<Integer> pickedFtrs) { Vector<Double> targets = new Vector<Double>(); Vector<svm_node[]> ftrVectors = new Vector<svm_node[]>(); for (double[] posVect : getPlain(positiveFeatureVectors)) { ftrVectors .add(LibSvmUtils.featuresArrayToNode(posVect, pickedFtrs)); targets.add(1.0); } for (double[] negVect : getPlain(negativeFeatureVectors)) { ftrVectors .add(LibSvmUtils.featuresArrayToNode(negVect, pickedFtrs)); targets.add(-1.0); } svm_problem problem = new svm_problem(); problem.l = targets.size(); problem.x = new svm_node[problem.l][]; for (int i = 0; i < problem.l; i++) problem.x[i] = ftrVectors.elementAt(i); problem.y = new double[problem.l]; for (int i = 0; i < problem.l; i++) problem.y[i] = targets.elementAt(i); return problem; } /** * @param pickedFtrsI * the list of features to pick. * @return a list of libsvm problems, one per instance. */ public List<svm_problem> generateLibSvmProblemOnePerInstance( Vector<Integer> pickedFtrsI) { Vector<svm_problem> result = new Vector<>(); for (int i = 0; i < positiveFeatureVectors.size(); i++) { Vector<double[]> posFtrVect = positiveFeatureVectors.get(i); Vector<double[]> negFtrVect = negativeFeatureVectors.get(i); Vector<Double> targets = new Vector<Double>(); Vector<svm_node[]> ftrVectors = new Vector<svm_node[]>(); for (double[] posVect : posFtrVect) { ftrVectors.add(LibSvmUtils.featuresArrayToNode(posVect, pickedFtrsI)); targets.add(1.0); } for (double[] negVect : negFtrVect) { ftrVectors.add(LibSvmUtils.featuresArrayToNode(negVect, pickedFtrsI)); targets.add(-1.0); } svm_problem problem = new svm_problem(); problem.l = targets.size(); problem.x = new svm_node[problem.l][]; for (int j = 0; j < problem.l; j++) problem.x[j] = ftrVectors.elementAt(j); problem.y = new double[problem.l]; for (int j = 0; j < problem.l; j++) problem.y[j] = targets.elementAt(j); result.add(problem); } return result; } /** * @return the number of examples. */ public int getExamplesCount() { int count = 0; for (Vector<double[]> positiveFeatureVector : positiveFeatureVectors) count += positiveFeatureVector.size(); for (Vector<double[]> negativeFeatureVector : negativeFeatureVectors) count += negativeFeatureVector.size(); return count; } private static Vector<double[]> getPlain(Vector<Vector<double[]>> vectVect) { Vector<double[]> res = new Vector<>(); for (Vector<double[]> vect : vectVect) res.addAll(vect); return res; } /** * Dump the examples to a file. * * @param filename * where to write the dump. * @throws IOException * in case of error while writing the file. */ public void dumpExamplesLibSvm(String filename) throws IOException { BufferedWriter wr = new BufferedWriter(new FileWriter(filename, false)); for (double[] posVect : getPlain(positiveFeatureVectors)) writeLine(posVect, wr, true); for (double[] negVect : getPlain(negativeFeatureVectors)) writeLine(negVect, wr, false); wr.close(); } private void writeLine(double[] ftrVect, BufferedWriter wr, boolean positive) throws IOException { String line = positive ? "+1 " : "-1 "; for (int ftr = 0; ftr < ftrVect.length; ftr++) line += String.format("%d:%.9f ", ftr + 1, ftrVect[ftr]); wr.write(line + "\n"); } /** * @return the number of features of the gathered examples. */ public int getFtrCount() { return ftrCount; } }