/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.lucene.classification.utils; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.apache.lucene.classification.ClassificationResult; import org.apache.lucene.classification.Classifier; import org.apache.lucene.document.Document; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TermRangeQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.NamedThreadFactory; /** * Utility class to generate the confusion matrix of a {@link Classifier} */ public class ConfusionMatrixGenerator { private ConfusionMatrixGenerator() { } /** * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier}, * generated on the given {@link IndexReader}, class and text fields. * * @param reader the {@link IndexReader} containing the index used for creating the {@link Classifier} * @param classifier the {@link Classifier} whose confusion matrix has to be generated * @param classFieldName the name of the Lucene field used as the classifier's output * @param textFieldName the nome the Lucene field used as the classifier's input * @param timeoutMilliseconds timeout to wait before stopping creating the confusion matrix * @param <T> the return type of the {@link ClassificationResult} returned by the given {@link Classifier} * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} * @throws IOException if problems occurr while reading the index or using the classifier */ public static <T> ConfusionMatrix getConfusionMatrix(IndexReader reader, Classifier<T> classifier, String classFieldName, String textFieldName, long timeoutMilliseconds) throws IOException { ExecutorService executorService = Executors.newFixedThreadPool(1, new NamedThreadFactory("confusion-matrix-gen-")); try { Map<String, Map<String, Long>> counts = new HashMap<>(); IndexSearcher indexSearcher = new IndexSearcher(reader); TopDocs topDocs = indexSearcher.search(new TermRangeQuery(classFieldName, null, null, true, true), Integer.MAX_VALUE); double time = 0d; int counter = 0; for (ScoreDoc scoreDoc : topDocs.scoreDocs) { if (timeoutMilliseconds > 0 && time >= timeoutMilliseconds) { break; } Document doc = reader.document(scoreDoc.doc); String[] correctAnswers = doc.getValues(classFieldName); if (correctAnswers != null && correctAnswers.length > 0) { Arrays.sort(correctAnswers); ClassificationResult<T> result; String text = doc.get(textFieldName); if (text != null) { try { // fail if classification takes more than 5s long start = System.currentTimeMillis(); result = executorService.submit(() -> classifier.assignClass(text)).get(5, TimeUnit.SECONDS); long end = System.currentTimeMillis(); time += end - start; if (result != null) { T assignedClass = result.getAssignedClass(); if (assignedClass != null) { counter++; String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString(); String correctAnswer; if (Arrays.binarySearch(correctAnswers, classified) >= 0) { correctAnswer = classified; } else { correctAnswer = correctAnswers[0]; } Map<String, Long> stringLongMap = counts.get(correctAnswer); if (stringLongMap != null) { Long aLong = stringLongMap.get(classified); if (aLong != null) { stringLongMap.put(classified, aLong + 1); } else { stringLongMap.put(classified, 1L); } } else { stringLongMap = new HashMap<>(); stringLongMap.put(classified, 1L); counts.put(correctAnswer, stringLongMap); } } } } catch (TimeoutException timeoutException) { // add classification timeout time += 5000; } catch (ExecutionException | InterruptedException executionException) { throw new RuntimeException(executionException); } } } } return new ConfusionMatrix(counts, time / counter, counter); } finally { executorService.shutdown(); } } /** * a confusion matrix, backed by a {@link Map} representing the linearized matrix */ public static class ConfusionMatrix { private final Map<String, Map<String, Long>> linearizedMatrix; private final double avgClassificationTime; private final int numberOfEvaluatedDocs; private double accuracy = -1d; private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix, double avgClassificationTime, int numberOfEvaluatedDocs) { this.linearizedMatrix = linearizedMatrix; this.avgClassificationTime = avgClassificationTime; this.numberOfEvaluatedDocs = numberOfEvaluatedDocs; } /** * get the linearized confusion matrix as a {@link Map} * * @return a {@link Map} whose keys are the correct classification answers and whose values are the actual answers' * counts */ public Map<String, Map<String, Long>> getLinearizedMatrix() { return Collections.unmodifiableMap(linearizedMatrix); } /** * calculate precision on the given class * * @param klass the class to calculate the precision for * @return the precision for the given class */ public double getPrecision(String klass) { Map<String, Long> classifications = linearizedMatrix.get(klass); double tp = 0; double den = 0; // tp + fp if (classifications != null) { for (Map.Entry<String, Long> entry : classifications.entrySet()) { if (klass.equals(entry.getKey())) { tp += entry.getValue(); } } for (Map<String, Long> values : linearizedMatrix.values()) { if (values.containsKey(klass)) { den += values.get(klass); } } } return tp > 0 ? tp / den : 0; } /** * calculate recall on the given class * * @param klass the class to calculate the recall for * @return the recall for the given class */ public double getRecall(String klass) { Map<String, Long> classifications = linearizedMatrix.get(klass); double tp = 0; double fn = 0; if (classifications != null) { for (Map.Entry<String, Long> entry : classifications.entrySet()) { if (klass.equals(entry.getKey())) { tp += entry.getValue(); } else { fn += entry.getValue(); } } } return tp + fn > 0 ? tp / (tp + fn) : 0; } /** * get the F-1 measure of the given class * * @param klass the class to calculate the F-1 measure for * @return the F-1 measure for the given class */ public double getF1Measure(String klass) { double recall = getRecall(klass); double precision = getPrecision(klass); return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0; } /** * get the F-1 measure on this confusion matrix * * @return the F-1 measure */ public double getF1Measure() { double recall = getRecall(); double precision = getPrecision(); return precision > 0 && recall > 0 ? 2 * precision * recall / (precision + recall) : 0; } /** * Calculate accuracy on this confusion matrix using the formula: * {@literal accuracy = correctly-classified / (correctly-classified + wrongly-classified)} * * @return the accuracy */ public double getAccuracy() { if (this.accuracy == -1) { double tp = 0d; double tn = 0d; double tfp = 0d; // tp + fp double fn = 0d; for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); for (Map.Entry<String, Long> entry : classification.getValue().entrySet()) { if (klass.equals(entry.getKey())) { tp += entry.getValue(); } else { fn += entry.getValue(); } } for (Map<String, Long> values : linearizedMatrix.values()) { if (values.containsKey(klass)) { tfp += values.get(klass); } else { tn++; } } } this.accuracy = (tp + tn) / (tfp + fn + tn); } return this.accuracy; } /** * get the macro averaged precision (see {@link #getPrecision(String)}) over all the classes. * * @return the macro averaged precision as computed from the confusion matrix */ public double getPrecision() { double p = 0; for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); p += getPrecision(klass); } return p / linearizedMatrix.size(); } /** * get the macro averaged recall (see {@link #getRecall(String)}) over all the classes * * @return the recall as computed from the confusion matrix */ public double getRecall() { double r = 0; for (Map.Entry<String, Map<String, Long>> classification : linearizedMatrix.entrySet()) { String klass = classification.getKey(); r += getRecall(klass); } return r / linearizedMatrix.size(); } @Override public String toString() { return "ConfusionMatrix{" + "linearizedMatrix=" + linearizedMatrix + ", avgClassificationTime=" + avgClassificationTime + ", numberOfEvaluatedDocs=" + numberOfEvaluatedDocs + '}'; } /** * get the average classification time in milliseconds * * @return the avg classification time */ public double getAvgClassificationTime() { return avgClassificationTime; } /** * get the no. of documents evaluated while generating this confusion matrix * * @return the no. of documents evaluated */ public int getNumberOfEvaluatedDocs() { return numberOfEvaluatedDocs; } } }