/** * Filename: SVMSupport.java (in org.redpin.server.standalone.svm) * This file is part of the Redpin project. * * Redpin is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published * by the Free Software Foundation, either version 3 of the License, or * any later version. * * Redpin is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with Redpin. If not, see <http://www.gnu.org/licenses/>. * * (c) Copyright ETH Zurich, Luba Rogoleva, Philipp Bolliger, 2010, ALL RIGHTS RESERVED. * * www.redpin.org */ package org.redpin.server.standalone.svm; import java.io.BufferedWriter; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.util.Collections; import java.util.Hashtable; import java.util.List; import java.util.Vector; import java.util.logging.Level; import org.libsvm.core.svm_predict; import org.libsvm.core.svm_scale; import org.libsvm.core.svm_train; import org.redpin.server.standalone.core.Fingerprint; import org.redpin.server.standalone.core.Location; import org.redpin.server.standalone.core.Measurement; import org.redpin.server.standalone.core.measure.WiFiReading; import org.redpin.server.standalone.db.HomeFactory; import org.redpin.server.standalone.util.Log; /** * * @author Luba Rogoleva (lubar@student.ethz.ch) * */ public class SVMSupport { public final static String TRAIN = "train.1"; public final static String TEST = "test.1"; public final static String TEMP = "temp"; public final static String TRAIN_SCALE = "train.1.scale"; public final static String TEST_SCALE = "test.1.scale"; public final static String RANGE = "range1"; public final static String OUT = "out"; public final static String MODEL_EXT = ".model"; public final static String TRAIN_SCRIPT = "train.sh"; private static int ACTIVE_MODEL = 0; private static boolean trained = false; /** * Train SVM */ public static void train() { Log.getLogger().log(Level.FINE, "Starting SVM train."); Log.getLogger().log(Level.FINE, "Building categories..."); CategorizerFactory.buildCategories(); List<Measurement> setupdata = HomeFactory.getMeasurementHome().getAll(); if (setupdata == null || setupdata.size() == 0) return; Log.getLogger().log(Level.FINE, "Transforming data to the format of an SVM package..."); transformToSVMFormat(setupdata, TRAIN, false); Log.getLogger().log(Level.FINE, "Creating the model..."); int nextModel = Math.abs(ACTIVE_MODEL - 1); if (!runScript(TRAIN_SCRIPT, new String[] {nextModel+""})) { String[] scaleargs = {"-l","-1","-u","1","-s",RANGE,TRAIN}; String[] args={"-t","0","-c","512","-q",TRAIN_SCALE+nextModel,TRAIN_SCALE+nextModel+MODEL_EXT}; svm_train t = new svm_train(); svm_scale s = new svm_scale(); try { s.run(scaleargs, TRAIN_SCALE+nextModel); t.run(args); } catch (IOException e) { Log.getLogger().log(Level.SEVERE, "Failed to create SVM model: " + e.getMessage()); } } synchronized(SVMSupport.class) { ACTIVE_MODEL = nextModel; } Log.getLogger().log(Level.FINE, "SVM train finished.."); trained = true; } /** * Predict * @param m {@link Measurement} * @return path to result file */ public static synchronized String predict(final Measurement m) { File modelfile = new File(TRAIN_SCALE+ACTIVE_MODEL+MODEL_EXT); File outputfile = new File(OUT); Vector<Measurement> testMeasurements = new Vector<Measurement>(); testMeasurements.add(m); transformToSVMFormat(testMeasurements, TEST, true); try { String[] scaleargs = {"-r", RANGE, TEST}; svm_scale s = new svm_scale(); s.run(scaleargs, TEST_SCALE); String[] args={TEST_SCALE,modelfile+"",outputfile+""}; svm_predict.main(args); } catch (FileNotFoundException e) { Log.getLogger().log(Level.SEVERE, "predict failed due to FileNotFoundException: " + e.getMessage()); } catch (IOException e) { Log.getLogger().log(Level.SEVERE, "predict failed due to IOException: " + e.getMessage()); } return OUT; } public static boolean isTrained() { return trained; } /** * Function transforms data (measurements) to the format of an SVM package. * Each measurement is represented as a vector of real numbers. * @param data - list of measurements * @param fileName - destination file name */ public synchronized static void transformToSVMFormat(final List<Measurement> data, String fileName, boolean isNew) { File testfile = new File(fileName); try { BufferedWriter writertest = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(testfile))); Hashtable<Integer, Integer> rssis = new Hashtable<Integer, Integer>(); Vector<Integer> sarray = new Vector<Integer>(); for (Measurement m : data){ Integer categoryID = -1; if (!isNew) { categoryID = getLocationCategory(m); if (categoryID == -1) continue; } StringBuffer line = new StringBuffer(); line.append(categoryID); for (WiFiReading r : m.getWiFiReadings()) { if (r != null && r.getBssid() != null) { Integer id = CategorizerFactory.BSSIDCategorizer().GetCategoryID(r.getBssid()); if (id != -1 && !rssis.contains(id)) { rssis.put(id, r.getRssi()); sarray.add(id); } } } Collections.sort(sarray); for(int i = 0; i < sarray.size(); i++) { if (rssis.get(sarray.get(i)) != null) line.append(" " + sarray.get(i) + ":" + rssis.get(sarray.get(i))); } line.append("\n"); writertest.write(line.toString()); rssis.clear(); sarray.clear(); } writertest.close(); } catch (FileNotFoundException e) { Log.getLogger().log(Level.SEVERE, "transformToSVMFormat failed due to FileNotFoundException: " + e.getMessage()); } catch (IOException e) { Log.getLogger().log(Level.SEVERE, "transformToSVMFormat failed due to IOException: " + e.getMessage()); } } /** * Runs a shell script * @param scriptName * @return true in case of a successful run */ private static synchronized boolean runScript(String scriptName, String[] args) { String command = "./" + scriptName; for (String arg : args){ command += " " + arg; } try { Process p = Runtime.getRuntime().exec(command); int exitvalue = p.waitFor(); if (exitvalue == 0) return true; } catch (InterruptedException e) { Log.getLogger().log(Level.INFO, "runScript failed due to InterruptedException: " + e.getMessage()); } catch (IOException e) { Log.getLogger().log(Level.INFO, "runScript failed due to IOException: " + e.getMessage()); } return false; } /** * Returns location category * @param m {@link Measurement} * @return Location category id */ private static Integer getLocationCategory(Measurement m) { if (m == null || m.getId() == null) return -1; Fingerprint f = HomeFactory.getFingerprintHome().getByMeasurementId(m.getId()); if (f != null) { Location l = (Location) f.getLocation(); if (l != null && l.getId() != null) { return CategorizerFactory.LocationCategorizer().GetCategoryID(l.getId().toString()); } } return -1; } }