/*
* File: DefaultConfusionMatrix.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright January 11, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.performance.categorization;
import gov.sandia.cognition.learning.data.TargetEstimatePair;
import gov.sandia.cognition.math.MutableDouble;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.Pair;
import gov.sandia.cognition.util.Summarizer;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
/**
* A default implementation of the {@code ConfusionMatrix} interface. It is
* backed by a two-level map storing the category object counts, making a
* sparse representation.
*
* @param <CategoryType>
* The type of the category object over the confusion matrix.
* @author Justin Basilico
* @since 3.1
*/
public class DefaultConfusionMatrix<CategoryType>
extends AbstractConfusionMatrix<CategoryType>
{
/** The backing map of confusion matrix entries. The first key is the
* actual category and the second is the predicted category. */
protected Map<CategoryType, Map<CategoryType, MutableDouble>> confusions;
/**
* Creates a new, empty {@code DefaultConfusionMatrix}.
*/
public DefaultConfusionMatrix()
{
super();
this.confusions = new LinkedHashMap<CategoryType, Map<CategoryType, MutableDouble>>();
}
/**
* Creates a copy of a given confusion matrix.
*
* @param other
* The other confusion matrix to copy.
*/
public DefaultConfusionMatrix(
final ConfusionMatrix<? extends CategoryType> other)
{
this();
this.addAll(other);
}
@Override
public DefaultConfusionMatrix<CategoryType> clone()
{
@SuppressWarnings("unchecked")
final DefaultConfusionMatrix<CategoryType> clone =
(DefaultConfusionMatrix<CategoryType>) super.clone();
if (this.confusions != null)
{
clone.confusions = new LinkedHashMap<CategoryType, Map<CategoryType, MutableDouble>>(
this.confusions.size());
for (Map.Entry<CategoryType, Map<CategoryType, MutableDouble>> outerEntry
: this.confusions.entrySet())
{
final LinkedHashMap<CategoryType, MutableDouble> categoryMap =
new LinkedHashMap<CategoryType, MutableDouble>(
outerEntry.getValue().size());
clone.confusions.put(outerEntry.getKey(), categoryMap);
for (Map.Entry<CategoryType, MutableDouble> innerEntry
: outerEntry.getValue().entrySet())
{
categoryMap.put(innerEntry.getKey(),
innerEntry.getValue().clone());
}
}
}
return clone;
}
@Override
public void add(
final CategoryType target,
final CategoryType estimate,
final double value)
{
Map<CategoryType, MutableDouble> subMap = confusions.get(target);
if (subMap == null)
{
subMap = new HashMap<CategoryType, MutableDouble>();
this.confusions.put(target, subMap);
}
MutableDouble entry = subMap.get(estimate);
if (entry == null)
{
entry = new MutableDouble(value);
subMap.put(estimate, entry);
}
else
{
entry.value += value;
}
}
@Override
public double getCount(
final CategoryType target,
final CategoryType estimate)
{
Map<CategoryType, MutableDouble> subMap = confusions.get(target);
if (subMap == null)
{
return 0.0;
}
else
{
MutableDouble entry = subMap.get(estimate);
if (entry == null)
{
return 0.0;
}
else
{
return entry.getValue();
}
}
}
@Override
public double getActualCount(
final CategoryType target)
{
Map<CategoryType, MutableDouble> subMap = confusions.get(target);
if (subMap == null)
{
return 0.0;
}
double result = 0.0;
for (MutableDouble value : subMap.values())
{
result += value.getValue();
}
return result;
}
@Override
public void clear()
{
this.confusions.clear();
}
@Override
public Set<CategoryType> getCategories()
{
final LinkedHashSet<CategoryType> result =
new LinkedHashSet<CategoryType>();
result.addAll(this.getActualCategories());
result.addAll(this.getPredictedCategories());
return result;
}
@Override
public Set<CategoryType> getActualCategories()
{
return this.confusions.keySet();
}
@Override
public Set<CategoryType> getPredictedCategories()
{
final LinkedHashSet<CategoryType> estimates = new LinkedHashSet<CategoryType>(
this.confusions.size());
for (Map<CategoryType, ?> estimateCounts
: this.confusions.values())
{
estimates.addAll(estimateCounts.keySet());
}
return estimates;
}
@Override
public Set<CategoryType> getPredictedCategories(
final CategoryType target)
{
Map<CategoryType, MutableDouble> subMap = confusions.get(target);
if (subMap == null)
{
return Collections.emptySet();
}
else
{
return subMap.keySet();
}
}
@Override
public String toString()
{
return this.confusions.toString();
}
/**
* Creates a new {@code DefaultConfusionMatrix} from the given
* actual-predicted pairs.
*
* @param <CategoryType>
* The category type.
* @param pairs
* The actual-category pairs.
* @return
* A new confusion matrix populated from the given actual-category
* pairs.
*/
public static <CategoryType> DefaultConfusionMatrix<CategoryType> createUnweighted(
final Collection<? extends TargetEstimatePair<? extends CategoryType, ? extends CategoryType>> pairs)
{
final DefaultConfusionMatrix<CategoryType> result =
new DefaultConfusionMatrix<CategoryType>();
for (TargetEstimatePair<? extends CategoryType, ? extends CategoryType> item
: pairs)
{
result.add(item.getTarget(), item.getEstimate());
}
return result;
}
/**
* Creates a new {@code DefaultConfusionMatrix} from the given
* actual-predicted pairs.
*
* @param <CategoryType>
* The category type.
* @param pairs
* The actual-category pairs.
* @return
* A new confusion matrix populated from the given actual-category
* pairs.
*/
public static <CategoryType> DefaultConfusionMatrix<CategoryType> createFromActualPredictedPairs(
final Collection<? extends Pair<? extends CategoryType, ? extends CategoryType>> pairs)
{
final DefaultConfusionMatrix<CategoryType> result =
new DefaultConfusionMatrix<CategoryType>();
for (Pair<? extends CategoryType, ? extends CategoryType> pair
: pairs)
{
result.add(pair.getFirst(), pair.getSecond());
}
return result;
}
/**
* A confusion matrix summarizer that summarizes actual-predicted pairs.
*
* @param <CategoryType>
* The type of category of the summarizer.
*/
public static class ActualPredictedPairSummarizer<CategoryType>
extends AbstractCloneableSerializable
implements Summarizer<Pair<? extends CategoryType, ? extends CategoryType>, DefaultConfusionMatrix<CategoryType>>
{
/**
* Creates a new {@code CombineSummarizer}.
*/
public ActualPredictedPairSummarizer()
{
super();
}
@Override
public DefaultConfusionMatrix<CategoryType> summarize(
final Collection<? extends Pair<? extends CategoryType, ? extends CategoryType>> data)
{
return createFromActualPredictedPairs(data);
}
}
/**
* A confusion matrix summarizer that adds together confusion matrices.
*
* @param <CategoryType>
* The type of category of the summarizer.
*/
public static class CombineSummarizer<CategoryType>
extends AbstractCloneableSerializable
implements Summarizer<ConfusionMatrix<CategoryType>, DefaultConfusionMatrix<CategoryType>>
{
/**
* Creates a new {@code CombineSummarizer}.
*/
public CombineSummarizer()
{
super();
}
@Override
public DefaultConfusionMatrix<CategoryType> summarize(
final Collection<? extends ConfusionMatrix<CategoryType>> data)
{
final DefaultConfusionMatrix<CategoryType> result =
new DefaultConfusionMatrix<CategoryType>();
for (ConfusionMatrix<CategoryType> item : data)
{
result.addAll(item);
}
return result;
}
}
/**
* A factory for default confusion matrices.
*
* @param <CategoryType>
* The type of category that the confusion is computed over.
*/
public static class Factory<CategoryType>
extends AbstractCloneableSerializable
implements gov.sandia.cognition.factory.Factory<DefaultConfusionMatrix<CategoryType>>
{
/**
* Creates a new {@code Factory}.
*/
public Factory()
{
super();
}
@Override
public DefaultConfusionMatrix<CategoryType> create()
{
return new DefaultConfusionMatrix<CategoryType>();
}
}
}