/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.client.solrj.io; import java.util.HashMap; import java.util.Map; public class ClassificationEvaluation { private long truePositive; private long falsePositive; private long trueNegative; private long falseNegative; public void count(int actual, int predicted) { if (predicted == 1) { if (actual == 1) truePositive++; else falsePositive++; } else { if (actual == 0) trueNegative++; else falseNegative++; } } public void putToMap(Map map) { map.put("truePositive_i",truePositive); map.put("trueNegative_i",trueNegative); map.put("falsePositive_i",falsePositive); map.put("falseNegative_i",falseNegative); } public Map toMap() { HashMap map = new HashMap(); putToMap(map); return map; } public static ClassificationEvaluation create(Map map) { ClassificationEvaluation evaluation = new ClassificationEvaluation(); evaluation.addEvaluation(map); return evaluation; } public void addEvaluation(Map map) { this.truePositive += (long) map.get("truePositive_i"); this.trueNegative += (long) map.get("trueNegative_i"); this.falsePositive += (long) map.get("falsePositive_i"); this.falseNegative += (long) map.get("falseNegative_i"); } public double getPrecision() { if (truePositive + falsePositive == 0) return 0; return (double) truePositive / (truePositive + falsePositive); } public double getRecall() { if (truePositive + falseNegative == 0) return 0; return (double) truePositive / (truePositive + falseNegative); } public double getF1() { double precision = getPrecision(); double recall = getRecall(); if (precision + recall == 0) return 0; return 2 * (precision * recall) / (precision + recall); } public double getAccuracy() { return (double) (truePositive + trueNegative) / (truePositive + trueNegative + falseNegative + falsePositive); } }