/* Copyright 2006, Carnegie Mellon, All Rights Reserved */
package edu.cmu.minorthird.classify.ranking;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.algorithms.linear.Hyperplane;
import edu.cmu.minorthird.util.ProgressCounter;
/**
* A ranking method based on a voted perceptron.
*/
public class RankingPerceptron extends BatchRankingLearner
{
private int numEpochs;
private static final double MARGIN = 0.1;
public RankingPerceptron()
{
this(100);
}
public RankingPerceptron(int numEpochs)
{
this.numEpochs=numEpochs;
}
@Override
public Classifier batchTrain(Dataset data)
{
Hyperplane h = new Hyperplane();
Hyperplane s = new Hyperplane();
int numUpdates = 0;
Map<String,List<Example>> rankingMap = listsWithOneExampleEach( splitIntoRankings(data) );
//Map rankingMap = splitIntoRankings(data);
ProgressCounter pc = new ProgressCounter("perceptron training", "epoch", numEpochs);
for (int e=0; e<numEpochs; e++) {
//System.out.println("epoch "+e+"/"+numEpochs);
for (Iterator<String> i=rankingMap.keySet().iterator(); i.hasNext(); ) {
String subpop = i.next();
List<Example> ranking = rankingMap.get(subpop);
numUpdates += batchTrainSubPop( h, s, ranking );
}
pc.progress();
}
pc.finished();
// turn sum hyperplane into an average
s.multiply( 1.0/(numUpdates) );
//new ViewerFrame("hyperplane", s.toGUI());
return s;
}
// return the number of times h has been updated
private int batchTrainSubPop( Hyperplane h, Hyperplane s, List<Example> ranking )
{
sortByScore(h,ranking);
int updates = 0;
// int highestNegativeIndex = ranking.size();
Example highestNegativeExample = null;
for (int i=0; i<ranking.size(); i++) {
Example exi = ranking.get(i);
if (exi.getLabel().isNegative()) {
// highestNegativeIndex = i;
highestNegativeExample = ranking.get(i);
break;
}
}
// look for positive example, update
for (int i=0; i<ranking.size(); i++) {
Example exi = ranking.get(i);
if (exi.getLabel().isPositive()) {
if (highestNegativeExample!=null && (h.score(exi) < h.score(highestNegativeExample)+MARGIN)) {
//if (i>highestNegativeIndex) {
// the positive example is ranked below the
// highestNegativeExample, which is incorrect
Example pos = ranking.get(i);
h.increment( highestNegativeExample, -1.0);
h.increment( pos, +1.0 );
}
s.increment( h );
updates++;
}
}
return updates;
}
}