package de.tud.inf.operator.learner.regressionensemble;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleReader;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.AttributeValueFilter;
import com.rapidminer.example.set.Condition;
import com.rapidminer.example.set.ConditionedExampleSet;
import com.rapidminer.example.set.SortedExampleReader;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.ExampleTable;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeFile;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeStringCategory;
import de.tud.inf.support.TreeMultiMap;
public class BatchEnsembleRegression extends AbstractMetaLearner {
private static final String ENSEMBLE_MAX_MEMBERS = "maximum members";
private static final String ENSEMBLE_LOCAL_THRESHOLD = "local threshold";
private static final String ENSEMBLE_SIMILARITY_MEASURE = "similarity measure";
private static final String ENSEMBLE_SLIDING_TEST = "test on training window";
private static final String ENSEMBLE_LEAVE_OUT_SIZE = "percentage of test examples";
private static final String ENSEMBLE_STATE_FILE = "state file";
private static final String ENSEMBLE_FULL_MATERIALIZED = "materialize full";
private static final String ENSEMBLE_PENALTY_WEIGHT = "penalty weight";
private static final String ENSEMBLE_SAMPLING_INTERVAL = "sampling interval";
private static final String ENSEMBLE_USE_THREADS = "use threads";
private MemberGatherer gatherer = null;
private Object lock = new Object();
public BatchEnsembleRegression(OperatorDescription description) {
super(description);
}
public Model learn(ExampleSet exampleSet) throws OperatorException {
EnsembleRegressionModel ensemble = null;
Operator learner = this.getOperator(0);
File state_file = getParameterAsFile(ENSEMBLE_STATE_FILE);
int max_members = getParameterAsInt(ENSEMBLE_MAX_MEMBERS);
int sampling_interval = getParameterAsInt(ENSEMBLE_SAMPLING_INTERVAL);
int leaf_out_size = getParameterAsInt(ENSEMBLE_LEAVE_OUT_SIZE);
double local_threshold = getParameterAsDouble(ENSEMBLE_LOCAL_THRESHOLD);
double penalty_weight = getParameterAsDouble(ENSEMBLE_PENALTY_WEIGHT);
boolean sliding_test = getParameterAsBoolean(ENSEMBLE_SLIDING_TEST);
boolean materialize_full = getParameterAsBoolean(ENSEMBLE_FULL_MATERIALIZED);
boolean use_threads = getParameterAsBoolean(ENSEMBLE_USE_THREADS);
String similarity_measure_type = getParameterAsString(ENSEMBLE_SIMILARITY_MEASURE);
Distance distance = null;
switch (ENSEMBLE_SIMILARITY_MEASURES.valueOf(similarity_measure_type)) {
case EuclideanDistance:
distance = new L2Norm();
break;
default:
logError("Unknown distance function");
break;
}
if(use_threads == true) {
try {
ensemble = createEnsembleThreaded(exampleSet, learner, distance, max_members, local_threshold, penalty_weight, sliding_test, leaf_out_size, sampling_interval);
} catch (InterruptedException ie) {
throw new OperatorException("Threading Problem", ie);
}
} else {
if(materialize_full == true) {
ensemble = createEnsembleTotalMaterialized(exampleSet, learner, distance, max_members, local_threshold, penalty_weight, sliding_test, leaf_out_size, sampling_interval);
} else {
ensemble = createEnsembleScanMaterialized(exampleSet, learner, distance, max_members, local_threshold, penalty_weight, sliding_test, leaf_out_size, sampling_interval);
}
}
log("Back in main");
// save the ensemble
OutputStream out = null;
try {
out = new FileOutputStream(state_file);
if(ensemble != null) {
ensemble.write(out);
}
} catch (IOException e) {
throw new UserError(this, e, 303, new Object[] { state_file, e.getMessage() });
} finally {
if (out != null) {
try {
out.close();
} catch (IOException e) {
logError("Cannot close stream to file " + state_file);
}
}
}
log("Write done");
return ensemble;
}
protected class MemberGatherer {
HashMap<Integer, EnsembleMember> candidates;
TreeMultiMap<Double, Integer> ratios;
boolean exceptionOccured;
public MemberGatherer(HashMap<Integer, EnsembleMember> candidates, TreeMultiMap<Double, Integer> ratios) {
this.candidates = candidates;
this.ratios = ratios;
this.exceptionOccured = false;
}
public HashMap<Integer, EnsembleMember> getCandidates() {
return candidates;
}
public TreeMultiMap<Double, Integer> getRatios() {
return ratios;
}
public void flagException() {
exceptionOccured = true;
}
public boolean getExceptionOccured() {
return exceptionOccured;
}
}
protected class EstimateAndTest implements Runnable {
protected ExampleSet exampleSet;
protected Distance distance;
protected ArrayList<Integer> ids;
protected Learner learner;
protected Attribute idAttribute;
protected Attribute predictedLabel;
protected int round;
protected int trainingEndId;
protected boolean sliding_test;
protected int max_members;
protected double local_threshold;
protected double penalty_weight;
protected int leaf_out_size;
protected int sampling_interval;
protected int testBeginId;
protected int maxId;
protected int size;
protected BatchEnsembleRegression parent;
protected EstimateAndTest(
ExampleSet exampleSet,
Distance distance,
ArrayList<Integer> ids,
Learner learner,
int round,
int trainingEndId,
int testBeginId,
int maxId,
Attribute idAttribute,
Attribute predictedLabel,
boolean sliding_test,
int max_members,
double local_threshold,
double penalty_weight,
int leaf_out_size,
int sampling_interval,
int size,
BatchEnsembleRegression parent
) {
this.exampleSet = exampleSet;
this.ids = ids;
this.distance = distance;
this.round = round;
this.idAttribute = idAttribute;
this.trainingEndId = trainingEndId;
this.learner = learner;
this.sliding_test = sliding_test;
this.predictedLabel = predictedLabel;
this.max_members = max_members;
this.local_threshold = local_threshold;
this.penalty_weight = penalty_weight;
this.leaf_out_size = leaf_out_size;
this.sampling_interval = sampling_interval;
this.testBeginId = testBeginId;
this.maxId = maxId;
this.size = size;
this.parent = parent;
}
public void run() {
int currentId = ids.get(round - 1);
String windowConditionExpression = idAttribute.getName() + " >= " + currentId + " && " + idAttribute.getName() + " <= " + trainingEndId;
Condition windowCondition = new AttributeValueFilter(exampleSet, windowConditionExpression);
ConditionedExampleSet window = new ConditionedExampleSet(exampleSet, windowCondition);
PredictionModel model = null;
try {
model = (PredictionModel) learner.learn(window);
} catch(OperatorException oe) {
gatherer.flagException();
oe.printStackTrace();
return;
}
EnsembleMember member = new EnsembleMember();
int positives = 0;
int negatives = 0;
// testing
if (sliding_test == true) {
// sliding test -> test on all values used in training!
// iterate over all examples in the window and test [e.g. batch predict]; result will be written to predictedLabel
try {
model.performPrediction(window, predictedLabel);
} catch (OperatorException oe) {
gatherer.flagException();
oe.printStackTrace();
return;
}
// test on window! --> iterate for classification
Iterator<Example> iter = window.iterator();
while (iter.hasNext()) {
Example example = iter.next();
double label = example.getLabel();
double prediction = example.getNumericalValue(predictedLabel);
double dist = distance.distance(label, prediction);
if(dist < local_threshold) {
positives++;
} else {
negatives++;
}
}
} else {
String testSetConditionExpression = idAttribute.getName() + " >= " + testBeginId + " && " + idAttribute.getName() + " <= " + maxId;
Condition testSetCondition = new AttributeValueFilter(exampleSet, testSetConditionExpression);
ConditionedExampleSet testSet = new ConditionedExampleSet(exampleSet, testSetCondition);
try {
model.performPrediction(testSet, predictedLabel);
} catch (OperatorException oe) {
gatherer.flagException();
oe.printStackTrace();
return;
}
Iterator<Example> iter = testSet.iterator();
while (iter.hasNext()) {
Example example = iter.next();
double label = example.getLabel();
double prediction = example.getNumericalValue(predictedLabel);
double dist = distance.distance(label, prediction);
if(dist < local_threshold) {
positives++;
} else {
negatives++;
}
}
}
//log(Double.toString(dist));
member.setPositive(positives);
member.setNegative(negatives);
member.setIntroducedAt(currentId);
//OPTION correct state init
member.setState(MemberState.STABLE);
member.setModel(model);
//candidates[round - 1] = member;
double ratio = member.getRatio();
double size_penalty = ((double) ((size + 1) - round)) / ((double)size);
double currentPenalizedRatio = penalty_weight * size_penalty + (1 - penalty_weight) * ratio;
synchronized(lock) {
if(parent.getGatherer().getCandidates().size() < max_members) {
// can safely use round as "index" here
parent.getGatherer().getCandidates().put(round, member);
parent.getGatherer().getRatios().put(currentPenalizedRatio, round);
//ensemble_positives += member.getPositive();
} else {
// check weather there is a lower ranking member to replace; always replace in the lowest Key
double leastKey = parent.getGatherer().getRatios().firstKey();
if(leastKey < currentPenalizedRatio) {
// there is at least one entry
int leastMemberIndex = parent.getGatherer().getRatios().get(leastKey).get(0);
// remove it
parent.getGatherer().getRatios().remove(leastKey, leastMemberIndex);
parent.getGatherer().getCandidates().remove(leastMemberIndex);
// now put the new member
parent.getGatherer().getCandidates().put(round, member);
parent.getGatherer().getRatios().put(currentPenalizedRatio, round);
//ensemble_positives += member.getPositive();
}
}
}
}
}
protected EnsembleRegressionModel createEnsembleThreaded(
ExampleSet exampleSet,
Operator learnerOperator,
Distance distance,
int max_members,
double local_threshold,
double penalty_weight,
boolean sliding_test,
int leaf_out_size,
int sampling_interval
) throws OperatorException, InterruptedException {
EnsembleRegressionModel ensemble = new EnsembleRegressionModel(exampleSet);
// candidate set; is maintained while iterating through the windows!
HashMap<Integer, EnsembleMember> candidates = new HashMap<Integer, EnsembleMember>();
// the ratios for maintaining the candidates!
TreeMultiMap<Double, Integer> ratios = new TreeMultiMap<Double, Integer>();
// all threads
ArrayList<Thread> threads = new ArrayList<Thread>();
// access object for the threads
gatherer = new MemberGatherer(candidates, ratios);
Attribute idAttribute = exampleSet.getAttributes().getId();
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
Attribute predictedLabel = createPredictedLabel(exampleSet, labelAttribute);
ArrayList<Integer> ids = getAllIds(exampleSet);
int maxId = ids.get(ids.size() - 1);
Learner learner = (Learner) learnerOperator;
int first = 1;
int last;
int testBeginId;
int trainingEndId;
if (sliding_test == true) {
last = exampleSet.size();
trainingEndId = ids.get(last - 1);
testBeginId = 0; // does not matter
//testBeginId = ids.get(last);
} else {
int leaf_out_count = exampleSet.size() * leaf_out_size / 100;
if(leaf_out_count == 0) {
//FIXME
leaf_out_count = 1;
}
//log(Integer.toString(leaf_out_count));
last = exampleSet.size() - leaf_out_count;
trainingEndId = ids.get(last - 1);
testBeginId = ids.get(last);
}
int size = last - first + 1;
threads.ensureCapacity(size);
// build the threads
for(int round = last; round >= first; round = round - sampling_interval) {
EstimateAndTest runnable = new EstimateAndTest(
exampleSet,
distance,
ids,
learner,
round,
trainingEndId,
testBeginId,
maxId,
idAttribute,
predictedLabel,
sliding_test,
max_members,
local_threshold,
penalty_weight,
leaf_out_size,
sampling_interval,
size,
this
);
threads.add(new Thread(runnable));
}
// start threads
for(Thread t : threads) {
t.start();
}
// wait for threads
for(Thread t : threads) {
t.join();
}
candidates = gatherer.getCandidates();
ratios = gatherer.getRatios();
if(gatherer.getExceptionOccured() == true) {
throw new OperatorException("Exception in one of the worker threads");
}
int ensemble_positives = 0;
for (Iterator<EnsembleMember> iterator = candidates.values().iterator(); iterator.hasNext();) {
EnsembleMember member = iterator.next();
ensemble_positives += member.getPositive();
ensemble.addMember(member);
}
ensemble.setSeenIds(new HashSet<Integer>(ids));
if(ensemble_positives == 0) {
log("Looser Ensemble encountered");
for (EnsembleMember mem : ensemble) {
mem.setWeight(1/max_members);
}
} else {
for (EnsembleMember mem : ensemble) {
mem.setWeight(mem.getPositive() / ensemble_positives);
log(Double.toString(mem.getRatio()));
log(Double.toString(mem.getIntroducedAt()));
}
}
// retrain on all examples
for (EnsembleMember mem : ensemble) {
String windowConditionExpression = idAttribute.getName() + " >= " + mem.getIntroducedAt() + " && " + idAttribute.getName() + " <= " + maxId;
Condition windowCondition = new AttributeValueFilter(exampleSet, windowConditionExpression);
ConditionedExampleSet window = new ConditionedExampleSet(exampleSet, windowCondition);
PredictionModel model = (PredictionModel) learner.learn(window);
mem.setModel(model);
}
// clean up
removePredictedLabel(exampleSet, predictedLabel);
gatherer = null;
return ensemble;
}
protected EnsembleRegressionModel createEnsembleTotalMaterialized(
ExampleSet exampleSet,
Operator learnerOperator,
Distance distance,
int max_members,
double local_threshold,
double penalty_weight,
boolean sliding_test,
int leaf_out_size,
int sampling_interval
) throws OperatorException {
EnsembleRegressionModel ensemble = new EnsembleRegressionModel(exampleSet);
// candidate set; is materialized fully
HashMap<Integer, EnsembleMember> candidates = new HashMap<Integer, EnsembleMember>();
// EnsembleMember[] candidates = new EnsembleMember[exampleSet.size()];
TreeMultiMap<Double, Integer> ratios = new TreeMultiMap<Double, Integer>();
Attribute idAttribute = exampleSet.getAttributes().getId();
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
Attribute predictedLabel = createPredictedLabel(exampleSet, labelAttribute);
ArrayList<Integer> ids = getAllIds(exampleSet);
int maxId = ids.get(ids.size() - 1);
Learner learner = (Learner) learnerOperator;
int first = 1;
int last;
int testBeginId;
int trainingEndId;
if (sliding_test == true) {
last = exampleSet.size();
trainingEndId = ids.get(last - 1);
testBeginId = 0; // does not matter
//testBeginId = ids.get(last);
} else {
int leaf_out_count = exampleSet.size() * leaf_out_size / 100;
if(leaf_out_count == 0) {
//FIXME
leaf_out_count = 1;
}
//log(Integer.toString(leaf_out_count));
last = exampleSet.size() - leaf_out_count;
trainingEndId = ids.get(last - 1);
testBeginId = ids.get(last);
}
int size = last - first + 1;
for(int round = last; round >= first; round = round - sampling_interval) {
int currentId = ids.get(round - 1);
//String windowConditionExpression = idAttribute.getName() + " >= " + round + " && " + idAttribute.getName() + " <= " + last;
String windowConditionExpression = idAttribute.getName() + " >= " + currentId + " && " + idAttribute.getName() + " <= " + trainingEndId;
Condition windowCondition = new AttributeValueFilter(exampleSet, windowConditionExpression);
//log(windowCondition.toString());
ConditionedExampleSet window = new ConditionedExampleSet(exampleSet, windowCondition);
//log(Integer.toString(window.size()));
PredictionModel model = (PredictionModel) learner.learn(window);
EnsembleMember member = new EnsembleMember();
int positives = 0;
int negatives = 0;
// testing
if (sliding_test == true) {
// sliding test -> test on all values used in training!
// iterate over all examples in the window and test [e.g. batch predict]; result will be written to predictedLabel
model.performPrediction(window, predictedLabel);
// test on window! --> iterate for classification
Iterator<Example> iter = window.iterator();
while (iter.hasNext()) {
Example example = iter.next();
double label = example.getLabel();
double prediction = example.getNumericalValue(predictedLabel);
double dist = distance.distance(label, prediction);
if(dist < local_threshold) {
positives++;
} else {
negatives++;
}
}
} else {
String testSetConditionExpression = idAttribute.getName() + " >= " + testBeginId + " && " + idAttribute.getName() + " <= " + maxId;
Condition testSetCondition = new AttributeValueFilter(exampleSet, testSetConditionExpression);
ConditionedExampleSet testSet = new ConditionedExampleSet(exampleSet, testSetCondition);
model.performPrediction(testSet, predictedLabel);
Iterator<Example> iter = testSet.iterator();
while (iter.hasNext()) {
Example example = iter.next();
double label = example.getLabel();
double prediction = example.getNumericalValue(predictedLabel);
double dist = distance.distance(label, prediction);
if(dist < local_threshold) {
positives++;
} else {
negatives++;
}
}
}
//log(Double.toString(dist));
member.setPositive(positives);
member.setNegative(negatives);
member.setIntroducedAt(currentId);
//OPTION correct state init
member.setState(MemberState.STABLE);
member.setModel(model);
//candidates[round - 1] = member;
candidates.put(round-1, member);
double ratio = member.getRatio();
double size_penalty = ((double) ((size + 1) - round)) / ((double)size);
double key = penalty_weight * size_penalty + (1 - penalty_weight) * ratio;
// put into map of ratios; use corrected index for the member
ratios.put(key, (round-1));
}
// select members
Entry<Double, ArrayList<Integer>> currEntry = ratios.lastEntry();
ArrayList<Integer> currList = currEntry.getValue();
double currKey = currEntry.getKey();
EnsembleMember selected_member = null;
int ensemble_positives = 0;
int selected_members = 0;
int currIndex = 0;
int memberIndex;
while(selected_members < max_members) {
if(currList != null) {
// the index of the member is also round number when added
memberIndex = currList.get(currIndex);
// retrieve the member
//selected_member = candidates[memberIndex];
selected_member = candidates.get(memberIndex);
// maintain number sum of positives in ensemble
ensemble_positives += selected_member.getPositive();
// add the member to the ensemble
ensemble.addMember(selected_member);
selected_members++;
// book keeping
if(currIndex < (currList.size() - 1)) {
currIndex++;
} else {
currEntry = ratios.lowerEntry(currKey);
if(currEntry != null) {
currList = currEntry.getValue();
currKey = currEntry.getKey();
currIndex = 0;
} else {
break;
}
}
} else {
break;
}
}
log(ratios.toString());
ensemble.setSeenIds(new HashSet<Integer>(ids));
if(ensemble_positives == 0) {
log("Looser Ensemble encountered");
for (EnsembleMember mem : ensemble) {
mem.setWeight(1/max_members);
}
} else {
for (EnsembleMember mem : ensemble) {
mem.setWeight(mem.getPositive() / ensemble_positives);
log(Double.toString(mem.getRatio()));
log(Double.toString(mem.getIntroducedAt()));
}
}
// retrain on all examples
for (EnsembleMember mem : ensemble) {
String windowConditionExpression = idAttribute.getName() + " >= " + mem.getIntroducedAt() + " && " + idAttribute.getName() + " <= " + maxId;
Condition windowCondition = new AttributeValueFilter(exampleSet, windowConditionExpression);
ConditionedExampleSet window = new ConditionedExampleSet(exampleSet, windowCondition);
PredictionModel model = (PredictionModel) learner.learn(window);
mem.setModel(model);
}
// clean up
removePredictedLabel(exampleSet, predictedLabel);
return ensemble;
}
protected EnsembleRegressionModel createEnsembleScanMaterialized(
ExampleSet exampleSet,
Operator learnerOperator,
Distance distance,
int max_members,
double local_threshold,
double penalty_weight,
boolean sliding_test,
int leaf_out_size,
int sampling_interval
) throws OperatorException {
EnsembleRegressionModel ensemble = new EnsembleRegressionModel(exampleSet);
// candidate set; is maintained while iterating through the windows!
HashMap<Integer, EnsembleMember> candidates = new HashMap<Integer, EnsembleMember>();
// the ratios for maintaining the candidates!
TreeMultiMap<Double, Integer> ratios = new TreeMultiMap<Double, Integer>();
Attribute idAttribute = exampleSet.getAttributes().getId();
Attribute labelAttribute = exampleSet.getAttributes().getLabel();
Attribute predictedLabel = createPredictedLabel(exampleSet, labelAttribute);
ArrayList<Integer> ids = getAllIds(exampleSet);
int maxId = ids.get(ids.size() - 1);
Learner learner = (Learner) learnerOperator;
int first = 1;
int last;
int testBeginId;
int trainingEndId;
if (sliding_test == true) {
last = exampleSet.size();
trainingEndId = ids.get(last - 1);
testBeginId = 0; // does not matter
//testBeginId = ids.get(last);
} else {
int leaf_out_count = exampleSet.size() * leaf_out_size / 100;
if(leaf_out_count == 0) {
//FIXME
leaf_out_count = 1;
}
//log(Integer.toString(leaf_out_count));
last = exampleSet.size() - leaf_out_count;
trainingEndId = ids.get(last - 1);
testBeginId = ids.get(last);
}
int size = last - first + 1;
for(int round = last; round >= first; round = round - sampling_interval) {
int currentId = ids.get(round - 1);
//String windowConditionExpression = idAttribute.getName() + " >= " + round + " && " + idAttribute.getName() + " <= " + last;
String windowConditionExpression = idAttribute.getName() + " >= " + currentId + " && " + idAttribute.getName() + " <= " + trainingEndId;
Condition windowCondition = new AttributeValueFilter(exampleSet, windowConditionExpression);
//log(windowCondition.toString());
ConditionedExampleSet window = new ConditionedExampleSet(exampleSet, windowCondition);
//log(Integer.toString(window.size()));
PredictionModel model = (PredictionModel) learner.learn(window);
EnsembleMember member = new EnsembleMember();
int positives = 0;
int negatives = 0;
// testing
if (sliding_test == true) {
// sliding test -> test on all values used in training!
// iterate over all examples in the window and test [e.g. batch predict]; result will be written to predictedLabel
model.performPrediction(window, predictedLabel);
// test on window! --> iterate for classification
Iterator<Example> iter = window.iterator();
while (iter.hasNext()) {
Example example = iter.next();
double label = example.getLabel();
double prediction = example.getNumericalValue(predictedLabel);
double dist = distance.distance(label, prediction);
if(dist < local_threshold) {
positives++;
} else {
negatives++;
}
}
} else {
String testSetConditionExpression = idAttribute.getName() + " >= " + testBeginId + " && " + idAttribute.getName() + " <= " + maxId;
Condition testSetCondition = new AttributeValueFilter(exampleSet, testSetConditionExpression);
ConditionedExampleSet testSet = new ConditionedExampleSet(exampleSet, testSetCondition);
model.performPrediction(testSet, predictedLabel);
Iterator<Example> iter = testSet.iterator();
while (iter.hasNext()) {
Example example = iter.next();
double label = example.getLabel();
double prediction = example.getNumericalValue(predictedLabel);
double dist = distance.distance(label, prediction);
if(dist < local_threshold) {
positives++;
} else {
negatives++;
}
}
}
//log(Double.toString(dist));
member.setPositive(positives);
member.setNegative(negatives);
member.setIntroducedAt(currentId);
//OPTION correct state init
member.setState(MemberState.STABLE);
member.setModel(model);
//candidates[round - 1] = member;
double ratio = member.getRatio();
double size_penalty = ((double) ((size + 1) - round)) / ((double)size);
double currentPenalizedRatio = penalty_weight * size_penalty + (1 - penalty_weight) * ratio;
if(candidates.size() < max_members) {
// can safely use round as "index" here
candidates.put(round, member);
ratios.put(currentPenalizedRatio, round);
//ensemble_positives += member.getPositive();
} else {
// check weather there is a lower ranking member to replace; always replace in the lowest Key
double leastKey = ratios.firstKey();
if(leastKey < currentPenalizedRatio) {
// there is at least one entry
int leastMemberIndex = ratios.get(leastKey).get(0);
// maintain ensemble_positices
//ensemble_positives -= candidates.get(leastMemberIndex).getPositive();
// remove it
ratios.remove(leastKey, leastMemberIndex);
candidates.remove(leastMemberIndex);
// now put the new member
candidates.put(round, member);
ratios.put(currentPenalizedRatio, round);
//ensemble_positives += member.getPositive();
}
}
}
int ensemble_positives = 0;
for (Iterator<EnsembleMember> iterator = candidates.values().iterator(); iterator.hasNext();) {
EnsembleMember member = iterator.next();
ensemble_positives += member.getPositive();
ensemble.addMember(member);
}
ensemble.setSeenIds(new HashSet<Integer>(ids));
if(ensemble_positives == 0) {
log("Looser Ensemble encountered");
for (EnsembleMember mem : ensemble) {
mem.setWeight(1/max_members);
}
} else {
for (EnsembleMember mem : ensemble) {
mem.setWeight(mem.getPositive() / ensemble_positives);
log(Double.toString(mem.getRatio()));
log(Double.toString(mem.getIntroducedAt()));
}
}
// retrain on all examples
for (EnsembleMember mem : ensemble) {
String windowConditionExpression = idAttribute.getName() + " >= " + mem.getIntroducedAt() + " && " + idAttribute.getName() + " <= " + maxId;
Condition windowCondition = new AttributeValueFilter(exampleSet, windowConditionExpression);
ConditionedExampleSet window = new ConditionedExampleSet(exampleSet, windowCondition);
PredictionModel model = (PredictionModel) learner.learn(window);
mem.setModel(model);
}
// clean up
removePredictedLabel(exampleSet, predictedLabel);
return ensemble;
}
protected ArrayList<Integer> getAllIds(ExampleSet exampleSet) {
ArrayList<Integer> ids = new ArrayList<Integer>();
ExampleReader reader = new SortedExampleReader(exampleSet);
while(reader.hasNext()) {
Example currExample = reader.next();
ids.add((int) currExample.getId());
}
return ids;
}
protected static Attribute createPredictedLabel(ExampleSet exampleSet, Attribute label) {
// create and add prediction attribute
Attribute predictedLabel = AttributeFactory.createAttribute(label, Attributes.PREDICTION_NAME);
predictedLabel.clearTransformations();
ExampleTable table = exampleSet.getExampleTable();
table.addAttribute(predictedLabel);
exampleSet.getAttributes().setPredictedLabel(predictedLabel);
return predictedLabel;
}
protected static void removePredictedLabel(ExampleSet exampleSet, Attribute predictedLabel) {
exampleSet.getExampleTable().removeAttribute(predictedLabel);
exampleSet.getAttributes().remove(predictedLabel);
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterTypeFile state_file = new ParameterTypeFile(
ENSEMBLE_STATE_FILE,
"path to the ensemble state file",
"mod",
false);
state_file.setExpert(false);
types.add(state_file);
ParameterTypeStringCategory similarity_measure = new ParameterTypeStringCategory(
ENSEMBLE_SIMILARITY_MEASURE,
"similarity measure to use",
getMeasureNames(),
getDefaultMeasureNames());
similarity_measure.setExpert(true);
types.add(similarity_measure);
ParameterTypeInt max_members = new ParameterTypeInt(
ENSEMBLE_MAX_MEMBERS,
"maximum number of members to select into ensemble",
1,
Integer.MAX_VALUE,
5);
max_members.setExpert(false);
types.add(max_members);
ParameterTypeDouble local_threshold = new ParameterTypeDouble(
ENSEMBLE_LOCAL_THRESHOLD,
"local prediction error threshold",
0,
Double.MAX_VALUE,
5);
local_threshold.setExpert(false);
types.add(local_threshold);
ParameterTypeBoolean sliding_test = new ParameterTypeBoolean(
ENSEMBLE_SLIDING_TEST,
"test member candidates on training set (true) or on leaf out set (false)",
true
);
sliding_test.setExpert(false);
types.add(sliding_test);
ParameterTypeInt leaf_out_size = new ParameterTypeInt(
ENSEMBLE_LEAVE_OUT_SIZE,
"percentage of examples to leaf out of training for testing",
0,
100,
10
);
leaf_out_size.setExpert(false);
types.add(leaf_out_size);
ParameterTypeBoolean materialize_full = new ParameterTypeBoolean(
ENSEMBLE_FULL_MATERIALIZED,
"materialize all possible members",
true
);
materialize_full.setExpert(false);
types.add(materialize_full);
ParameterTypeDouble penalty_weight = new ParameterTypeDouble(
ENSEMBLE_PENALTY_WEIGHT,
"weight of the window legth penalty term",
0.0,
1.0,
0.0
);
penalty_weight.setExpert(false);
types.add(penalty_weight);
ParameterTypeInt sampling_interval = new ParameterTypeInt(
ENSEMBLE_SAMPLING_INTERVAL,
"interval between successively tested windows",
1,
Integer.MAX_VALUE,
1
);
sampling_interval.setExpert(true);
types.add(sampling_interval);
ParameterTypeBoolean use_threads = new ParameterTypeBoolean(
ENSEMBLE_USE_THREADS,
"use threads",
false
);
use_threads.setExpert(true);
types.add(use_threads);
return types;
}
public boolean supportsCapability(LearnerCapability capability) {
Operator innerOperator = this.getOperator(0);
return ((Learner) innerOperator).supportsCapability(capability);
}
@Override
public int getMaxNumberOfInnerOperators() {
return 1;
}
@Override
public int getMinNumberOfInnerOperators() {
return 1;
}
public void checkInnerOperator() throws UserError {
// inner operator must be a learner
Operator innerOperator = this.getOperator(0);
if (!(innerOperator instanceof Learner)) {
throw new UserError(this, 127, "Inner operator is not a learner");
}
}
protected static enum ENSEMBLE_SIMILARITY_MEASURES {
EuclideanDistance
};
protected static String[] getMeasureNames() {
List<String> l = new ArrayList<String>();
for(Object o : ENSEMBLE_SIMILARITY_MEASURES.values()) {
l.add(o.toString());
}
return l.toArray(new String[] {});
}
protected static String getDefaultMeasureNames() {
return ENSEMBLE_SIMILARITY_MEASURES.EuclideanDistance.toString();
}
private MemberGatherer getGatherer() {
return gatherer;
}
}