/** * Copyright 2014 Marco Cornolti * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package it.acubelab.smaph.learn; import it.acubelab.smaph.SmaphAnnotatorDebugger; import java.io.*; import java.util.HashMap; import java.util.Vector; import libsvm.svm; import libsvm.svm_model; import libsvm.svm_node; public abstract class LibSvmFilter { private svm_model model; private double[] rangeMins, rangeMaxs; private String modelFile; private String rangeFile; public LibSvmFilter(String modelFile, String rangeFile) throws IOException { setModel(modelFile); this.rangeFile = rangeFile; resetRanges(); } public abstract double[] featuresToFtrVect(HashMap<String, Double> features); public boolean predict(HashMap<String, Double> features) { svm_node[] ftrVect = LibSvmUtils.featuresArrayToNode(featuresToFtrVect(features)); LibSvmUtils.scaleNode(ftrVect, rangeMins, rangeMaxs); return svm.svm_predict(model, ftrVect) > 0; } public void resetRanges() { Vector<String[]> tokensVect = new Vector<>(); try { BufferedReader reader = new BufferedReader(new InputStreamReader( new FileInputStream(rangeFile))); String line; while ((line = reader.readLine()) != null) { String[] tokens = line.split(" "); if (tokens.length == 3) tokensVect.add(tokens); } reader.close(); } catch (IOException e) { e.printStackTrace(); throw new RuntimeException(e); } rangeMins = new double[tokensVect.size()]; rangeMaxs = new double[tokensVect.size()]; for (String[] tokens : tokensVect) { int featureId = Integer.parseInt(tokens[0]); float rangeMin = Float.parseFloat(tokens[1]); float rangeMax = Float.parseFloat(tokens[2]); rangeMins[featureId - 1] = rangeMin; rangeMaxs[featureId - 1] = rangeMax; SmaphAnnotatorDebugger.out.printf("Feature %d range: [%.3f, %.3f]%n", featureId, rangeMin, rangeMax); } } public String getModel() { return modelFile; } public void setModel(String modelFile) { this.modelFile = modelFile; try { this.model = svm.svm_load_model(modelFile); } catch (IOException e) { e.printStackTrace(); throw new RuntimeException(e); } } }