package de.tud.inf.operator.learner.regressionensemble;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
//import java.util.Random;
import java.util.Set;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleReader;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.AttributeValueFilter;
import com.rapidminer.example.set.AttributeValueFilterSingleCondition;
import com.rapidminer.example.set.Condition;
import com.rapidminer.example.set.ConditionedExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.ExampleTable;
import com.rapidminer.operator.AbstractIOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.learner.LearnerCapability;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.learner.meta.AbstractMetaLearner;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeFile;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeStringCategory;
import de.tud.inf.example.set.ReverseSortedExampleReader;
//OPTION gather statistics on member creation and deletion
//OPTION prediction mechanism (no writing to the example set/creating of attributes)
public class EnsembleRegression extends AbstractMetaLearner {
// definition of parameter keys
private static final String ENSEMBLE_STATE_FILE = "state file";
private static final String ENSEMBLE_MAX_MEMBERS = "maximum members";
private static final String ENSEMBLE_MIN_MEMBERS = "minimum members";
private static final String ENSEMBLE_EVICTION_RATIO = "eviction ratio";
private static final String ENSEMBLE_LOCAL_THRESHOLD = "local threshold";
private static final String ENSEMBLE_SIMILARITY_MEASURE = "similarity measure";
private static final String ENSEMBLE_DISCARD_STATE = "discard state on first run";
private static final String ENSEMBLE_INITIAL_WINDOW = "initial window";
// definition of the available similarity measures
protected static enum ENSEMBLE_SIMILARITY_MEASURES {
EuclideanDistance
};
protected static String[] getMeasureNames() {
List<String> l = new ArrayList<String>();
for(Object o : ENSEMBLE_SIMILARITY_MEASURES.values()) {
l.add(o.toString());
}
return l.toArray(new String[] {});
}
protected static String getDefaultMeasureNames() {
return ENSEMBLE_SIMILARITY_MEASURES.EuclideanDistance.toString();
}
// is this the first run off apply
protected boolean firstRun = true;
//private Random randGen = new Random(0L);
public EnsembleRegression(OperatorDescription description) {
super(description);
}
public Model learn(ExampleSet exampleSet) throws OperatorException {
EnsembleRegressionModel ensemble = null;
//AbstractLearner learner = null;
Operator learner = this.getOperator(0);
File state_file = getParameterAsFile(ENSEMBLE_STATE_FILE);
String similarity_measure_type = getParameterAsString(ENSEMBLE_SIMILARITY_MEASURE);
Distance distance = null;
switch (ENSEMBLE_SIMILARITY_MEASURES.valueOf(similarity_measure_type)) {
case EuclideanDistance:
distance = new L2Norm();
break;
default:
logError("Unknown distance function");
break;
}
int max_members = getParameterAsInt(ENSEMBLE_MAX_MEMBERS);
int min_members = getParameterAsInt(ENSEMBLE_MIN_MEMBERS);
int initial_window = getParameterAsInt(ENSEMBLE_INITIAL_WINDOW);
double eviction_ratio = getParameterAsDouble(ENSEMBLE_EVICTION_RATIO);
double local_threshold = getParameterAsDouble(ENSEMBLE_LOCAL_THRESHOLD);
boolean discard_model = getParameterAsBoolean(ENSEMBLE_DISCARD_STATE);
// get the ensemble state, i.e. the existing composite model, if it exists; if not, initialize the ensemble
if( (discard_model == false && state_file.exists()) || (discard_model == true && firstRun == false) ) {
// read the state file
//log("State file read");
InputStream in = null;
try {
//in = new GZIPInputStream(new FileInputStream(state_file));
in = new FileInputStream(state_file);
ensemble = (EnsembleRegressionModel) AbstractIOObject.read(in);
//log("Ensemble contains " + ensemble.getNumberOfMembers() + " members.");
} catch (IOException e) {
throw new UserError(this, e, 303, new Object[] { state_file, e.getMessage() });
} finally {
if (in != null) {
try {
in.close();
} catch (IOException e) {
logError("Cannot close stream from file " + state_file);
}
}
}
} else {
// the ensemble keeps the complete passed example set
ensemble = new EnsembleRegressionModel(exampleSet);
firstRun = false; // from now on: read the state file
log("ensemble initialized");
}
// check whether the old and the new example set are compatible
// fulfilled by construction for the first run; meaningful check later
// for performance reasons we will just assume that the exmaple sets are compatible
// if(exampleSetsCompatible(ensemble.getExampleSet(), exampleSet) == false) {
// logError("New example set not compatible to example set used for training ensemble so far");
// throw new UserError(this, -1, "Example sets not compatible");
// }
// ------------------------------------------------------------------------------------------
// determine the parts of the example set that are new
List<Integer> newIds = newExampleIds(ensemble.getSeenIds(), exampleSet);
if(newIds == null) {
// no new ids --> just return, since no update is possible
logWarning("no new Examples found; returning old state");
} else {
// update the ensemble
for(int id : newIds) {
ensemble = updateEnsemble(
exampleSet,
ensemble,
learner,
distance,
max_members,
min_members,
initial_window,
eviction_ratio,
local_threshold,
id);
}
}
// maintain the list of seen id's
Set<Integer> seenIds = ensemble.getSeenIds();
seenIds.addAll(newIds);
ensemble.setSeenIds(seenIds);
// now we can assign the new example set as reference for the next iteration of apply
//ensemble.setExampleSet(exampleSet);
// save the ensemble state
OutputStream out = null;
try {
// out = new GZIPOutputStream(new FileOutputStream(state_file));
out = new FileOutputStream(state_file);
if(ensemble != null) {
ensemble.write(out);
}
//this.getLog().log("State file updated");
} catch (IOException e) {
throw new UserError(this, e, 303, new Object[] { state_file, e.getMessage() });
} finally {
if (out != null) {
try {
out.close();
} catch (IOException e) {
logError("Cannot close stream to file " + state_file);
}
}
}
return ensemble;
}
protected EnsembleRegressionModel updateEnsemble(
ExampleSet currentExampleSet,
EnsembleRegressionModel ensemble,
Operator learnerOperator,
Distance distance,
int max_members,
int min_members,
int initial_window,
double eviction_ratio,
double local_threshold,
int mostRecentExample
) throws OperatorException {
Learner learner = (Learner) learnerOperator;
// counter for the positive predictions
double totalPositve = 0;
// build example set for testing containing only the current new item
Attribute idAttribute = currentExampleSet.getAttributes().getId();
Attribute labelAttribute = currentExampleSet.getAttributes().getLabel();
Attribute predictedLabel = createPredictedLabel(currentExampleSet, labelAttribute);
Condition currentExampleCondition = new AttributeValueFilterSingleCondition(
idAttribute,
AttributeValueFilterSingleCondition.EQUALS,
Integer.toString(mostRecentExample));
ConditionedExampleSet currentExample = new ConditionedExampleSet(currentExampleSet, currentExampleCondition);
EnsembleMember deletionCandidate = null;
Iterator<EnsembleMember> iter = ensemble.iterator();
while(iter.hasNext()) {
EnsembleMember member = iter.next();
if(member.getState() != MemberState.UNSTABLE) {
member.getModel().performPrediction(currentExample, predictedLabel);
double label = currentExample.getExample(0).getLabel();
double prediction = currentExample.getExample(0).getDataRow().get(predictedLabel);
double dist = distance.distance(label, prediction);
//log(Double.toString(dist));
if(dist < local_threshold) {
member.incPositive();
} else {
member.incNegative();
}
double ratio = member.getRatio();
// if the member is stable and to bad, add it to the list of deletion candidates
if(member.getState() == MemberState.STABLE) {
if(ratio < eviction_ratio) {
//deletionCandidates.add(member);
if (deletionCandidate == null) {
deletionCandidate = member;
} else {
if(ratio < deletionCandidate.getRatio()) {
deletionCandidate = member;
}
}
//log("candidate added");
}
}
}
// build the new conditioned data set
Condition windowCondition = new AttributeValueFilter(currentExampleSet,
idAttribute.getName() + " >= " + member.getIntroducedAt() + " && " + idAttribute.getName() + " <= " + mostRecentExample);
//log(windowCondition.toString());
ConditionedExampleSet memberExampleSet = new ConditionedExampleSet(currentExampleSet, windowCondition);
member.setModel((PredictionModel) learner.learn(memberExampleSet));
//log(member.getModel().toString());
if(member.getState() != MemberState.STABLE) {
int attributeCount = currentExampleSet.getAttributes().allSize();
int spezialCount = currentExampleSet.getAttributes().specialSize();
int trainingThreshold = attributeCount - spezialCount;
int stableThreshold = trainingThreshold;
if(member.getState() == MemberState.UNSTABLE) {
if(mostRecentExample - member.getIntroducedAt() > trainingThreshold) {
member.setState(MemberState.TRAINING);
}
}
if(member.getState() == MemberState.TRAINING) {
if(mostRecentExample - member.getIntroducedAt() > stableThreshold) {
member.setState(MemberState.STABLE);
}
}
}
totalPositve += member.getPositive();
}
removePredictedLabel(currentExampleSet, predictedLabel);
// delete the worst member
if(deletionCandidate != null) {
// proceed only if there will be at least min_members be left after deletion
if(ensemble.getNumberOfMembers() > min_members) {
// compensate totalPositive
totalPositve -= deletionCandidate.getPositive();
ensemble.deleteMember(deletionCandidate);
}
}
iter = ensemble.iterator();
while(iter.hasNext()) {
EnsembleMember member = iter.next();
double weight = (double) member.getPositive() / (double) totalPositve;
member.setWeight(weight);
}
// add new members if needed
if(ensemble.getNumberOfMembers() < max_members) {
// only add new members at random
//if(randGen.nextBoolean()) {
PredictionModel model = null;
if(initial_window == 1) {
model = (PredictionModel) learner.learn(currentExample);
} else {
Condition initialCondition = new AttributeValueFilter(currentExampleSet,
idAttribute.getName() + " >= " + getLastIDFromWindow(currentExampleSet, initial_window));
ConditionedExampleSet initialExampleSet = new ConditionedExampleSet(currentExampleSet, initialCondition);
model = (PredictionModel) learner.learn(initialExampleSet);
}
EnsembleMember newMember = new EnsembleMember();
newMember.setWeight(0);
newMember.setPositive(0);
newMember.setNegative(0);
newMember.setIntroducedAt(mostRecentExample);
newMember.setState(MemberState.UNSTABLE);
newMember.setModel(model);
ensemble.addMember(newMember);
//}
}
return ensemble;
}
protected static int getLastIDFromWindow(ExampleSet exampleSet, int window) {
int windowId = 0;
int counter = 0;
ExampleReader reader = new ReverseSortedExampleReader(exampleSet);
while(reader.hasNext()) {
Example example = reader.next();
windowId = (int) example.getId();
counter++;
if (counter == window) {
break;
}
}
return windowId;
}
protected static Attribute createPredictedLabel(ExampleSet exampleSet, Attribute label) {
// create and add prediction attribute
Attribute predictedLabel = AttributeFactory.createAttribute(label, Attributes.PREDICTION_NAME);
predictedLabel.clearTransformations();
ExampleTable table = exampleSet.getExampleTable();
table.addAttribute(predictedLabel);
exampleSet.getAttributes().setPredictedLabel(predictedLabel);
return predictedLabel;
}
protected static void removePredictedLabel(ExampleSet exampleSet, Attribute predictedLabel) {
exampleSet.getExampleTable().removeAttribute(predictedLabel);
exampleSet.getAttributes().remove(predictedLabel);
}
protected List<Integer> newExampleIds(Set<Integer> seenIds, ExampleSet exampleSet) {
List<Integer> newIds = new LinkedList<Integer>();
ExampleReader reader = new ReverseSortedExampleReader(exampleSet);
while(reader.hasNext()) {
Example currExample = reader.next();
int currId = (int) currExample.getId();
if(!seenIds.contains(currId)) {
newIds.add(currId);
} else {
//log(Integer.toString(currId));
break;
}
}
return newIds;
}
public boolean supportsCapability(LearnerCapability capability) {
Operator innerOperator = this.getOperator(0);
return ((Learner) innerOperator).supportsCapability(capability);
}
@Override
public int getMaxNumberOfInnerOperators() {
return 1;
}
@Override
public int getMinNumberOfInnerOperators() {
return 1;
}
public void checkInnerOperator() throws UserError {
// inner operator must be a learner
Operator innerOperator = this.getOperator(0);
if (!(innerOperator instanceof Learner)) {
throw new UserError(this, 127, "Inner operator is not a learner");
}
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterTypeInt max_members = new ParameterTypeInt(
ENSEMBLE_MAX_MEMBERS,
"maximum number of members in the ensemble",
1,
Integer.MAX_VALUE,
5);
max_members.setExpert(false);
types.add(max_members);
ParameterTypeInt min_members= new ParameterTypeInt(
ENSEMBLE_MIN_MEMBERS,
"minimum number of members in the ensemble",
0,
Integer.MAX_VALUE,
0);
min_members.setExpert(true);
types.add(min_members);
ParameterTypeDouble eviction_ratio = new ParameterTypeDouble(
ENSEMBLE_EVICTION_RATIO,
"minimum weight that ensemble members need",
0,
Double.MAX_VALUE,
0.25);
eviction_ratio.setExpert(false);
types.add(eviction_ratio);
ParameterTypeDouble local_threshold = new ParameterTypeDouble(
ENSEMBLE_LOCAL_THRESHOLD,
"threshold for the prediction error",
0,
Double.MAX_VALUE,
5);
local_threshold.setExpert(false);
types.add(local_threshold);
ParameterTypeInt initial_window= new ParameterTypeInt(
ENSEMBLE_INITIAL_WINDOW,
"initial size of the training window",
1,
Integer.MAX_VALUE,
1);
initial_window.setExpert(true);
types.add(initial_window);
ParameterTypeStringCategory similarity_measure = new ParameterTypeStringCategory(
ENSEMBLE_SIMILARITY_MEASURE,
"similarity measure to use",
getMeasureNames(),
getDefaultMeasureNames());
similarity_measure.setExpert(true);
types.add(similarity_measure);
ParameterTypeBoolean discard_state = new ParameterTypeBoolean(
ENSEMBLE_DISCARD_STATE,
"discard the ensemble in the state file",
false);
discard_state.setExpert(false);
types.add(discard_state);
ParameterTypeFile state_file = new ParameterTypeFile(
ENSEMBLE_STATE_FILE,
"path to the ensemble state file",
"mod",
false);
state_file.setExpert(false);
types.add(state_file);
return types;
}
}