/* Copyright 2006, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.ranking;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Instance;
/**
* Learns to rank sets of examples, rather than to classify individual examples.
*
* Data is presented to a BatchRankingLearner as an ordinary binary
* dataset. Examples from the same subpopulation are comparable, and
* should be ranked so that positive examples have a higher score than
* negative examples.
*
* @author William Cohen
*/
public abstract class BatchRankingLearner extends BatchBinaryClassifierLearner{
/** Sort a dataset into 'rankings'. Each ranking is a List of
* Examples such that all positive examples in the list should be
* ranked above all negative examples. Returns a map so that
* map.get(key) is an ArrayList of examples.
*/
static public Map<String,List<Example>> splitIntoRankings(Dataset data){
Map<String,List<Example>> map=new HashMap<String,List<Example>>();
for(Iterator<Example> i=data.iterator();i.hasNext();){
Example ex=i.next();
List<Example> list=map.get(ex.getSubpopulationId());
if(list==null)
map.put(ex.getSubpopulationId(),(list=new ArrayList<Example>()));
list.add(ex);
}
return map;
}
/**
* Split a Map output by splitIntoRankings into lists that contain
* exactly one positive example each.
*/
static public Map<String,List<Example>> listsWithOneExampleEach(Map<String,List<Example>> rankingLists){
Map<String,List<Example>> map1=new HashMap<String,List<Example>>();
for(Iterator<String> i=rankingLists.keySet().iterator();i.hasNext();){
String key=i.next();
List<Example> posExamples=new ArrayList<Example>();
List<Example> negExamples=new ArrayList<Example>();
List<Example> ranking=rankingLists.get(key);
for(int j=0;j<ranking.size();j++){
Example exi=ranking.get(j);
if(exi.getLabel().isPositive()){
posExamples.add(exi);
}else{
negExamples.add(exi);
}
}
for(int j=0;j<posExamples.size();j++){
Example exi=posExamples.get(j);
List<Example> ranking1=new ArrayList<Example>();
ranking1.addAll(negExamples);
ranking1.add(exi);
map1.put(key+"."+j,ranking1);
}
}
return map1;
}
/** Sort a List of Instances by score according to the classifier. */
static public void sortByScore(final BinaryClassifier c,List<Example> data){
Collections.sort(data,new Comparator<Example>(){
@Override
public int compare(Example instA,Example instB){
double diff=c.score(instB)-c.score(instA);
int cmp=diff>0?+1:(diff<0?-1:0);
if(cmp!=0)
return cmp;
// rather than be random, sort negative examples
// above positive if scores are the same
if((instA instanceof Example)&&(instB instanceof Example)){
Example exA=instA;
Example exB=instB;
return (int)(exA.getLabel().numericLabel()-exB.getLabel()
.numericLabel());
}else{
return 0;
}
}
});
}
}