/*
* Ivory: A Hadoop toolkit for web-scale information retrieval
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package ivory.ltr;
import ivory.core.ConfigurationException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
/**
* @author Don Metzler
*
*/
public class GreedyLearn {
private static final double TOLERANCE = 0.0001;
public void train(String featFile, String modelOutputFile, int numModels, String metricClassName, boolean pruneCorrelated, double correlationThreshold, boolean logFeatures, boolean productFeatures, boolean quotientFeatures, int numThreads) throws IOException, InterruptedException, ExecutionException, ConfigurationException, InstantiationException, IllegalAccessException, ClassNotFoundException {
// read training instances
Instances trainInstances = new Instances(featFile);
// get feature map (mapping from feature names to feature number)
Map<String,Integer> featureMap = trainInstances.getFeatureMap();
// construct initial model
Model initialModel = new Model();
// initialize feature pools
Map<Model,ArrayList<Feature>> featurePool = new HashMap<Model,ArrayList<Feature>>();
featurePool.put(initialModel, new ArrayList<Feature>());
// add simple features to feature pools
for(String featureName : featureMap.keySet()) {
featurePool.get(initialModel).add(new SimpleFeature(featureMap.get(featureName), featureName));
}
// eliminate document-independent features
List<Feature> constantFeatures = new ArrayList<Feature>();
for(int i = 0; i < featurePool.size(); i++) {
Feature f = featurePool.get(initialModel).get(i);
if(trainInstances.featureIsConstant(f)) {
System.err.println("Feature " + f.getName() + " is constant -- removing from feature pool!");
constantFeatures.add(f);
}
}
featurePool.get(initialModel).removeAll(constantFeatures);
// initialize score tables
Map<Model,ScoreTable> scoreTable = new HashMap<Model,ScoreTable>();
scoreTable.put(initialModel, new ScoreTable(trainInstances));
// initialize model queue
List<Model> models = new ArrayList<Model>();
models.add(initialModel);
// set up threading
ExecutorService threadPool = Executors.newFixedThreadPool(numThreads);
Map<Model,ArrayList<ArrayList<Feature>>> featureBatches = new HashMap<Model,ArrayList<ArrayList<Feature>>>();
featureBatches.put(initialModel, new ArrayList<ArrayList<Feature>>());
for(int i = 0; i < numThreads; i++) {
featureBatches.get(initialModel).add(new ArrayList<Feature>());
}
for(int i = 0; i < featurePool.get(initialModel).size(); i++) {
featureBatches.get(initialModel).get(i % numThreads).add(featurePool.get(initialModel).get(i));
}
// greedily add features
double curMetric = 0.0;
double prevMetric = Double.NEGATIVE_INFINITY;
int iter = 1;
while(curMetric - prevMetric > TOLERANCE ) {
Map<ModelFeaturePair,AlphaMeasurePair> modelFeaturePairMeasures = new HashMap<ModelFeaturePair,AlphaMeasurePair>();
// update models
for(Model model : models) {
List<Future<Map<Feature,AlphaMeasurePair>>> futures = new ArrayList<Future<Map<Feature,AlphaMeasurePair>>>();
for(int i = 0; i < numThreads; i++) {
// construct measure
Measure metric = (Measure)Class.forName(metricClassName).newInstance();
// line searcher
LineSearch search = new LineSearch(model, featureBatches.get(model).get(i), scoreTable.get(model), metric);
Future<Map<Feature,AlphaMeasurePair>> future = threadPool.submit(search);
futures.add(future);
}
for(int i = 0; i < numThreads; i++) {
Map<Feature,AlphaMeasurePair> featAlphaMeasureMap = futures.get(i).get();
for(Feature f : featAlphaMeasureMap.keySet()) {
AlphaMeasurePair featAlphaMeasure = featAlphaMeasureMap.get(f);
modelFeaturePairMeasures.put(new ModelFeaturePair(model, f), featAlphaMeasure);
}
}
}
// sort model-feature pairs
List<ModelFeaturePair> modelFeaturePairs = new ArrayList<ModelFeaturePair>(modelFeaturePairMeasures.keySet());
Collections.sort(modelFeaturePairs, new ModelFeatureComparator(modelFeaturePairMeasures));
// preserve current list of models
List<Model> oldModels = new ArrayList<Model>(models);
// add best model feature pairs to pool
models = new ArrayList<Model>();
//Lidan: here consider top-K features, rather than just the best one
for(int i = 0; i < numModels; i++) {
Model model = modelFeaturePairs.get(i).model;
Feature feature = modelFeaturePairs.get(i).feature;
String bestFeatureName = feature.getName();
AlphaMeasurePair bestAlphaMeasure = modelFeaturePairMeasures.get(modelFeaturePairs.get(i));
System.err.println("Model = " + model);
System.err.println("Best feature: " + bestFeatureName);
System.err.println("Best alpha: " + bestAlphaMeasure.alpha);
System.err.println("Best measure: " + bestAlphaMeasure.measure);
Model newModel = new Model(model);
models.add(newModel);
ArrayList<ArrayList<Feature>> newFeatureBatch = new ArrayList<ArrayList<Feature>>();
for(ArrayList<Feature> fb : featureBatches.get(model)) {
newFeatureBatch.add(new ArrayList<Feature>(fb));
}
featureBatches.put(newModel, newFeatureBatch);
featurePool.put(newModel, new ArrayList<Feature>(featurePool.get(model)));
// add auxiliary features (for atomic features only)
if(featureMap.containsKey(bestFeatureName)) {
int bestFeatureIndex = featureMap.get(bestFeatureName);
// add log features, if requested
if(logFeatures) {
Feature logFeature = new LogFeature(bestFeatureIndex, "log(" + bestFeatureName + ")");
featureBatches.get(newModel).get(bestFeatureIndex % numThreads).add(logFeature);
featurePool.get(newModel).add(logFeature);
}
// add product features, if requested
if(productFeatures) {
for(String featureNameB : featureMap.keySet()) {
int indexB = featureMap.get(featureNameB);
Feature prodFeature = new ProductFeature(bestFeatureIndex, indexB, bestFeatureName + "*" + featureNameB);
featureBatches.get(newModel).get(indexB % numThreads).add(prodFeature);
featurePool.get(newModel).add(prodFeature);
}
}
// add quotient features, if requested
if(quotientFeatures) {
for(String featureNameB : featureMap.keySet()) {
int indexB = featureMap.get(featureNameB);
Feature divFeature = new QuotientFeature(bestFeatureIndex, indexB, bestFeatureName + "/" + featureNameB);
featureBatches.get(newModel).get(indexB % numThreads).add(divFeature);
featurePool.get(newModel).add(divFeature);
}
}
}
// prune highly correlated features
if(pruneCorrelated) {
if(!newModel.containsFeature(feature)) {
List<Feature> correlatedFeatures = new ArrayList<Feature>();
for(Feature f : featurePool.get(newModel)) {
if(f == feature) {
continue;
}
double correl = trainInstances.getCorrelation(f, feature);
if(correl > correlationThreshold) {
System.err.println("Pruning highly correlated feature: " + f.getName());
correlatedFeatures.add(f);
}
}
for(ArrayList<Feature> batch : featureBatches.get(newModel)) {
batch.removeAll(correlatedFeatures);
}
featurePool.get(newModel).removeAll(correlatedFeatures);
}
}
// update score table
if(iter == 0) {
scoreTable.put(newModel, scoreTable.get(model).translate(feature, 1.0, 1.0));
newModel.addFeature(feature, 1.0);
}
else {
scoreTable.put(newModel, scoreTable.get(model).translate(feature, bestAlphaMeasure.alpha, 1.0 / (1.0 + bestAlphaMeasure.alpha)));
newModel.addFeature(feature, bestAlphaMeasure.alpha);
}
}
for(Model model : oldModels) {
featurePool.remove(model);
featureBatches.remove(model);
scoreTable.remove(model);
}
// update metrics
prevMetric = curMetric;
curMetric = modelFeaturePairMeasures.get(modelFeaturePairs.get(0)).measure;
iter++;
}
// serialize model
System.out.println("Final Model: " + models.get(0));
models.get(0).write(modelOutputFile);
threadPool.shutdown();
}
public class ModelFeaturePair {
public Model model;
public Feature feature;
public ModelFeaturePair(Model m, Feature f) {
model = m;
feature = f;
}
}
public class ModelFeatureComparator implements Comparator<ModelFeaturePair> {
private Map<ModelFeaturePair,AlphaMeasurePair> lookup = null;
public ModelFeatureComparator(Map<ModelFeaturePair,AlphaMeasurePair> lookup) {
this.lookup = lookup;
}
public int compare(ModelFeaturePair o1, ModelFeaturePair o2) {
if(lookup.get(o1).measure > lookup.get(o2).measure) {
return -1;
}
else if(lookup.get(o1).measure < lookup.get(o2).measure) {
return 1;
}
return 0;
}
}
@SuppressWarnings("static-access")
public static void main(String[] args) throws InterruptedException, ExecutionException {
Options options = new Options();
options.addOption( OptionBuilder.withArgName("input").hasArg().withDescription("Input file that contains training instances.").isRequired().create("input") );
options.addOption( OptionBuilder.withArgName("model").hasArg().withDescription("Model file to create.").isRequired().create("model") );
options.addOption( OptionBuilder.withArgName("numModels").hasArg().withDescription("Number of models to consider each iteration (default=1).").create("numModels") );
options.addOption( OptionBuilder.withArgName("className").hasArg().withDescription("Java class name of metric to optimize for (default=ivory.ltr.NDCGMeasure)").create("metric") );
options.addOption( OptionBuilder.withArgName("threshold").hasArg().withDescription("Feature correlation threshold for pruning (disabled by default).").create("pruneCorrelated") );
options.addOption( OptionBuilder.withArgName("log").withDescription("Include log features (default=false).").create("log") );
options.addOption( OptionBuilder.withArgName("product").withDescription("Include product features (default=false).").create("product") );
options.addOption( OptionBuilder.withArgName("quotient").withDescription("Include quotient features (default=false).").create("quotient") );
options.addOption( OptionBuilder.withArgName("numThreads").hasArg().withDescription("Number of threads to utilize (default=1).").create("numThreads") );
HelpFormatter formatter = new HelpFormatter();
CommandLineParser parser = new GnuParser();
String trainFile = null;
String modelOutputFile = null;
int numModels = 1;
String metricClassName = "ivory.ltr.NDCGMeasure";
boolean pruneCorrelated = false;
double correlationThreshold = 1.0;
boolean logFeatures = false;
boolean productFeatures = false;
boolean quotientFeatures = false;
int numThreads = 1;
// parse the command-line arguments
try {
CommandLine line = parser.parse( options, args);
if(line.hasOption("input")) {
trainFile = line.getOptionValue("input");
}
if(line.hasOption("model")) {
modelOutputFile = line.getOptionValue("model");
}
if(line.hasOption("numModels")) {
numModels = Integer.parseInt(line.getOptionValue("numModels"));
}
if(line.hasOption("metric")) {
metricClassName = line.getOptionValue("metric");
}
if(line.hasOption("pruneCorrelated")) {
pruneCorrelated = true;
correlationThreshold = Double.parseDouble(line.getOptionValue("pruneCorrelated"));
}
if(line.hasOption("numThreads")) {
numThreads = Integer.parseInt(line.getOptionValue("numThreads"));
}
if(line.hasOption("log")) {
logFeatures = true;
}
if(line.hasOption("product")) {
productFeatures = true;
}
if(line.hasOption("quotient")) {
quotientFeatures = true;
}
}
catch(ParseException exp) {
System.err.println(exp.getMessage());
}
// were all of the required parameters specified?
if(trainFile == null || modelOutputFile == null) {
formatter.printHelp("GreedyLearn", options, true);
System.exit(-1);
}
// learn the model
try {
GreedyLearn learn = new GreedyLearn();
learn.train(trainFile, modelOutputFile, numModels, metricClassName, pruneCorrelated, correlationThreshold, logFeatures, productFeatures, quotientFeatures, numThreads);
} catch (IOException e) {
e.printStackTrace();
} catch (ConfigurationException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}