/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import ml.shifu.shifu.container.ConfusionMatrixObject;
import ml.shifu.shifu.container.ModelResultObject;
import ml.shifu.shifu.util.QuickSort;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* Confusion matrix calculator
*/
public class ConfusionMatrixCalculator {
private static Logger log = LoggerFactory.getLogger(ConfusionMatrixCalculator.class);
// input
private List<ModelResultObject> moList;
private List<String> posTags;
@SuppressWarnings("unused")
private List<String> negTags;
private Double negScaleFactor = 1.0;
private Double posScaleFactor = 1.0;
private static String fmt = "%s|%s|%s|%s|%s|%s|%s|%s|%s\n";
public ConfusionMatrixCalculator(List<String> posTags, List<String> negTags, List<ModelResultObject> moList) {
this.moList = moList;
this.posTags = posTags;
this.negTags = negTags;
QuickSort.sort(this.moList, new ModelResultObject.ModelResultObjectComparator());
}
public List<ConfusionMatrixObject> calculate() {
List<ConfusionMatrixObject> cmoList = new ArrayList<ConfusionMatrixObject>();
// Calculate the sum
Double sumPos = 0.0, sumNeg = 0.0, sumWeightedPos = 0.0, sumWeightedNeg = 0.0;
for(ModelResultObject mo: moList) {
if(posTags.contains(mo.getTag())) {
// Positive
sumPos += posScaleFactor;
sumWeightedPos += mo.getWeight() * posScaleFactor;
} else {
// Negative
sumNeg += negScaleFactor;
sumWeightedNeg += mo.getWeight() * negScaleFactor;
}
}
// init ConfusionMatrix
ConfusionMatrixObject initCmo = new ConfusionMatrixObject();
initCmo.setTp(0.0);
initCmo.setFp(0.0);
initCmo.setFn(sumPos);
initCmo.setTn(sumNeg);
initCmo.setWeightedTp(0.0);
initCmo.setWeightedFp(0.0);
initCmo.setWeightedFn(sumWeightedPos);
initCmo.setWeightedTn(sumWeightedNeg);
initCmo.setScore(moList.get(0).getScore());
cmoList.add(initCmo);
// Calculate the rest
ConfusionMatrixObject prevCmo = initCmo;
for(ModelResultObject mo: moList) {
ConfusionMatrixObject cmo = new ConfusionMatrixObject(prevCmo);
if(posTags.contains(mo.getTag())) {
// Positive Instance
cmo.setTp(cmo.getTp() + posScaleFactor);
cmo.setFn(cmo.getFn() - posScaleFactor);
cmo.setWeightedTp(cmo.getWeightedTp() + mo.getWeight() * posScaleFactor);
cmo.setWeightedFn(cmo.getWeightedFn() - mo.getWeight() * posScaleFactor);
} else {
// Negative Instance
cmo.setFp(cmo.getFp() + negScaleFactor);
cmo.setTn(cmo.getTn() - negScaleFactor);
cmo.setWeightedFp(cmo.getWeightedFp() + mo.getWeight() * negScaleFactor);
cmo.setWeightedTn(cmo.getWeightedTn() - mo.getWeight() * negScaleFactor);
}
cmo.setScore(mo.getScore());
cmoList.add(cmo);
prevCmo = cmo;
}
return cmoList;
}
public void setNegScaleFactor(Double negScaleFactor) {
this.negScaleFactor = negScaleFactor;
}
public void setPosScaleFactor(Double posScaleFactor) {
this.posScaleFactor = posScaleFactor;
}
public void calculate(BufferedWriter writer) {
Double sumPos = 0.0, sumNeg = 0.0, sumWeightedPos = 0.0, sumWeightedNeg = 0.0;
for(ModelResultObject mo: moList) {
if(posTags.contains(mo.getTag())) {
// Positive
sumPos += posScaleFactor;
sumWeightedPos += mo.getWeight() * posScaleFactor;
} else {
// Negative
sumNeg += negScaleFactor;
sumWeightedNeg += mo.getWeight() * negScaleFactor;
}
}
ConfusionMatrixObject prevCmo = new ConfusionMatrixObject();
prevCmo.setTp(0.0);
prevCmo.setFp(0.0);
prevCmo.setFn(sumPos);
prevCmo.setTn(sumNeg);
prevCmo.setWeightedTp(0.0);
prevCmo.setWeightedFp(0.0);
prevCmo.setWeightedFn(sumWeightedPos);
prevCmo.setWeightedTn(sumWeightedNeg);
prevCmo.setScore(1000);
saveConfusionMaxtrixWithWriter(writer, prevCmo);
for(ModelResultObject mo: moList) {
ConfusionMatrixObject cmo = new ConfusionMatrixObject(prevCmo);
if(posTags.contains(mo.getTag())) {
// Positive Instance
cmo.setTp(cmo.getTp() + posScaleFactor);
cmo.setFn(cmo.getFn() - posScaleFactor);
cmo.setWeightedTp(cmo.getWeightedTp() + mo.getWeight() * posScaleFactor);
cmo.setWeightedFn(cmo.getWeightedFn() - mo.getWeight() * posScaleFactor);
} else {
// Negative Instance
cmo.setFp(cmo.getFp() + negScaleFactor);
cmo.setTn(cmo.getTn() - negScaleFactor);
cmo.setWeightedFp(cmo.getWeightedFp() + mo.getWeight() * negScaleFactor);
cmo.setWeightedTn(cmo.getWeightedTn() - mo.getWeight() * negScaleFactor);
}
cmo.setScore(mo.getScore());
saveConfusionMaxtrixWithWriter(writer, cmo);
prevCmo = cmo;
}
}
public static void saveConfusionMaxtrixWithWriter(BufferedWriter writer, ConfusionMatrixObject cmo) {
try {
writer.write(String.format(fmt, cmo.getTp(), cmo.getFp(), cmo.getFn(), cmo.getTn(), cmo.getWeightedTp(),
cmo.getWeightedFp(), cmo.getWeightedFn(), cmo.getWeightedTn(), cmo.getScore()));
} catch (IOException e) {
try {
writer.close();
} catch (IOException e1) {
log.error("Could not close the writer while write into confusion matrix");
}
log.error("Could not write into confusion matrix");
}
}
}