/** * * @author Walid Shalaby * Inspired from wekaexamples/classifiers/WekawekaClassifier.java */ package edu.uncc.cs.watsonsim.researchers; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.List; import edu.uncc.cs.watsonsim.Answer; import edu.uncc.cs.watsonsim.Question; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; import weka.core.Utils; /** * Combine the scores coming in from many scorers, in order to generate a * single combined score in the end. * * Sorts and reverses the result, so that the top answer is at rank 0. * * As of April 23, the best SVM settings at are at C=16 gamma=0.1 * * @author Walid Shalaby */ public class CombineScores extends Researcher { static String scorerModelPath = "data/scorer/models/allengines.model"; static String scorerDatasetPath = "data/scorer/schemas/allengines-01-schema.arff"; Classifier scorerModel = null; Instances qResultsDataset = null; List<String> names = new ArrayList<>(); public CombineScores() { try { LoadModel(scorerModelPath); qResultsDataset = new Instances(new BufferedReader(new FileReader(scorerDatasetPath))); } catch (ClassNotFoundException e) { e.printStackTrace(); throw new RuntimeException("Weka learners are missing. " + "Did you install Weka correctly?"); } catch (IOException e) { e.printStackTrace(); throw new RuntimeException("Weka models appear to be missing. " + "Do you have data/scorers? It is not possible to run " + "without them."); } // Get the attribute's names as a string @SuppressWarnings("unchecked") List<Attribute> attributes = Collections.list((Enumeration<Attribute>) qResultsDataset.enumerateAttributes()); for (Attribute a : attributes) names.add(a.name()); // Only decide the class attribute afterward because otherwise weka // will cut it out and the length will not match qResultsDataset.setClassIndex(qResultsDataset.attribute("CORRECT").index()); } @Override /** * Perform softmax on scores while retrieving them answer-by-answer */ public List<Answer> question(Question question, List<Answer> answers) { // Collect double[] scores = new double[answers.size()]; { int i = 0; for (Answer a : answers) { try { scores[i++] = score(a.scores.getEach(names)); } catch (Exception e) { System.out.println("An unknown error occured while scoring with Weka. Some results may be scored wrong."); e.printStackTrace(); a.setOverallScore(0.0); } } } { // Then scale (just for cleanliness) double sum = 0; for (int i=0; i<scores.length; i++) { scores[i] = Math.exp(scores[i]); sum += scores[i]; } for (int i=0; i<scores.length; i++) { scores[i] /= sum; } } // Finally, apply. { int i = 0; for (Answer a : answers) a.setOverallScore(scores[i++]); } Collections.sort(answers); Collections.reverse(answers); return answers; } /*public static QuestionResultsScorer prepareGenericScorer(String schemapath, String modelpath) { QuestionResultsScorer scorer = new QuestionResultsScorer(); scorer.scorerModelPath = modelpath; scorer.scorerDatasetPath = schemapath; return scorer; }*/ public void LoadModel(String modelpath) throws IOException, ClassNotFoundException { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelpath)); scorerModel = (Classifier) ois.readObject(); ois.close(); } /** * @param attributesValues: one or more attributes used to score the result e.g., indri rank * @throws Exception */ public double score(double[] attributesValues) throws Exception { Instance inst = new Instance(1, attributesValues); inst.setDataset(qResultsDataset); return scorerModel.distributionForInstance(inst)[1]; //return scorerModel.classifyInstance(inst); } /** * @param inputpath: path of arff file containing results training instances * @param outputpath: path to write training statistics * @param modelpath: path to write trained model * @param targetAttributeName: attribute name of target e.g., "correct" * @param doEvaluate: perform evaluation of model after training * @throws Exception */ public static void buildScorerModel(String inputpath, String outputpath, String modelpath, String targetAttributeName, boolean doEvaluate) throws Exception { // use logistic regression as default classifier //String classifierName = "weka.classifiers.functions.SimpleLogistic"; //String[] classifierOptions = new String[]{"-I", "0", "-M", "500", "-H", "100", "-W", "0.0"}; //String classifierName = "weka.classifiers.functions.Logistic"; //String[] classifierOptions = new String[]{"-R", "1.0E-8", "-M", "-1"}; String classifierName = "weka.classifiers.functions.MultilayerPerceptron"; String[] classifierOptions = new String[]{"-L", "0.3", "-M", "0.2", "-N", "500", "-V", "0", "-S", "0", "-E", "20", "-H", "a"}; //String classifierName = "weka.classifiers.lazy.KStar"; //String[] classifierOptions = new String[]{"-B", "20", "-M", "a"}; buildScorerModel(inputpath, outputpath, modelpath, classifierName, classifierOptions, targetAttributeName, doEvaluate); } /** * @param inputpath: path of arff file containing results training instances * @param outputpath: path to write training statistics * @param modelpath: path to write trained model * @param classifierName: Weka classifier class name * @param classifierOptions: parameters of classifier * @param targetAttributeName: attribute name of target e.g., "correct" * @param doEvaluate: perform 10 fold cross-validation evaluation of model after training * @throws Exception */ public static void buildScorerModel(String inputpath, String outputpath, String modelpath, String classifierName, String[] classifierOptions, String targetAttributeName, boolean doEvaluate) throws Exception { // initialize the classifier Classifier classifier = Classifier.forName(classifierName, classifierOptions); try { // load training instances Instances qResults = new Instances(new BufferedReader(new FileReader(inputpath))); // set target attribute qResults.setClass(qResults.attribute(targetAttributeName)); // build classifier classifier.buildClassifier(qResults); Evaluation evaluation = null; if(doEvaluate) { // 10fold CV with seed=1 evaluation = new Evaluation(qResults); evaluation.crossValidateModel(classifier, qResults, 10, qResults.getRandomNumberGenerator(1)); } // Write training statistics to output file writeStatistics(classifier, qResults, evaluation, outputpath); // write model writeModel(classifier, modelpath); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } private static void writeModel(Classifier classifier, String modelpath) throws IOException { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelpath)); oos.writeObject(classifier); oos.close(); } /** * @param classifier * @param qResults * @param evaluation * @param outputpath * @throws IOException */ private static void writeStatistics(Classifier classifier, Instances qResults, Evaluation evaluation, String outputpath) throws IOException { BufferedWriter writer = new BufferedWriter(new FileWriter(new File(outputpath))); writer.append("Classifier...: " + classifier.getClass().getName() + " " + Utils.joinOptions(classifier.getOptions()) + "\n"); writer.append("Relation: " + qResults.relationName() + "\n"); writer.append("Instances: " + qResults.numInstances() + "\n"); writer.append("Attributes: " + qResults.numAttributes() + "\n"); if(qResults.numAttributes()<=100) { for(int i=0; i<qResults.numAttributes(); i++) writer.append(" " + qResults.attribute(i).name() + "\n"); } writer.append("\n\n"); // model weights writer.append(classifier.toString() + "\n"); if(evaluation!=null) { // some statistics writer.append(evaluation.toSummaryString() + "\n"); try { // per class statistics writer.append(evaluation.toClassDetailsString() + "\n"); } catch (Exception e) { e.printStackTrace(); } try { // confusion matrix writer.append(evaluation.toMatrixString() + "\n"); } catch (Exception e) { e.printStackTrace(); } } writer.close(); } /** Train the classifier * @throws Exception */ public static void main(String[] arguments) throws Exception { buildScorerModel("data/weka-log.arff", "data/model.log", scorerModelPath, "CORRECT", false); } }