/* Copyright 2003, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.multi;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.ExampleSchema;
/**
* A label which is associated with an instance---either by a classifier,
* or in training data.
*
*<p>
* MultiClassLabels should be weighted to that the weight for a class name
* is (approximately) the log-odds having that class name, ie if
* the probability of class "POS" is p, the getWeight("POS") should
* return Math.log( p/(1-p) ).
*
* The POS and NEG class labels (as defined in
* ExampleSchema.POS_CLASS_NAME and ExampleSchema.NEG_CLASS_NAME) are
* special. Binary class labels should be created with the
* positiveLabel(posWeight) and negativeLabel(negWeight) routines, or
* else the binaryLabel routine. The numericLabel() returns +1 or -1
* for binary classLabels. The posWeight() method returns the score
* of the positive class.
* The classLabel.numericLabel() method ignores the underlying score.
* For testing binary examples, classLabel.isPositive(),
* classLabel.isNegative(), and classLabel.bestWeight() should be
* used.
*
* @author Cameron Williams
*/
public class MultiClassLabel implements Serializable{
static final long serialVersionUID=20080130L;
private ClassLabel[] labels;
private int dimensions;
public MultiClassLabel(){
;
}
public MultiClassLabel(ClassLabel[] labels){
this.labels=labels;
this.dimensions=labels.length;
}
public ClassLabel[] getLabels(){
return labels;
}
/** Return the number of dimensions in the multiLabel */
public int numDimensions(){
return dimensions;
}
/** See if this is one of the distinguished binary labels. */
public boolean[] isBinary(){
boolean[] binary=new boolean[dimensions];
for(int i=0;i<dimensions;i++){
binary[i]=ExampleSchema.BINARY_EXAMPLE_SCHEMA.isValid(labels[i]);
}
return binary;
}
/** See if this is the distinguished positive label. */
public boolean[] isPositive(){
boolean[] positive=new boolean[dimensions];
for(int i=0;i<dimensions;i++){
positive[i]=
ExampleSchema.POS_CLASS_NAME.equals(labels[i].bestClassName());
}
return positive;
}
/** See if this is the distinguished negative label. */
public boolean[] isNegative(){
boolean[] negative=new boolean[dimensions];
for(int i=0;i<dimensions;i++){
negative[i]=
ExampleSchema.NEG_CLASS_NAME.equals(labels[i].bestClassName());
}
return negative;
}
/** Return a numeric score of +1, or -1 for a binary example */
public double[] numericLabel(){
double[] numLabel=new double[dimensions];
for(int i=0;i<dimensions;i++){
numLabel[i]=labels[i].numericLabel();
}
return numLabel;
}
/** Returns the highest-ranking label. */
public String[] bestClassName(){
String[] bestName=new String[dimensions];
for(int i=0;i<dimensions;i++){
bestName[i]=labels[i].bestClassName();
}
return bestName;
}
/** Returns the weight of the highest-ranking label. */
public double[] bestWeight(){
double[] bestWeight=new double[dimensions];
for(int i=0;i<dimensions;i++){
bestWeight[i]=labels[i].bestWeight();
}
return bestWeight;
}
/** Returns the weight of the positive class name */
public double[] posWeight(){
double[] posWeight=new double[dimensions];
for(int i=0;i<dimensions;i++){
posWeight[i]=labels[i].getWeight(ExampleSchema.POS_CLASS_NAME);
}
return posWeight;
}
/** Returns the probability of the positive class name */
public double[] posProbability(){
double[] posProb=new double[dimensions];
for(int i=0;i<dimensions;i++){
posProb[i]=labels[i].getProbability(ExampleSchema.POS_CLASS_NAME);
}
return posProb;
}
/** Returns the weight of the label. */
public double[] getWeight(String[] label){
double[] weight=new double[dimensions];
for(int i=0;i<dimensions;i++){
weight[i]=labels[i].getWeight(label[i]);
}
return weight;
}
/** Returns the probability of a label. */
public double[] getProbability(String[] label){
double[] odds=new double[dimensions];
for(int i=0;i<dimensions;i++){
double expOdds=Math.exp(labels[i].getWeight(label[i]));
odds[i]=expOdds/(1.0+expOdds);
}
return odds;
}
/** Returns the set of labels that appear in the ranking. */
public List<Set<String>> possibleLabels(){
List<Set<String>> sets=new ArrayList<Set<String>>(dimensions);
for(int i=0;i<dimensions;i++){
sets.add(labels[i].possibleLabels());
}
return sets;
}
/** Is this label correct, relative to another label? */
public boolean[] isMultiCorrect(MultiClassLabel otherLabel){
if(otherLabel==null)
throw new IllegalArgumentException("null otherLabel?");
if(bestClassName()==null)
throw new IllegalArgumentException("null bestClassName?");
if(dimensions!=otherLabel.numDimensions())
throw new IllegalArgumentException("Number of Dimensions do not match");
boolean[] correct=new boolean[dimensions];
for(int i=0;i<dimensions;i++){
correct[i]=
this.labels[i].bestClassName().equals(
otherLabel.labels[i].bestClassName());
}
return correct;
}
/** Is this label correct, relative to another label? */
public boolean isCorrect(MultiClassLabel otherLabel){
if(otherLabel==null)
throw new IllegalArgumentException("null otherLabel?");
if(bestClassName()==null)
throw new IllegalArgumentException("null bestClassName?");
if(dimensions!=otherLabel.numDimensions())
throw new IllegalArgumentException("Number of Dimensions do not match");
boolean correct=true;
for(int i=0;i<dimensions;i++){
correct=
correct&&
this.labels[i].bestClassName().equals(
otherLabel.labels[i].bestClassName());
}
return correct;
}
@Override
public String toString(){
String labelString="";
for(int i=0;i<dimensions;i++){
labelString=labelString+labels[i].toString();
}
return labelString;
}
public String toDetails(){
String details="";
for(int i=0;i<dimensions;i++){
details=details+labels[i].toDetails();
}
return details;
}
}