package qa.qcri.aidr.predict.classification.nominal; import java.util.ArrayList; import java.util.List; import org.apache.log4j.Logger; import qa.qcri.aidr.predict.DataStore; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.core.Attribute; import weka.core.Instance; import weka.core.Instances; import weka.core.SparseInstance; import qa.qcri.aidr.predict.data.Document; import qa.qcri.aidr.predict.featureextraction.WordSet; /** * Model contains an instance of a Weka classifier, and transforms a feature * vector into a Weka Instance for classification. Model also contains * information about what features are used in the classifier and the model's * latest performance metrics. * * @author jrogstadius */ public class Model { private static Logger logger = Logger.getLogger(Model.class); private final int attributeID; private Classifier classifier; private Instances attributeSpecification; private Evaluation evaluation; private double[] missingVal; private double meanPrecision, meanRecall, meanAuc; private int modelID = -1; private Integer trainingSampleCount; private List<ModelNominalLabelPerformance> labelPerformanceList; public Model(int attributeID, Classifier classifier, Instances specification) { this.attributeID = attributeID; this.classifier = classifier; this.attributeSpecification = specification; this.missingVal = new double[attributeSpecification.numAttributes()]; labelPerformanceList = new ArrayList<ModelNominalLabelPerformance>(); } public Evaluation evaluate(Instances evaluationSet) throws Exception { Evaluation evaluation = new Evaluation(evaluationSet); evaluation.evaluateModel(classifier, evaluationSet); Integer nullLabelID = DataStore.getNullLabelID(attributeID); // Calculate performance score Attribute classAttribute = attributeSpecification .attribute(attributeSpecification.numAttributes() - 1); double precSum = 0, recSum = 0, aucSum = 0; for (int i = 0; i < classAttribute.numValues(); i++) { double precTemp = evaluation.precision(i); double recTemp = evaluation.recall(i); double aucTemp = evaluation.areaUnderROC(i); //.areaUnderPRC(i); if (!(aucTemp >= 0 && aucTemp <= 1)) { logger.warn("AUC is not available for the trained model"); aucTemp = 0; } //Per-label classification performance ModelNominalLabelPerformance labelPerformance = new ModelNominalLabelPerformance( Integer.parseInt(classAttribute.value(i)), precTemp, recTemp, aucTemp, 0); labelPerformanceList.add(labelPerformance); //Average classification performance across all non-null labels //if (Integer.parseInt(classAttribute.value(i)) != nullLabelID) { precSum += precTemp; recSum += recTemp; aucSum += aucTemp; //} } double numValues = classAttribute.numValues();// - 1; // ignore "null" meanPrecision = precSum / numValues; meanRecall = recSum / numValues; meanAuc = aucSum / numValues; return evaluation; } public Classifier getClassifier() { return classifier; } public Instances getAttributeSpecification() { return attributeSpecification; } public Evaluation getEvaluationResults() { return evaluation; } public int getAttributeID() { return attributeID; } public NominalLabelBC classify(Document item) { if (modelID == -1) { logger.error("Model has not been initialized"); throw new RuntimeException("Model has not been initialized"); } ArrayList<WordSet> wordSets = item.getFeatures(WordSet.class); if (wordSets.isEmpty()) { return null; } Instance instance = wordsToInstance(WordSet.join(wordSets)); Attribute classAttribute = attributeSpecification .attribute(attributeSpecification.numAttributes() - 1); try { double[] labelProbabilities = classifier .distributionForInstance(instance); //labelIndex refers to the position of this label in the class //attribute's list of possible values int labelIndex = findLargestIndex(labelProbabilities); //Weka's class attribute has only string values, so we parse the value //to get a nominalLabelID int labelID = Integer.parseInt(classAttribute.value(labelIndex)); NominalLabelBC label = new NominalLabelBC( modelID, attributeID, labelID, labelProbabilities[labelIndex]); item.addLabel(label); return label; } catch (Exception e) { logger.error("Exception when classifying document set", e); } return null; } int findLargestIndex(double[] probabilities) { int i = 0; for (int j = 1; j < probabilities.length; j++) { if (probabilities[j] > probabilities[i]) { i = j; } } return i; } Instance wordsToInstance(WordSet words) { Instance item = new SparseInstance( attributeSpecification.numAttributes()); item.setDataset(attributeSpecification); // Words for (String word : words.getWords()) { Attribute attribute = attributeSpecification.attribute(word); if (attribute != null) { item.setValue(attribute, 1); } } item.replaceMissingValues(missingVal); return item; } public double getMeanPrecision() { return meanPrecision; } public double getMeanRecall() { return meanRecall; } public double getMeanAuc() { return meanAuc; } public double getWeightedPerformance() { return meanPrecision + 0.5 * meanRecall; } public void setTrainingSampleCount(int count) { trainingSampleCount = count; } public int getTrainingSampleCount() { if (trainingSampleCount == null) { logger.error("trainingSampleCount has not been set"); throw new RuntimeException("trainingSampleCount has not been set"); } return trainingSampleCount; } public void setModelID(int modelID) { this.modelID = modelID; } public int getModelID() { return modelID; } public List<ModelNominalLabelPerformance> getLabelPerformanceList() { return labelPerformanceList; } }