/*
* Copyright (c) 2011 The S4 Project, http://s4.io.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific
* language governing permissions and limitations under the
* License. See accompanying LICENSE file.
*/
package org.apache.s4.example.model;
import java.util.Map;
import java.util.HashMap;
import org.apache.s4.base.Event;
import org.apache.s4.core.App;
import org.apache.s4.core.ProcessingElement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
final public class MetricsPE extends ProcessingElement {
private static final Logger logger = LoggerFactory
.getLogger(MetricsPE.class);
private Map<Integer, HashMap<Integer, MutableInt>> counts;
private long totalCount = 0;
private int numClasses;
public MetricsPE(App app) {
super(app);
}
public void onEvent(Event event) {
ResultEvent resultEvent = (ResultEvent) event;
int classID = resultEvent.getClassId();
int hypID = resultEvent.getHypId();
totalCount += 1;
/* Increment counter. */
if (!counts.containsKey(classID)) {
counts.put(classID, new HashMap<Integer, MutableInt>());
numClasses++;
}
MutableInt value = counts.get(classID).get(hypID);
if (value == null) {
value = new MutableInt();
counts.get(classID).put(hypID, value);
}
value.inc();
}
public void onTrigger(Event event) {
logger.info(this.toString());
}
@Override
protected void onCreate() {
counts = new HashMap<Integer, HashMap<Integer, MutableInt>>();
}
@Override
protected void onRemove() {
}
/** @return number of data vectors processed. */
public long getCount() {
return totalCount;
}
@Override
public String toString() {
StringBuilder report = new StringBuilder();
report.append("\n\nConfusion Matrix [%]:\n");
report.append("\n ");
for (int i = 0; i < numClasses; i++) {
report.append(String.format("%6d", i));
}
report.append("\n ----------------------------------------\n");
long truePositives = 0;
for (Map.Entry<Integer, HashMap<Integer, MutableInt>> entry : counts
.entrySet()) {
int classID = entry.getKey();
report.append(String.format("%5d:", classID));
HashMap<Integer, MutableInt> hypCounts = entry.getValue();
long totalCountForClass = getTotalCountForClass(hypCounts);
float[] sortedCounts = new float[numClasses];
for (Map.Entry<Integer, MutableInt> hypEntry : hypCounts.entrySet()) {
int hypID = hypEntry.getKey();
long count = hypEntry.getValue().get();
/*
* Because of timing, it is possible to have a hypId that was
* not counted in numClasses yet. In this case we bail out and
* without producing a report.
*/
if (hypID > (numClasses - 1))
return "Insufficient data.";
sortedCounts[hypID] = (float) count / totalCountForClass * 100f;
if (classID == hypID)
truePositives += count;
}
for (int i = 0; i < numClasses; i++) {
report.append(String.format("%6.1f", sortedCounts[i]));
}
report.append("\n");
}
report.append(String.format(
"\nAccuracy: %6.1f%% - Num Observations: %6d\n",
(float) truePositives / totalCount * 100f, totalCount));
return report.toString();
}
private long getTotalCountForClass(HashMap<Integer, MutableInt> counts) {
long count = 0;
for (Map.Entry<Integer, MutableInt> hypEntry : counts.entrySet()) {
count += hypEntry.getValue().get();
}
return count;
}
private class MutableInt {
private int value = 0;
private void inc() {
++value;
}
private int get() {
return value;
}
}
}