package com.maalaang.omtwitter.ml;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import jnisvmlight.KernelParam;
import jnisvmlight.LabeledFeatureVector;
import jnisvmlight.LearnParam;
import jnisvmlight.SVMLightInterface;
import jnisvmlight.SVMLightModel;
import jnisvmlight.TrainingParameters;
public class SvmTrainer {
public static int numOfExamples(String exampleFile) throws NumberFormatException, IOException {
int cnt = 0;
BufferedReader br = new BufferedReader(new FileReader(exampleFile));
while (br.readLine() != null) {
cnt++;
}
br.close();
return cnt;
}
public static void train(String exampleFile, String modelFile, int numOfDocs) throws IOException {
SVMLightInterface trainer = new SVMLightInterface();
SVMLightInterface.SORT_INPUT_VECTORS = false;
LabeledFeatureVector[] trainData = new LabeledFeatureVector[numOfDocs];
BufferedReader br = new BufferedReader(new FileReader(exampleFile));
String line = null;
int docCnt = 0;
while ((line = br.readLine()) != null) {
String[] tokens = line.split("(:| |\t)+");
int label = Integer.parseInt(tokens[0]);
int nDims = (tokens.length - 1) / 2;
int[] dims = new int[nDims];
double[] values = new double[nDims];
for (int i = 1, j = 0; j < nDims; i += 2, j++) {
dims[j] = Integer.parseInt(tokens[i]);
values[j] = Double.parseDouble(tokens[i+1]);
}
trainData[docCnt] = new LabeledFeatureVector(label, dims, values);
docCnt++;
}
TrainingParameters tp = new TrainingParameters();
LearnParam learnParam = tp.getLearningParameters();
learnParam.verbosity = 1;
learnParam.kernel_cache_size = 2048;
learnParam.type = LearnParam.CLASSIFICATION;
KernelParam kernelParam = tp.getKernelParameters();
kernelParam.kernel_type = KernelParam.LINEAR;
SVMLightModel model = trainer.trainModel(trainData, tp);
model.writeModelToFile(modelFile);
br.close();
}
}