package rainbownlp.machinelearning;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import rainbownlp.machinelearning.convertor.SVMLightFormatConvertor;
import rainbownlp.util.ConfigurationUtil;
import rainbownlp.util.HibernateUtil;
import rainbownlp.util.SystemUtil;
public abstract class SVMLightBasedLearnerEngine extends LearnerEngine {
public void train(List<MLExample> pTrainExamples) throws Exception {
ConfigurationUtil.TrainingMode = true;
HibernateUtil.clearLoaderSession();
//This part added since the session was so slow
List<Integer> train_example_ids = new ArrayList<Integer>();
for(MLExample example : pTrainExamples)
{
train_example_ids.add(example.getExampleId());
}
if(isBinaryClassification())
trainFile = SVMLightFormatConvertor.writeToFileBinary(train_example_ids, taskName);
else
trainFile = SVMLightFormatConvertor.writeToFile(train_example_ids, taskName);
String myShellScript = getTrainCommand();
SystemUtil.runShellCommand(myShellScript);
}
public void test(List<MLExample> pTestExamples) throws Exception{
File model = new File(getModelFilePath());
if(!model.exists()) {
throw(new Exception("Model file is missing, train before test: "+modelFile));
}
ConfigurationUtil.TrainingMode = false;
List<Integer> test_example_ids = new ArrayList<Integer>();
String exampleids = "";
for(MLExample example : pTestExamples)
{
exampleids = exampleids.concat(","+example.getExampleId());
test_example_ids.add(example.getExampleId());
}
exampleids = exampleids.replaceFirst(",", "");
String resetQuery = "update MLExample set predictedClass = -1 where exampleId in ("+ exampleids +")";
HibernateUtil.executeNonReader(resetQuery, true);
String resultFile = modelFile+"_result.txt";
if(isBinaryClassification())
testFile = SVMLightFormatConvertor.writeToFileBinary(test_example_ids, taskName);
else
testFile = SVMLightFormatConvertor.writeToFile(test_example_ids, taskName);
SystemUtil.runShellCommand(getTestCommand(resultFile));
File f=new File(resultFile);
if (!f.exists()) {
throw(new Exception("SVM result not generated!"));
}
// 2. read classification output and update database
FileReader fileR = new FileReader(resultFile);
BufferedReader reader = new BufferedReader(fileR);
int counter = 0;
while (counter<pTestExamples.size() && reader.ready()) {
String line = reader.readLine();
Double classNum = -1D;
Double weight = 0D;
String[] lineParts = line.split(" ");
if(isBinaryClassification()){
weight = Double.parseDouble(lineParts[0]);
classNum = 1D;
if(weight>0)
classNum=2D;
}else{
classNum = Double.parseDouble(lineParts[0]) - 1;//convert to index (e.g. 1 -> 0)
if(lineParts.length>classNum+1)
weight = Double.parseDouble(lineParts[classNum.intValue()]);
}
pTestExamples.get(counter).setPredictedClass(classNum.toString());
pTestExamples.get(counter).setPredictionWeight(weight);
MLExample test = pTestExamples.get(counter);
String savePredictedQuery = "update MLExample set predictedClass ="+test.getPredictedClass()+" , predictionWeight = " +
weight+" where exampleId="+test.getExampleId();
HibernateUtil.executeNonReader(savePredictedQuery);
counter++;
}
assert !reader.ready() : "Something wrong file remained, updated rows:"+counter;
assert counter==pTestExamples.size() : "Something wrong resultset remained, updated rows:"+counter;
reader.close();
}
protected abstract boolean isBinaryClassification();
protected abstract String getTrainCommand();
protected abstract String getTestCommand(String resultFile);
}