/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* HardPairwiseSelector.java
* Copyright (C) 2002 Mikhail Bilenko
*
*/
package weka.core.metrics;
import java.util.*;
import java.io.Serializable;
import weka.core.*;
/**
* HardPairwiseSelector class. Given a metric and training data,
* create a set of "difficult" diff-class instance pairs that correspond to metric training data
*
* @author Mikhail Bilenko (mbilenko@cs.utexas.edu)
* @version $Revision: 1.3 $
*/
public class HardPairwiseSelector extends PairwiseSelector implements Serializable, OptionHandler {
public static final int PAIRS_RANDOM = 1;
public static final int PAIRS_HARDEST = 2;
public static final int PAIRS_EASIEST = 4;
public static final int PAIRS_INTERVAL = 8;
public static final Tag[] TAGS_PAIR_SELECTION_MODE = {
new Tag(PAIRS_RANDOM, "Random pairs"),
new Tag(PAIRS_HARDEST, "Hardest pairs"),
new Tag(PAIRS_EASIEST, "Easiest pairs"),
new Tag(PAIRS_INTERVAL, "Pairs in a percentile range")
};
protected int m_positivesMode = PAIRS_RANDOM;
protected int m_negativesMode = PAIRS_RANDOM;
/** We will need this reverse comparator class to get hardest pairs (those with the largest distance */
public class ReverseComparator implements Comparator {
public int compare(Object o1, Object o2) {
Comparable c = (Comparable) o1;
return -1 * c.compareTo(o2);
}
}
/** A default constructor */
public HardPairwiseSelector() {
}
/**
* Provide an array of metric pairs metric using given training instances
*
* @param metric the metric to train
* @param instances data to train the metric on
* @exception Exception if training has gone bad.
*/
public ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs, Metric metric) throws Exception {
ArrayList pairList = new ArrayList();
TreeSet posPairSet = null;
TreeSet negPairSet = null;
double [] posPairDistances = null;
double [] negPairDistances = null;
Iterator iterator = null;
int numActualPositives = 0, numActualNegatives = 0;
// INITIALIZE
initSelector(instances);
System.out.println("m_numPotentialPositives=" + m_numPotentialPositives + "\tm_numPotentialNegatives=" + m_numPotentialNegatives);
// SELECT POSITIVE PAIRS
switch (m_positivesMode) {
case PAIRS_EASIEST:
posPairSet = new TreeSet();
posPairDistances = populatePositivePairSet(metric, posPairSet);
pairList = getUniquePairs(posPairSet, metric, numPosPairs);
break;
case PAIRS_HARDEST:
posPairSet = new TreeSet(new ReverseComparator());
posPairDistances = populatePositivePairSet(metric, posPairSet);
pairList = getUniquePairs(posPairSet, metric, numPosPairs);
break;
case PAIRS_RANDOM:
// go through lists of instances for each class and create a list of *all* positive pairs
ArrayList posPairList = new ArrayList();
iterator = m_classInstanceMap.values().iterator();
while (iterator.hasNext()) {
ArrayList instanceList = (ArrayList) iterator.next();
for (int i = 0; i < instanceList.size(); i++) {
Instance instance1 = (Instance) instanceList.get(i);
for (int j = i+1; j < instanceList.size(); j++) {
Instance instance2 = (Instance) instanceList.get(j);
TrainingPair pair = new TrainingPair(instance1, instance2, true, metric.distance(instance1, instance2));
posPairList.add(pair);
}
}
}
// if we have fewer pairs available than requested, return all the ones that were created
if (posPairList.size() <= numPosPairs) {
pairList = posPairList;
} else { // if we have enough potential pairs, sample randomly with replacement
Random random = new Random();
for (int i = 0; i < numPosPairs; i++) {
int idx = random.nextInt(posPairList.size());
TrainingPair pair = (TrainingPair) posPairList.remove(idx);
pairList.add(pair);
}
}
break;
case PAIRS_INTERVAL:
System.err.println("TODO PAIRS_INTERVAL!!!");
break;
default:
throw new Exception("Unknown method for selecting positive pairs: " + m_positivesMode);
}
numActualPositives = pairList.size();
// SELECT NEGATIVE PAIRS
switch (m_negativesMode) {
case PAIRS_EASIEST:
// Create a map with *all* negatives
negPairSet = new TreeSet(new ReverseComparator());
negPairDistances = populateNegativePairSet(metric, negPairSet);
pairList.addAll(getUniquePairs(negPairSet, metric, numNegPairs));
case PAIRS_HARDEST:
negPairSet = new TreeSet();
negPairDistances = populateNegativePairSet(metric, negPairSet);
pairList.addAll(getUniquePairs(negPairSet, metric, numNegPairs));
break;
case PAIRS_RANDOM: // create all negative pairs and sample randomly
ArrayList negPairList = new ArrayList();
// go through lists of instances for each class
for (int i = 0; i < m_classValueList.size(); i++) {
ArrayList instanceList1 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(i));
for (int j = 0; j < instanceList1.size(); j++) {
Instance instance1 = (Instance) instanceList1.get(j);
// create all pairs from other clusters with this instance
for (int k = i+1; k < m_classValueList.size(); k++) {
ArrayList instanceList2 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(k));
for (int l = 0; l < instanceList2.size(); l++) {
Instance instance2 = (Instance) instanceList2.get(l);
TrainingPair pair = new TrainingPair(instance1, instance2, false, metric.distance(instance1, instance2));
negPairList.add(pair);
}
}
}
}
// if we have fewer pairs available than requested, return all the ones that were created
if (negPairList.size() <= numNegPairs) {
pairList.addAll(negPairList);
} else { // if we have enough potential pairs, randomly sample with replacement
Random random = new Random();
for (int i = 0; i < numNegPairs; i++) {
int idx = random.nextInt(negPairList.size());
TrainingPair pair = (TrainingPair) negPairList.remove(idx);
pairList.add(pair);
}
}
break;
case PAIRS_INTERVAL:
System.err.println("TODO PAIRS_INTERVAL!!!");
break;
default:
throw new Exception("Unknown method for selecting positive pairs: " + m_positivesMode);
}
numActualNegatives = pairList.size() - numActualPositives;
System.out.println();
System.out.println("POSITIVES: requested=" + numPosPairs + "\tpossible=" + m_numPotentialPositives +
"\tactual=" + numActualPositives);
System.out.println("NEGATIVES: requested=" + numNegPairs + "\tpossible=" + m_numPotentialNegatives +
"\tactual=" + numActualNegatives);
return pairList;
}
/** This helper method goes through a TreeSet containing sorted TrainingPairs
* and returns a list of unique pairs
* @param pairSet a sorted set of TrainingPair's
* @param metric the metric that is used for creating DiffInstance's
* @param numPairs the number of desired pairs
* @return a list with training pairs
*/
protected ArrayList getUniquePairs(TreeSet pairSet, Metric metric, int numPairs) {
ArrayList pairList = new ArrayList();
HashMap checksumMap = new HashMap();
Iterator iterator = pairSet.iterator();
for (int i = 0; iterator.hasNext() && i < numPairs; i++) {
TrainingPair pair = (TrainingPair) iterator.next();
if (metric instanceof LearnableMetric) {
Instance diffInstance = ((LearnableMetric)metric).createDiffInstance(pair.instance1, pair.instance2);
double checksum = 0;
for (int j = 0; j < diffInstance.numValues(); j++) {
checksum += j*17 * diffInstance.value(j);
}
// round off to help with machine precision errors
checksum = (float) checksum;
// if this checksum was encountered before, get a list of instances
// that have this checksum, and check if any of them are dupes of this one
if (checksumMap.containsKey(new Double(checksum))) {
ArrayList checksumList = (ArrayList) checksumMap.get(new Double(checksum));
System.out.println("Collision for " + checksum + ": " + checksumList.size());
boolean unique = true;
for (int k = 0; k < checksumList.size() && unique; k++) {
Instance nextDiffInstance = (Instance) checksumList.get(k);
unique = false;
for (int l = 0; l < nextDiffInstance.numValues() && !unique; l++) {
if (((float)nextDiffInstance.value(l)) != ((float)diffInstance.value(l))) {
unique = true;
}
}
if (!unique) {
// This is a dupe!
System.out.println("Dupe!");
i--;
break;
}
}
if (unique) {
pairList.add(pair);
checksumList.add(diffInstance);
}
} else { // this checksum has not been encountered before
pairList.add(pair);
ArrayList checksumList = new ArrayList();
checksumList.add(diffInstance);
checksumMap.put(new Double(checksum), checksumList);
}
} else { // this is not a LearnableMetric
pairList.add(pair);
}
}
return pairList;
}
/** Add a pair to the set so that there are no collisions
* @param set a set to which a new pair should be added
* @param pair a new pair that is to be added; value is the distance between the instances
* @return the unique value of the distance (possibly perturbed) with which the pair was added
*/
protected double addUniquePair(TreeSet set, TrainingPair pair) {
Random random = new Random();
double epsilon = 0.00001;
int counter = 0;
while (set.contains(pair)) {
double perturbation;
if (pair.value == 0) {
perturbation = Double.MIN_VALUE * random.nextInt(m_numPotentialPositives);
} else {
perturbation = pair.value * epsilon * ((random.nextDouble() > 0.5) ? 1 : -1);
}
pair.value += perturbation;
counter++;
if (counter % 10 == 0) {
epsilon *= 10;
}
}
set.add(pair);
return pair.value;
}
/** Populate a treeset with all positive TrainingPair's
* @param metric a metric that will be used to calculate distance
* @param pairSet an empty set that will be populated
* @return an array with distance values of the created pairs
*/
protected double[] populatePositivePairSet(Metric metric, TreeSet pairSet) throws Exception {
// Create a map with *all* positives
double [] posPairDistances = new double[m_numPotentialPositives];
int posCounter = 0;
// go through lists of instances for each class
Iterator iterator = m_classInstanceMap.values().iterator();
while (iterator.hasNext()) {
ArrayList instanceList = (ArrayList) iterator.next();
for (int i = 0; i < instanceList.size(); i++) {
Instance instance1 = (Instance) instanceList.get(i);
for (int j = i+1; j < instanceList.size(); j++) {
Instance instance2 = (Instance) instanceList.get(j);
TrainingPair pair = new TrainingPair(instance1, instance2, true, metric.distance(instance1, instance2));
// add the pair to the set
posPairDistances[posCounter++] = addUniquePair(pairSet, pair);
}
}
}
return posPairDistances;
}
/** Populate a treeset with all negative TrainingPair's
* @param metric a metric that will be used to calculate distance
* @param pairSet an empty set that will be populated
* @return an array with distance values of the created pairs
*/
protected double[] populateNegativePairSet(Metric metric, TreeSet pairSet) throws Exception {
double [] negPairDistances = new double[m_numPotentialNegatives];
int negCounter = 0;
// go through lists of instances for each class
for (int i = 0; i < m_classValueList.size(); i++) {
ArrayList instanceList1 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(i));
for (int j = 0; j < instanceList1.size(); j++) {
Instance instance1 = (Instance) instanceList1.get(j);
for (int k = i+1; k < m_classValueList.size(); k++) {
ArrayList instanceList2 = (ArrayList) m_classInstanceMap.get(m_classValueList.get(k));
for (int l = 0; l < instanceList2.size(); l++) {
Instance instance2 = (Instance) instanceList2.get(l);
TrainingPair pair = new TrainingPair(instance1, instance2, false, metric.distance(instance1, instance2));
negPairDistances[negCounter++] = addUniquePair(pairSet, pair);
}
}
}
}
return negPairDistances;
}
/** Given a set, return a TreeSet whose items are accessed in descending order
* @param set any set containing Comparable objects
* @return a new ordered set with those objects in reverse order
*/
public TreeSet reverseCopy(Set set) {
TreeSet reverseSet = new TreeSet(new ReverseComparator());
reverseSet.addAll(set);
return reverseSet;
}
/** Set the selection mode for positives
* @param mode selection mode
*/
public void setPositivesMode(SelectedTag mode) {
if (mode.getTags() == TAGS_PAIR_SELECTION_MODE) {
m_positivesMode = mode.getSelectedTag().getID();
}
}
/**
* return the selection mode for positives
* @return one of the selection modes
*/
public SelectedTag getPositivesMode() {
return new SelectedTag(m_positivesMode, TAGS_PAIR_SELECTION_MODE);
}
/** Set the selection mode for negatives
* @param mode selection mode
*/
public void setNegativesMode(SelectedTag mode) {
if (mode.getTags() == TAGS_PAIR_SELECTION_MODE) {
m_negativesMode = mode.getSelectedTag().getID();
}
}
/**
* return the selection mode for negatives
* @return one of the selection modes
*/
public SelectedTag getNegativesMode() {
return new SelectedTag(m_negativesMode, TAGS_PAIR_SELECTION_MODE);
}
/**
* Gets the current settings of WeightedDotP.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String [] getOptions() {
String [] options = new String [5];
int current = 0;
options[current++] = "-P";
switch(m_positivesMode) {
case PAIRS_RANDOM:
options[current++] = "-r";
break;
case PAIRS_HARDEST:
options[current++] = "-h";
break;
case PAIRS_EASIEST:
options[current++] = "-e";
break;
case PAIRS_INTERVAL:
options[current++] = "-i";
break;
}
options[current++] = "-N";
switch(m_negativesMode) {
case PAIRS_RANDOM:
options[current++] = "-r";
break;
case PAIRS_HARDEST:
options[current++] = "-h";
break;
case PAIRS_EASIEST:
options[current++] = "-e";
break;
case PAIRS_INTERVAL:
options[current++] = "-i";
break;
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Parses a given list of options. Valid options are:<p>
*
*/
public void setOptions(String[] options) throws Exception {
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(0);
return newVector.elements();
}
/**
* get an array of numIdxs random indeces out of n possible values.
* if the number of requested indeces is larger then maxIdx, returns
* maxIdx permuted values
* @param maxIdx - the maximum index of the set
* @param numIdxs number of indexes to return
* @return an array of indexes
*/
public static int[] randomSubset(int numIdxs, int maxIdx) {
Random r = new Random(maxIdx + numIdxs);
int[] indexes = new int[maxIdx];
for (int i = 0; i < maxIdx; i++) {
indexes[i] = i;
}
// permute the indeces randomly
for (int i = 0; i < maxIdx; i++) {
int idx = r.nextInt (maxIdx - i);
int temp = indexes[i + idx];
indexes[i + idx] = indexes[i];
indexes[i] = temp;
}
int []returnIdxs = new int[Math.min(numIdxs,maxIdx)];
for (int i = 0; i < returnIdxs.length; i++) {
returnIdxs[i] = indexes[i];
}
return returnIdxs;
}
}