/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* LearnableMetric.java
* Copyright (C) 2001 Mikhail Bilenko
*
*/
package weka.core.metrics;
import java.util.ArrayList;
import java.util.HashMap;
import weka.core.*;
import weka.classifiers.*;
import weka.clusterers.regularizers.*;
/**
* Interface to distance metrics that can be learned
*
* @author Mikhail Bilenko (mbilenko@cs.utexas.edu)
* @version $Revision: 1.13 $
*/
public abstract class LearnableMetric extends Metric {
/** Weights of individual attributes */
protected double[] m_attrWeights = null;
/** The metric may require normalizing all data */
protected boolean m_normalizeData = false;
/** The maximum number of same-class examples to construct diff-instances from */
protected int m_numPosDiffInstances = 200;
/** Proportion of different-class versus same-class diff-instances */
protected double m_posNegDiffInstanceRatio = 1.0;
/** A metric may utilize a classifier for learning its parameters */
boolean m_usesClassifier = false;
protected String m_classifierClassName = null;
public Classifier m_classifier = null;
/** Certain classifiers may use non-nominal class attributes */
protected boolean m_classifierRequiresNominalClass = false;
/** True if metric learning is used. Set to false to turn off metric learning */
protected boolean m_trainable = false;
/** Has the metric been trained? */
boolean m_trained = false;
/** True if metric uses an external estimator for calculating distances */
boolean m_external = false;
/** Current regularizer value */
double m_regularizerVal = 0;
/** use regularization on weights */
protected boolean m_regularize = false;
protected Regularizer m_regularizer = new Rayleigh();
/** Current normalizer value */
double m_normalizer = 0;
/** The metric may return its own maximum distance */
public boolean m_fixedMaxDistance = false;
protected double m_maxDistance = Double.MAX_VALUE;
public double getMaxDistance() {
return m_maxDistance;
}
/**
* Train the distance metric. A specific metric will take care of
* its own training either via a metric learner or by itself.
*/
public abstract void learnMetric(Instances data) throws Exception;
/**
* switch from calculating the metric to pair-space classification
*
* @param classifierClassName Some classifier that classifies pairs of points
* @param classifierRequiresNominalClass does classifier need a nominal class attribute?
* Using DistributionClassifier because it actually reports a margin
* SMO is first, will try others as well
*/
public void useClassifier (String classifierClassName, boolean classifierRequiresNominalClass) throws Exception{
m_classifierClassName = classifierClassName;
m_classifier = (Classifier)Class.forName(classifierClassName).newInstance();
m_classifierRequiresNominalClass = classifierRequiresNominalClass;
m_usesClassifier = true;
}
/**
* switch from using a classifier in difference-space to vanilla L-1
* norm distance
*/
public void useNoClassifier () {
m_usesClassifier = false;
m_classifier = null;
}
/**
* Is this metric defined in vanilla space, or difference space?
*
* @return true if metric uses a classifier to classify L1-Norm as in-cluster
* or out-of-cluster
*/
public boolean usesClassifier() {
return m_usesClassifier;
}
/** Is the metric normalizing? */
public boolean doesNormalizeData() {
return m_normalizeData;
}
/**
* Reset all values that have been learned
*/
public void resetMetric() throws Exception {
m_trained = false;
if (m_attrWeights != null) {
for (int i = 0; i < m_attrWeights.length; i++) {
m_attrWeights[i] = 1;
}
}
recomputeNormalizer();
recomputeRegularizer();
}
/**
* Create an instance with features corresponding to components of the two given instances
* @param instance1 first instance
* @param instance2 second instance
*/
public abstract Instance createDiffInstance (Instance instance1, Instance instance2);
/** Get the values of the partial derivates for the metric components
* for a particular instance pair
@param instance1 the first instance
@param instance2 the first instance
*/
public abstract double[] getGradients(Instance instance1, Instance instance2) throws Exception;
/**
* Set the feature weights
*
* @param weights an array of double weights for features
*/
public void setWeights(double[] _weights) throws Exception{
if (_weights.length != m_numAttributes) {
throw new Exception("Number of weights " + _weights.length +
" is not equal to the number of attributes " + m_numAttributes);
}
// System.out.print("Setting weights: [ ");
// for (int i = 0; i < _weights.length; i++) {
// System.out.print(_weights[i] + " ");
// if (i > 100) {
// System.out.print("...");
// break;
// }
// }
// System.out.println("]");
m_attrWeights = new double[_weights.length];
System.arraycopy(_weights, 0, m_attrWeights, 0, _weights.length);
recomputeNormalizer();
recomputeRegularizer();
}
/** Get the feature weights
*
* @return an array of feature weights
*/
public double[] getWeights() {
return m_attrWeights;
}
/**
* Given a cluster of instances, return the centroid of that cluster
* @param instances data points belonging to a cluster
* @return a centroid instance for the given cluster
*/
public abstract Instance getCentroidInstance(Instances instances, boolean fastMode, boolean normalized);
/** Fast version of meanOrMode - streamlined from Instances.meanOrMode for efficiency
* Does not check for missing attributes, assumes numeric attributes, assumes Sparse instances
*
*/
public double[] meanOrMode(Instances insts) {
int numAttributes = insts.numAttributes();
double [] value = new double[numAttributes];
double weight = 0;
for (int i=0; i<numAttributes; i++) {
value[i] = 0;
}
for (int j=0; j<insts.numInstances(); j++) {
SparseInstance inst = (SparseInstance) (insts.instance(j));
weight += inst.weight();
for (int i=0; i<inst.numValues(); i++) {
int indexOfIndex = inst.index(i);
value[indexOfIndex] += inst.weight() * inst.value(indexOfIndex);
}
}
if (Utils.eq(weight, 0)) {
for (int k=0; k<numAttributes; k++) {
value[k] = 0;
}
}
else {
for (int k=0; k<numAttributes; k++) {
value[k] = value[k] / weight;
}
}
return value;
}
/**
* Get the value of metricTraining
*
* @return Value of metricTraining
*/
public boolean getTrainable() {
return m_trainable;
}
/**
* Set the value of metricTraining
*
* @param metricTraining Value of metricTraining
*/
public void setTrainable(boolean metricTraining) {
m_trainable = metricTraining;
}
/**
* Get the value of m_external
*
* @return Value of m_external
*/
public boolean getExternal() {
return m_external;
}
/**
* Set the value of m_external
*
* @param external if true, an external estimator will be used for distance
*/
public void setExternal(boolean external) {
m_external = external;
}
/** Set the number of positive instances to be used for training
* @param numPosInstances the number amounts of positive examples (diff-instances)
*/
public void setNumPosDiffInstances (int numPosInstances) {
m_numPosDiffInstances = numPosInstances;
}
/** Set the number of positive instances to be used for training
* @param numPosInstances the number amounts of positive examples (diff-instances)
*/
public int getNumPosDiffInstances () {
return m_numPosDiffInstances;
}
/** Set the ratio of positive and negative instances to be used for training
* @param ratio the relative amounts of negative examples compared to positive examples.
* If -1, all possible negatives will be used (use with care!)
*/
public void setPosNegDiffInstanceRatio (double ratio) {
m_posNegDiffInstanceRatio = ratio;
}
/** Get the ratio of positive and negative instances to be used for training
* @returns the relative amounts of negative examples compared to positive examples.
* If -1, all possible negatives will be used (use with care!)
*/
public double getPosNegDiffInstanceRatio () {
return m_posNegDiffInstanceRatio;
}
/** get the regularizer value */
public double regularizer() {
return m_regularizerVal;
}
/** recompute the normalizer - L1 by default; children may override */
public void recomputeRegularizer() {
m_regularizerVal = m_regularizer.computeRegularizer(m_attrWeights);
// for (int i = 0; i < m_attrWeights.length; i++) {
// if (m_attrWeights[i] != 0) {
// m_regularizer += 1/Math.abs(m_attrWeights[i]);
// // Removed, since this would encourage making some weights 0, which we want
// // } else {
// // m_regularizer = Double.MAX_VALUE;
// }
// }
}
/** get the normalizer value */
public double getNormalizer() {
return m_normalizer;
}
/** recompute the normalizer - L1 by default; children may override */
public void recomputeNormalizer() {
m_normalizer = 0;
for (int i = 0; i < m_attrWeights.length; i++) {
if (m_attrWeights[i] > 0) {
m_normalizer += Math.log(m_attrWeights[i]);
}
}
}
/** Normalizes the values of an Instance utilizing feature weights
*
* @param inst Instance to be normalized
*/
public void normalizeInstanceWeighted(Instance inst) {
if (inst instanceof SparseInstance) {
double norm = 0;
int classIndex = inst.classIndex();
for (int i = 0; i < inst.numValues(); i++) {
int idx = inst.index(i);
if (idx != classIndex) { // don't normalize the class index
norm += m_attrWeights[idx] * inst.value(idx) * inst.value(idx);
}
}
norm = Math.sqrt(norm);
// System.out.println("norm: " + norm);
for (int i = 0; i < inst.numValues(); i++) {
int idx = inst.index(i);
if (idx != classIndex) {
inst.setValueSparse(i, inst.value(idx)/norm);
}
}
} else { // non-sparse instances
double norm = 0;
double values [] = inst.toDoubleArray();
int classIndex = inst.classIndex();
for (int i=0; i<values.length; i++) {
if (i != classIndex) { // don't normalize the class index
norm += m_attrWeights[i] * values[i] * values[i];
}
}
norm = Math.sqrt(norm);
for (int i=0; i<values.length; i++) {
if (i != classIndex) { // don't normalize the class index
values[i] /= norm;
}
}
inst.setValueArray(values);
}
}
/** Set/get the regularizer */
public void setRegularizer(Regularizer reg) {
m_regularizer = reg;
}
public Regularizer getRegularizer() {
return m_regularizer;
}
public static Metric forName(String metricName,
String [] options) throws Exception {
return (LearnableMetric)Utils.forName(LearnableMetric.class,
metricName,
options);
}
/** Create a copy of this metric */
public Object clone() {
LearnableMetric m = null;
m = (LearnableMetric) super.clone();
// clone the fields
if (m_attrWeights != null) {
m.m_attrWeights = (double []) m_attrWeights.clone();
}
if (m_classifier != null) {
try {
m.m_classifier = Classifier.makeCopies(m_classifier, 1)[0];
} catch (Exception e) {
System.err.println("Problems cloning a non-null classifier in LearnableMetric; this should never be reached");
}
}
return m;
}
}