/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package storm.applications.model.learner; import java.util.HashMap; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import storm.applications.constants.ReinforcementLearnerConstants.Conf; import storm.applications.util.config.Configuration; import storm.applications.util.math.HistogramStat; /** * Interval estimator reinforcement learner based on confidence bound * @author pranab * */ public class IntervalEstimator extends ReinforcementLearner { private static final Logger LOG = LoggerFactory.getLogger(IntervalEstimator.class); private int binWidth; private int confidenceLimit; private int minConfidenceLimit; private int curConfidenceLimit; private int confidenceLimitReductionStep; private int confidenceLimitReductionRoundInterval; private int minDistrSample; private int lastRoundNum = 1; private long randomSelectCount; private long intvEstSelectCount; private long logCounter; private long roundCounter; private boolean lowSample = true; private boolean debugOn; private Map<String, HistogramStat> rewardDistr = new HashMap<>(); @Override public void initialize(Configuration config) { binWidth = config.getInt(Conf.BIN_WIDTH); confidenceLimit = config.getInt(Conf.CONFIDENCE_LIMIT); minConfidenceLimit = config.getInt(Conf.MIN_CONFIDENCE_LIMIT); curConfidenceLimit = confidenceLimit; confidenceLimitReductionStep = config.getInt(Conf.CONFIDENCE_LIMIT_RED_STEP); confidenceLimitReductionRoundInterval = config.getInt(Conf.CONFIDENCE_LIMIT_RED_ROUND_INT); minDistrSample = config.getInt(Conf.MIN_DIST_SAMPLE); debugOn = config.getBoolean(Conf.DEBUG_ON, false); for (String action : actions) { rewardDistr.put(action, new HistogramStat(binWidth)); } initSelectedActions(); if (debugOn) { LOG.info("confidenceLimit:" + confidenceLimit + " minConfidenceLimit:" + minConfidenceLimit + " confidenceLimitReductionStep:" + confidenceLimitReductionStep + " confidenceLimitReductionRoundInterval:" + confidenceLimitReductionRoundInterval + " minDistrSample:" + minDistrSample); } } @Override public String[] nextActions(int roundNum) { String selAction = null; ++logCounter; ++roundCounter; //make sure reward distributions have enough sample if (lowSample) { lowSample = false; for (String action : rewardDistr.keySet()) { int sampleCount = rewardDistr.get(action).getCount(); if (debugOn && logCounter % 100 == 0) { LOG.info("action:" + action + " distr sampleCount: " + sampleCount); } if (sampleCount < minDistrSample) { lowSample = true; break; } } if (!lowSample && debugOn) { LOG.info("got full sample"); lastRoundNum = roundNum; } } if (lowSample) { //select randomly selAction = actions[(int)(Math.random() * actions.length)]; ++randomSelectCount; } else { //reduce confidence limit adjustConfLimit(roundNum); //select as per interval estimate, choosing distr with max upper conf bound int maxUpperConfBound = 0; for (String action : rewardDistr.keySet()) { HistogramStat stat = rewardDistr.get(action); int[] confBounds = stat.getConfidenceBounds(curConfidenceLimit); if (debugOn) { LOG.info("curConfidenceLimit:" + curConfidenceLimit + " action:" + action + " conf bounds:" + confBounds[0] + " " + confBounds[1]); } if (confBounds[1] > maxUpperConfBound) { maxUpperConfBound = confBounds[1]; selAction = action; } } ++intvEstSelectCount; } selActions[0] = selAction; return selActions; } /** * @param roundNum */ private void adjustConfLimit(int roundNum) { if (curConfidenceLimit > minConfidenceLimit) { int redStep = (roundNum - lastRoundNum) / confidenceLimitReductionRoundInterval; if (debugOn) { LOG.info("redStep:" + redStep + " roundNum:" + roundNum + " lastRoundNum:" + lastRoundNum); } if (redStep > 0) { curConfidenceLimit -= (redStep * confidenceLimitReductionStep); if (curConfidenceLimit < minConfidenceLimit) { curConfidenceLimit = minConfidenceLimit; } if (debugOn) { LOG.info("reduce conf limit roundNum:" + roundNum + " lastRoundNum:" + lastRoundNum); } lastRoundNum = roundNum; } } } @Override public void setReward(String action, int reward) { HistogramStat stat = rewardDistr.get(action); if (null == stat) { throw new IllegalArgumentException("invalid action:" + action); } stat.add(reward); if (debugOn) { LOG.info("setReward action:" + action + " reward:" + reward + " sample count:" + stat.getCount()); } } @Override public String getStat() { return "randomSelectCount:" + randomSelectCount + " intvEstSelectCount:" + intvEstSelectCount; } }