package de.tud.inf.operator.learner.regressionensemble; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.HashSet; import java.util.Set; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.set.AttributeValueFilterSingleCondition; import com.rapidminer.example.set.Condition; import com.rapidminer.example.set.ConditionedExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.Tools; @SuppressWarnings("unused") public class EnsembleRegressionModel extends PredictionModel implements Iterable<EnsembleMember> { private static final long serialVersionUID = 4075168877323509822L; private List<EnsembleMember> members; private Set<Integer> seenIds; //private ExampleSet exampleSet; protected EnsembleRegressionModel(ExampleSet exampleSet) { this(exampleSet, null); } protected EnsembleRegressionModel(ExampleSet exampleSet, EnsembleMember[] initMembers) { super(exampleSet); members = new LinkedList<EnsembleMember>(); if(initMembers != null) { for (EnsembleMember member : initMembers) { members.add(member); } } //this.exampleSet = exampleSet; this.seenIds = new HashSet<Integer>(); } public Set<Integer> getSeenIds() { return seenIds; } public void setSeenIds(Set<Integer> seenIds) { this.seenIds = seenIds; } public boolean addMember(EnsembleMember member) { return members.add(member); } public void deleteMember(int index) { members.remove(index); } public void deleteMember(EnsembleMember member) { members.remove(member); } public int getNumberOfMembers() { return members.size(); } public EnsembleMember getMember(int index) { return members.get(index); } /** * assign labels to the examples in the passed example set */ @Override public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException { // Attribute idAttribute = exampleSet.getAttributes().getId(); // if(idAttribute == null) { // throw new OperatorException("id-Attribute missing!"); // } HashMap<Integer, Double> predictions = new HashMap<Integer, Double>(exampleSet.size()); // initialize Iterator<Example> initIter = exampleSet.iterator(); while(initIter.hasNext()) { Example currentExample = initIter.next(); int currentId = (int) currentExample.getId(); predictions.put(currentId, 0.0); } for(EnsembleMember member : members) { // skip unstable members if(member.getState() == MemberState.UNSTABLE) { continue; } // let the member model write it's prediction into the example set member.getModel().performPrediction(exampleSet, predictedLabel); // gather the predictions Iterator<Example> iter = exampleSet.iterator(); while(iter.hasNext()) { Example currentExample = iter.next(); int currentId = (int) currentExample.getId(); double weight = member.getWeight(); double memberprediction = currentExample.getNumericalValue(predictedLabel); double predictionSoFar = predictions.get(currentId); predictions.put(currentId, predictionSoFar + memberprediction * weight); } } return exampleSet; } // /** // * assign labels to the examples in the passed example set // */ // @Override // public ExampleSet performPrediction(ExampleSet exampleSet, Attribute predictedLabel) throws OperatorException { // Attribute idAttribute = exampleSet.getAttributes().getId(); // // if(idAttribute == null) { // throw new OperatorException("id-Attribute missing!"); // } // // Iterator<Example> iter = exampleSet.iterator(); // while(iter.hasNext()) { // Example currentExample = iter.next(); // int currentId = (int) currentExample.getNumericalValue(idAttribute); // double predictionForExample = 0; // // Condition currentExampleCondition = new AttributeValueFilterSingleCondition( // idAttribute, // AttributeValueFilterSingleCondition.EQUALS, // Integer.toString(currentId)); // // ConditionedExampleSet currentExampleSet = new ConditionedExampleSet(exampleSet, currentExampleCondition); // // log(Integer.toString(currentExampleSet.size())); // // for(EnsembleMember member : members) { // // skip unstable members // if(member.getState() == MemberState.UNSTABLE) { // continue; // } // // // let the member model write it's prediction into the example set // member.getModel().performPrediction(currentExampleSet, predictedLabel); // // // the condition should assure that only on example qualifies // double memberprediction = currentExampleSet.getExample(0).getNumericalValue(predictedLabel); // // predictionForExample += memberprediction * member.getWeight(); // } // // //currentExampleSet.getExample(0).setValue(predictedLabel, predictionForExample); // } // // return exampleSet; // } @Override public String toString() { StringBuffer result = new StringBuffer(); for (EnsembleMember member : members) { PredictionModel model = member.getModel(); result.append(model.getName()); result.append(Tools.getLineSeparator()); result.append(model.toString()); result.append(Tools.getLineSeparator()); } return result.toString(); } public Iterator<EnsembleMember> iterator() { return members.iterator(); } // public ExampleSet getExampleSet() { // return exampleSet; //} // //public void setExampleSet(ExampleSet exampleSet) { // this.exampleSet = exampleSet; //} }