/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.avenir.model; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import org.chombo.util.Pair; /** * @author pranab * */ public class EnsemblePredictiveModel extends PredictiveModel { private List<WeightedModel> models = new ArrayList<WeightedModel>(); private Map<String, Double> votes = new HashMap<String, Double>(); private List<VoteCount> sortedVotes = new ArrayList<VoteCount>(); private double minOddsRatio = -1.0; public EnsemblePredictiveModel() { super(); } /** * @param minOdds * @return */ public EnsemblePredictiveModel withMinOdds(double minOddsRatio) { this.minOddsRatio = minOddsRatio; return this; } /** * @param model */ public void addModel(PredictiveModel model) { addModel(model, 1.0); } /** * @param model */ public void addModel(PredictiveModel model, double weight) { models.add(new WeightedModel(model, weight)); } /* (non-Javadoc) * @see org.avenir.model.PredictiveModel#predict(java.lang.String[]) */ @Override public String predict(String[] items) { if (models.size() % 2 == 0) { throw new IllegalStateException("neem odd number of models in ensemble"); } //get votes votes.clear(); for (WeightedModel weightedModel : models) { PredictiveModel model = weightedModel.getLeft(); double weight = weightedModel.getRight(); String predClass = model.predict(items); Double count = votes.get(predClass); if (null == count) { votes.put(predClass, weight); } else { votes.put(predClass, count+weight); } } //sort by vote count sortedVotes.clear(); for (String prClass : votes.keySet()) { Double voteCount = votes.get(prClass); sortedVotes.add(new VoteCount(prClass, voteCount)); } Collections.sort(sortedVotes); if (minOddsRatio > 1.0) { //null implies ambiguous double oddsRatio = sortedVotes.get(0).getRight() / sortedVotes.get(1).getRight(); predClass = oddsRatio > minOddsRatio ? sortedVotes.get(0).getLeft() : null; } else { //select max vote predClass = sortedVotes.get(0).getLeft(); } if (errorCountingEnabled) { countError(); } return predClass; } @Override protected Pair<String, Double> predictClassProb(String[] items) { return null; } private static class WeightedModel extends Pair<PredictiveModel, Double> { public WeightedModel(PredictiveModel model, double weight) { super(model, weight); } } /** * @author pranab * */ private static class VoteCount extends Pair<String, Double> implements Comparable<VoteCount> { /** * @param predClass * @param voteCount */ public VoteCount(String predClass, double voteCount) { super(predClass, voteCount); } @Override public int compareTo(VoteCount that) { return that.right.compareTo(this.right); } } }