/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.reinforce;
import java.util.HashMap;
import java.util.Map;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.chombo.util.HistogramStat;
import org.chombo.util.ConfigUtility;
import org.chombo.util.Utility;
/**
* Interval estimator reinforcement learner based on confidence bound
* @author pranab
*
*/
public class IntervalEstimatorLearner extends ReinforcementLearner{
private int binWidth;
private int confidenceLimit;
private int minConfidenceLimit;
private int curConfidenceLimit;
private int confidenceLimitReductionStep;
private int confidenceLimitReductionRoundInterval;
private int minDistrSample;
private Map<String, HistogramStat> rewardDistr = new HashMap<String, HistogramStat>();
private long lastRoundNum = 1;
private long randomSelectCount;
private long intvEstSelectCount;
private boolean debugOn;
private long logCounter;
private boolean lowSample = true;
private static final Logger LOG = Logger.getLogger(IntervalEstimatorLearner.class);
@Override
public void initialize(Map<String, Object> config) {
super.initialize(config);
binWidth = ConfigUtility.getInt(config, "bin.width");
confidenceLimit = ConfigUtility.getInt(config, "confidence.limit");
minConfidenceLimit = ConfigUtility.getInt(config, "min.confidence.limit");
curConfidenceLimit = confidenceLimit;
confidenceLimitReductionStep = ConfigUtility.getInt(config, "confidence.limit.reduction.step");
confidenceLimitReductionRoundInterval = ConfigUtility.getInt(config, "confidence.limit.reduction.round.interval");
minDistrSample = ConfigUtility.getInt(config, "min.reward.distr.sample");
for (Action action : actions) {
rewardDistr.put(action.getId(), new HistogramStat(binWidth));
}
debugOn = ConfigUtility.getBoolean(config,"debug.on", false);
if (debugOn) {
LOG.setLevel(Level.INFO);
LOG.info("confidenceLimit:" + confidenceLimit + " minConfidenceLimit:" + minConfidenceLimit +
" confidenceLimitReductionStep:" + confidenceLimitReductionStep + " confidenceLimitReductionRoundInterval:" +
confidenceLimitReductionRoundInterval + " minDistrSample:" + minDistrSample);
}
}
/* (non-Javadoc)
* @see org.avenir.reinforce.ReinforcementLearner#nextAction()
*/
@Override
public Action nextAction() {
Action selAction = null;
++logCounter;
++totalTrialCount;
//make sure reward distributions have enough sample
if (lowSample) {
lowSample = false;
for (String action : rewardDistr.keySet()) {
int sampleCount = rewardDistr.get(action).getCount();
if (debugOn && logCounter % 100 == 0) {
LOG.info("action:" + action + " distr sampleCount: " + sampleCount);
}
if (sampleCount < minDistrSample) {
lowSample = true;
break;
}
}
if (!lowSample && debugOn) {
LOG.info("got full sample");
lastRoundNum = totalTrialCount;
}
}
if (lowSample) {
//select randomly
selAction = Utility.selectRandom(actions);
++randomSelectCount;
} else {
//reduce confidence limit
adjustConfLimit();
//select as per interval estimate, choosing distr with max upper conf bound
int maxUpperConfBound = 0;
String selActionId = null;
for (String action : rewardDistr.keySet()) {
HistogramStat stat = rewardDistr.get(action);
int[] confBounds = stat.getConfidenceBounds(curConfidenceLimit);
if (debugOn) {
LOG.info("curConfidenceLimit:" + curConfidenceLimit + " action:" + action + " conf bounds:" + confBounds[0] + " " + confBounds[1]);
}
if (confBounds[1] > maxUpperConfBound) {
maxUpperConfBound = confBounds[1];
selActionId = action;
}
}
selAction = findAction(selActionId);
++intvEstSelectCount;
}
selAction.select();
return selAction;
}
/**
* @param roundNum
*/
private void adjustConfLimit() {
if (curConfidenceLimit > minConfidenceLimit) {
int redStep = (int)((totalTrialCount - lastRoundNum) / confidenceLimitReductionRoundInterval);
if (debugOn) {
LOG.info("redStep:" + redStep + " roundNum:" + totalTrialCount + " lastRoundNum:" + lastRoundNum);
}
if (redStep > 0) {
curConfidenceLimit -= (redStep * confidenceLimitReductionStep);
if (curConfidenceLimit < minConfidenceLimit) {
curConfidenceLimit = minConfidenceLimit;
}
if (debugOn) {
LOG.info("reduce conf limit roundNum:" + totalTrialCount + " lastRoundNum:" + lastRoundNum);
}
lastRoundNum = totalTrialCount;
}
}
}
@Override
public void setReward(String action, int reward) {
HistogramStat stat = rewardDistr.get(action);
if (null == stat) {
throw new IllegalArgumentException("invalid action:" + action);
}
stat.add(reward);
findAction(action).reward(reward);
if (debugOn) {
LOG.info("setReward action:" + action + " reward:" + reward + " sample count:" + stat.getCount());
}
}
public String getStat() {
return "randomSelectCount:" + randomSelectCount + " intvEstSelectCount:" + intvEstSelectCount;
}
}