package quickml.supervised.ensembles.randomForest.randomRegressionForest; import com.google.common.base.Preconditions; import com.google.common.collect.Maps; import com.google.common.util.concurrent.AtomicDouble; import quickml.data.AttributesMap; import quickml.data.PredictionMap; import quickml.supervised.classifier.AbstractClassifier; import quickml.supervised.ensembles.randomForest.RandomForest; import quickml.supervised.tree.regressionTree.RegressionTree; import java.io.Serializable; import java.util.*; /** * Created with IntelliJ IDEA. * User: ian * Date: 4/18/13 * Time: 4:17 PM * To change this template use File | Settings | File Templates. */ public class RandomRegressionForest implements RandomForest<Double, RegressionTree> { static final long serialVersionUID = 56394564395638954L; public final List<RegressionTree> regressionTrees; protected RandomRegressionForest(List<RegressionTree> regressionTrees) { Preconditions.checkArgument(regressionTrees.size() > 0, "We must have at least one oldTree"); this.regressionTrees = regressionTrees; } @Override public Double predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) { double total = 0; for (RegressionTree regressionTree : regressionTrees) { final double value = regressionTree.predictWithoutAttributes(attributes, attributesToIgnore); if (Double.isInfinite(value) || Double.isNaN(value)) { throw new RuntimeException("Probability must be a normal number, not "+value); } total += value; } return total / regressionTrees.size(); } @Override public Double predict(AttributesMap attributes) { double total = 0; for (RegressionTree regressionTree : regressionTrees) { final double value = regressionTree.predict(attributes); if (Double.isInfinite(value) || Double.isNaN(value)) { throw new RuntimeException("Probability must be a normal number, not "+value); } total += value; } return total / regressionTrees.size(); } @Override public boolean equals(final Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; final RandomRegressionForest that = (RandomRegressionForest) o; if (!regressionTrees.equals(that.regressionTrees)) return false; return true; } @Override public int hashCode() { return regressionTrees.hashCode(); } }