/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* CostMatrix.java
* Copyright (C) 2001 Richard Kirkby
*
*/
package weka.classifiers;
import weka.core.Matrix;
import weka.core.Instances;
import weka.core.Utils;
import java.io.Reader;
import java.io.StreamTokenizer;
import java.util.Random;
/**
* Class for storing and manipulating a misclassification cost matrix.
* The element at position i,j in the matrix is the penalty for classifying
* an instance of class j as class i.
*
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 1.1.1.1 $
*/
public class CostMatrix extends Matrix {
/** The deafult file extension for cost matrix files */
public static String FILE_EXTENSION = ".cost";
/**
* Creates a cost matrix that is a copy of another.
*
* @param toCopy the matrix to copy.
*/
public CostMatrix(CostMatrix toCopy) {
super(toCopy.size(), toCopy.size());
for (int x=0; x<toCopy.size(); x++)
for (int y=0; y<toCopy.size(); y++)
setElement(x, y, toCopy.getElement(x, y));
}
/**
* Creates a default cost matrix of a particular size. All values will be 0.
*
* @param numOfClasses the number of classes that the cost matrix holds.
*/
public CostMatrix(int numOfClasses) {
super(numOfClasses, numOfClasses);
}
/**
* Creates a cost matrix from a reader.
*
* @param reader the reader to get the values from.
* @exception Exception if the matrix is invalid.
*/
public CostMatrix(Reader reader) throws Exception {
super(reader);
// make sure that the matrix is square
if (numRows() != numColumns())
throw new Exception("Trying to create a non-square cost matrix");
}
/**
* Sets the cost of all correct classifications to 0, and all
* misclassifications to 1.
*
*/
public void initialize() {
for (int i = 0; i < size(); i++) {
for (int j = 0; j < size(); j++) {
setElement(i, j, i == j ? 0.0 : 1.0);
}
}
}
/**
* Gets the size of the matrix.
*
* @return the size.
*/
public int size() {
return numColumns();
}
/**
* Applies the cost matrix to a set of instances. If a random number generator is
* supplied the instances will be resampled, otherwise they will be rewighted.
* Adapted from code once sitting in Instances.java
*
* @param data the instances to reweight.
* @param random a random number generator for resampling, if null then instances are
* rewighted.
* @return a new dataset reflecting the cost of misclassification.
* @exception Exception if the data has no class or the matrix in inappropriate.
*/
public Instances applyCostMatrix(Instances data, Random random) throws Exception {
double sumOfWeightFactors = 0, sumOfMissClassWeights,
sumOfWeights;
double [] weightOfInstancesInClass, weightFactor, weightOfInstances;
Instances newData;
if (data.classIndex() < 0) {
throw new Exception("Class index is not set!");
}
if (size() != data.numClasses()) {
throw new Exception("Misclassification cost matrix has "+
"wrong format!");
}
weightFactor = new double[data.numClasses()];
weightOfInstancesInClass = new double[data.numClasses()];
for (int j = 0; j < data.numInstances(); j++) {
weightOfInstancesInClass[(int)data.instance(j).classValue()] +=
data.instance(j).weight();
}
sumOfWeights = Utils.sum(weightOfInstancesInClass);
// normalize the matrix if not already
for (int i=0; i<size(); i++)
if (!Utils.eq(getElement(i, i),0)) {
CostMatrix normMatrix = new CostMatrix(this);
normMatrix.normalize();
return normMatrix.applyCostMatrix(data, random);
}
for (int i = 0; i < data.numClasses(); i++) {
// Using Kai Ming Ting's formula for deriving weights for
// the classes and Breiman's heuristic for multiclass
// problems.
sumOfMissClassWeights = 0;
for (int j = 0; j < data.numClasses(); j++) {
if (Utils.sm(getElement(i,j),0)) {
throw new Exception("Neg. weights in misclassification "+
"cost matrix!");
}
sumOfMissClassWeights += getElement(i,j);
}
weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
sumOfWeightFactors += sumOfMissClassWeights *
weightOfInstancesInClass[i];
}
for (int i = 0; i < data.numClasses(); i++) {
weightFactor[i] /= sumOfWeightFactors;
}
// Store new weights
weightOfInstances = new double[data.numInstances()];
for (int i = 0; i < data.numInstances(); i++) {
weightOfInstances[i] = data.instance(i).weight()*
weightFactor[(int)data.instance(i).classValue()];
}
// Change instances weight or do resampling
if (random != null) {
return data.resampleWithWeights(random, weightOfInstances);
} else {
Instances instances = new Instances(data);
for (int i = 0; i < data.numInstances(); i++) {
instances.instance(i).setWeight(weightOfInstances[i]);
}
return instances;
}
}
/**
* Calculates the expected misclassification cost for each possible class value,
* given class probability estimates.
*
* @param classProbs the class probability estimates.
* @return the expected costs.
* @exception Exception if the wrong number of class probabilities is supplied.
*/
public double[] expectedCosts(double[] classProbs) throws Exception {
if (classProbs.length != size())
throw new Exception("Length of probability estimates don't match cost matrix");
double[] costs = new double[size()];
for (int x=0; x<size(); x++)
for (int y=0; y<size(); y++)
costs[x] += classProbs[y] * getElement(x, y);
return costs;
}
/**
* Gets the maximum cost for a particular class value.
*
* @param classVal the class value.
* @return the maximum cost.
*/
public double getMaxCost(int classVal) {
double maxCost = Double.NEGATIVE_INFINITY;
for (int i=0; i<size(); i++) {
double cost = getElement(classVal, i);
if (cost > maxCost) maxCost = cost;
}
return maxCost;
}
/**
* Normalizes the matrix so that the diagonal contains zeros.
*
*/
public void normalize() {
for (int y=0; y<size(); y++) {
double diag = getElement(y, y);
for (int x=0; x<size(); x++)
setElement(x, y, getElement(x, y) - diag);
}
}
/**
* Loads a cost matrix in the old format from a reader. Adapted from code once sitting
* in Instances.java
*
* @param reader the reader to get the values from.
* @exception Exception if the matrix cannot be read correctly.
*/
public void readOldFormat(Reader reader) throws Exception {
StreamTokenizer tokenizer;
int currentToken;
double firstIndex, secondIndex, weight;
tokenizer = new StreamTokenizer(reader);
initialize();
tokenizer.commentChar('%');
tokenizer.eolIsSignificant(true);
while (StreamTokenizer.TT_EOF !=
(currentToken = tokenizer.nextToken())) {
// Skip empty lines
if (currentToken == StreamTokenizer.TT_EOL) {
continue;
}
// Get index of first class.
if (currentToken != StreamTokenizer.TT_NUMBER) {
throw new Exception("Only numbers and comments allowed "+
"in cost file!");
}
firstIndex = tokenizer.nval;
if (!Utils.eq((double)(int)firstIndex,firstIndex)) {
throw new Exception("First number in line has to be "+
"index of a class!");
}
if ((int)firstIndex >= size()) {
throw new Exception("Class index out of range!");
}
// Get index of second class.
if (StreamTokenizer.TT_EOF ==
(currentToken = tokenizer.nextToken())) {
throw new Exception("Premature end of file!");
}
if (currentToken == StreamTokenizer.TT_EOL) {
throw new Exception("Premature end of line!");
}
if (currentToken != StreamTokenizer.TT_NUMBER) {
throw new Exception("Only numbers and comments allowed "+
"in cost file!");
}
secondIndex = tokenizer.nval;
if (!Utils.eq((double)(int)secondIndex,secondIndex)) {
throw new Exception("Second number in line has to be "+
"index of a class!");
}
if ((int)secondIndex >= size()) {
throw new Exception("Class index out of range!");
}
if ((int)secondIndex == (int)firstIndex) {
throw new Exception("Diagonal of cost matrix non-zero!");
}
// Get cost factor.
if (StreamTokenizer.TT_EOF ==
(currentToken = tokenizer.nextToken())) {
throw new Exception("Premature end of file!");
}
if (currentToken == StreamTokenizer.TT_EOL) {
throw new Exception("Premature end of line!");
}
if (currentToken != StreamTokenizer.TT_NUMBER) {
throw new Exception("Only numbers and comments allowed "+
"in cost file!");
}
weight = tokenizer.nval;
if (!Utils.gr(weight,0)) {
throw new Exception("Only positive weights allowed!");
}
setElement((int)firstIndex, (int)secondIndex, weight);
}
}
}