package rainbownlp.machinelearning; import java.util.ArrayList; import java.util.List; import rainbownlp.machinelearning.convertor.WekaFormatConvertor; import rainbownlp.util.ConfigurationUtil; import rainbownlp.util.HibernateUtil; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.bayes.NaiveBayes; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.converters.ConverterUtils.DataSource; public class WekaClassifier extends LearnerEngine { String modelFile; String taskName; String trainFile; String testFile; int reinforcedCount = 0; String[] reinforcedModels = new String[reinforcedCount]; public Classifier wekaAlgorithm = new NaiveBayes(); public String[] options = null; public String wekaAlgorithmName = "NaiveBayes"; private WekaClassifier() { } @Override public void train(List<MLExample> pTrainExamples) throws Exception { ConfigurationUtil.TrainingMode = true; setPaths(); //This part added since the session was so slow List<Integer> train_example_ids = new ArrayList<Integer>(); for(MLExample example : pTrainExamples) { train_example_ids.add(example.getExampleId()); } WekaFormatConvertor.writeToFile(train_example_ids, trainFile,taskName, new String[]{"1", "2"}); DataSource source = new DataSource(trainFile); Instances data = source.getDataSet(); // setting class attribute if the data format does not provide this information if (data.classIndex() == -1) data.setClassIndex(data.numAttributes() - 1); if(options!=null) wekaAlgorithm.setOptions(options); // set the options wekaAlgorithm.buildClassifier(data); // build classifier // serialize model SerializationHelper.write(modelFile, wekaAlgorithm); } @Override public void test(List<MLExample> pTestExamples) throws Exception { ConfigurationUtil.TrainingMode = false; List<Integer> test_example_ids = new ArrayList<Integer>(); String exampleids = ""; for(MLExample example : pTestExamples) { exampleids = exampleids.concat(","+example.getExampleId()); test_example_ids.add(example.getExampleId()); } exampleids = exampleids.replaceFirst(",", ""); String resetQuery = "update MLExample set predictedClass = -1 where exampleId in ("+ exampleids +")"; HibernateUtil.executeNonReader(resetQuery); WekaFormatConvertor.writeToFile(test_example_ids, testFile,taskName, new String[]{"1", "2"}); // deserialize model wekaAlgorithm = (Classifier) SerializationHelper.read(modelFile); // classify new examples and update database DataSource source = new DataSource(testFile); Instances testData = source.getDataSet(); // setting class attribute if the data format does not provide this information if (testData.classIndex() == -1) testData.setClassIndex(testData.numAttributes() - 1); System.out.println(pTestExamples.size() +"=="+ testData.numInstances()); int counter = 0; while (counter<pTestExamples.size() && counter<testData.numInstances()) { Double clsLabel = wekaAlgorithm.classifyInstance(testData.instance(counter)); pTestExamples.get(counter).setPredictedClass(clsLabel.toString()); MLExample test = pTestExamples.get(counter); String savePredictedQuery = "update MLExample set predictedClass ="+test.getPredictedClass()+ " where exampleId="+test.getExampleId(); HibernateUtil.executeNonReader(savePredictedQuery); counter++; System.out.println("Processed :"+counter+"/"+pTestExamples.size()); } Evaluation eval = new Evaluation(testData); eval.evaluateModel(wekaAlgorithm, testData); System.out.println("\n====\n"+eval.toSummaryString()+"\n"+eval.toClassDetailsString()+"\n====\n"); } public static LearnerEngine getLearnerEngine(String pTaskName) { WekaClassifier learnerEngine = new WekaClassifier(); learnerEngine.setTaskName(pTaskName); learnerEngine.setPaths(); return learnerEngine; } private void setPaths() { String fold = (ConfigurationUtil.crossFoldCurrent>0)?("Fold"+ConfigurationUtil.crossFoldCurrent):""; setModelFilePath(ConfigurationUtil.getValue("TempFolder")+ fold+"-"+wekaAlgorithmName+ taskName+".weka"); setTrainFilePath(ConfigurationUtil.getValue("TempFolder")+ fold+"-"+wekaAlgorithmName+ "-train-" + taskName + ".arff"); setTestFilePath(ConfigurationUtil.getValue("TempFolder")+ fold+"-" + wekaAlgorithmName+ "-test-" + taskName + ".arff"); } private void setTestFilePath(String pTestFile) { testFile = pTestFile; } private void setTrainFilePath(String pTrainFile) { trainFile = pTrainFile; } private void setModelFilePath(String pModelFile) { modelFile = pModelFile; } }