/*
* Ivory: A Hadoop toolkit for web-scale information retrieval
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package ivory.ltr;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
/**
* @author Don Metzler
*
*/
public class LineSearch implements Callable<Map<Feature,AlphaMeasurePair>> {
private static final double SCALE_FACTOR = 0.01;
private static final int MAX_STEPS = 5;
private static final double MULTIPLIER = Math.pow((1.0 / SCALE_FACTOR), (1.0 / MAX_STEPS));
private Model model; // model
private List<Feature> features; // features
private ScoreTable scores; // score table (instances, scores)
private Measure measure; // evaluation metric
public LineSearch(Model model, List<Feature> features, ScoreTable scoreTable, Measure evaluator) {
this.model = model;
this.features = features;
this.scores = scoreTable;
this.measure = evaluator;
}
public Map<Feature,AlphaMeasurePair> call() throws Exception {
Map<Feature,AlphaMeasurePair> results = new HashMap<Feature,AlphaMeasurePair>();
for (Feature f : features) {
results.put(f, lineSearch(model, f, scores, measure));
}
return results;
}
public static AlphaMeasurePair lineSearch(Model model, Feature feature, ScoreTable scores, Measure measure) {
AlphaMeasurePair bestAlphaMeasure;
if (model.getNumFeatures() == 0) {
ScoreTable newScoreTable = scores.translate(feature, 1.0, 1.0);
double m = measure.evaluate(newScoreTable);
System.err.println("Feature: " + feature.getName() + ", Measure: " + m);
return new AlphaMeasurePair(1.0, m);
}
bestAlphaMeasure = new AlphaMeasurePair(0.0, measure.evaluate(scores));
double alpha;
double maxWeight = model.getMaxWeight();
alpha = maxWeight * SCALE_FACTOR;
for (int iter = 0; iter < MAX_STEPS; iter++) {
ScoreTable newScoreTable = scores.translate(feature, alpha, 1.0);
double m = measure.evaluate(newScoreTable);
// System.err.println("Alpha: " + alpha + ", ERR: " + measure);
if (m < bestAlphaMeasure.alpha) {
break;
}
if (m > bestAlphaMeasure.measure) {
bestAlphaMeasure.alpha = alpha;
bestAlphaMeasure.measure = m;
}
alpha *= MULTIPLIER;
}
if(bestAlphaMeasure.alpha != 0.0) {
System.err.println("Feature: " + feature.getName() + ", Measure: " + bestAlphaMeasure.measure);
return bestAlphaMeasure;
}
alpha = maxWeight * SCALE_FACTOR;
for (int iter = 0; iter < MAX_STEPS; iter++) {
ScoreTable newScoreTable = scores.translate(feature, -alpha, 1.0);
double m = measure.evaluate(newScoreTable);
// System.err.println("Alpha: " + alpha + ", ERR: " + measure);
if (m < bestAlphaMeasure.alpha) {
break;
}
if (m > bestAlphaMeasure.measure) {
bestAlphaMeasure.alpha = -alpha;
bestAlphaMeasure.measure = m;
}
alpha *= MULTIPLIER;
}
// double maxWeight = model.getMaxWeight();
//
// alpha = maxWeight;
// for (int iter = 0; iter < 5; iter++) {
// ScoreTable newScoreTable = scores.translate(feature, alpha, 1.0);
// double m = measure.evaluate(newScoreTable);
// // System.err.println("Alpha: " + alpha + ", ERR: " + measure);
// if (m > bestAlphaMeasure.measure) {
// bestAlphaMeasure.alpha = alpha;
// bestAlphaMeasure.measure = m;
// }
// alpha *= 0.1;
// }
//
// alpha = maxWeight;
// for (int iter = 0; iter < 5; iter++) {
// ScoreTable newScoreTable = scores.translate(feature, -alpha, 1.0);
// double m = measure.evaluate(newScoreTable);
// // System.err.println("Alpha: " + alpha + ", ERR: " + measure);
// if (m > bestAlphaMeasure.measure) {
// bestAlphaMeasure.alpha = -alpha;
// bestAlphaMeasure.measure = m;
// }
// alpha *= 0.1;
// }
System.err.println("Feature: " + feature.getName() + ", Measure: " + bestAlphaMeasure.measure);
return bestAlphaMeasure;
}
}