/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.reinforce;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.chombo.util.DynamicBean;
import org.chombo.util.Utility;
/**
* Implements greedy multiarm bandit reinforcement learning algorithms
* @author pranab
*
*/
public class GreedyRandomBandit extends Configured implements Tool {
@Override
public int run(String[] args) throws Exception {
Job job = new Job(getConf());
String jobName = "Greedy random bandit problem";
job.setJobName(jobName);
job.setJarByClass(GreedyRandomBandit.class);
FileInputFormat.addInputPath(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
Utility.setConfiguration(job.getConfiguration(), "avenir");
job.setMapperClass(GreedyRandomBandit.BanditMapper.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(Text.class);
int status = job.waitForCompletion(true) ? 0 : 1;
return status;
}
/**
* @author pranab
*
*/
public static class BanditMapper extends Mapper<LongWritable, Text, NullWritable, Text> {
private String fieldDelimRegex;
private String fieldDelim ;
private String[] items;
private Text outVal = new Text();
private int roundNum;
private float randomSelectionProb;
private String probRedAlgorithm;
private String curGroupID = null;
private String groupID;
private int countOrdinal;
private int rewardOrdinal;
private static final String PROB_RED_LINEAR = "linear";
private static final String PROB_RED_LOG_LINEAR = "logLinear";
private float probReductionConstant;
private static final String AUER_GREEDY = "AuerGreedy";
private static final String ITEM_ID = "itemID";
private static final String ITEM_COUNT = "count";
private static final String ITEM_REWARD = "reward";
private Map<String, Integer> groupBatchCount = new HashMap<String, Integer>();
private int auerGreedyConstant;
private GroupedItems groupedItems = new GroupedItems();
private int globalBatchSize;
private boolean selectionUnique;
private int minReward;
private boolean outputDecisionCount;
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void setup(Context context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
fieldDelimRegex = conf.get("field.delim.regex", ",");
fieldDelim = conf.get("field.delim", ",");
roundNum = Utility.assertIntConfigParam(conf, "current.round.num", "missing round number config param");
randomSelectionProb = conf.getFloat("random.selection.prob", (float)0.5);
probRedAlgorithm = conf.get("prob.reduction.algorithm", PROB_RED_LINEAR );
probReductionConstant = conf.getFloat("prob.reduction.constant", (float)1.0);
countOrdinal = conf.getInt("count.ordinal", -1);
rewardOrdinal = conf.getInt("reward.ordinal", -1);
auerGreedyConstant = conf.getInt("auer.greedy.constant", 5);
selectionUnique = conf.getBoolean("selection.unique", false);
minReward = conf.getInt("min.reward", 5);
outputDecisionCount = conf.getBoolean("output.decision.count", false);
//batch size is the number items selected in each round for each group
globalBatchSize = conf.getInt("global.batch.size", -1);
if (globalBatchSize < 0) {
List<String[]> lines = Utility.parseFileLines(conf, "group.item.count.path", ",");
if (lines.isEmpty()) {
throw new IllegalStateException("either global batch size or groupwise batch size needs to be defined");
}
String groupID;
int batchSize;
for (String[] line : lines) {
groupID= line[0];
batchSize = Integer.parseInt(line[1]);
groupBatchCount.put(groupID, batchSize );
}
}
}
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#cleanup(org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void cleanup(Context context) throws IOException, InterruptedException {
select( context);
}
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
items = value.toString().split(fieldDelimRegex);
groupID = items[0];
if (null == curGroupID || !groupID.equals(curGroupID)) {
//new group
if (null == curGroupID) {
//first group
collectGroupItems();
curGroupID = groupID;
} else {
//process this group
select( context);
//start next group
groupedItems.initialize();
curGroupID = groupID;
collectGroupItems();
}
} else {
//existing group
collectGroupItems();
}
}
/**
* @return
*/
private int getBatchSize() {
int batchSize = groupBatchCount.isEmpty() ? globalBatchSize : groupBatchCount.get(curGroupID);
return batchSize;
}
/**
*
*/
private void collectGroupItems() {
groupedItems.createtem(items[1], Integer.parseInt(items[countOrdinal]), Integer.parseInt(items[rewardOrdinal]));
}
/**
* select batch size number of items form each group
* @return
* @throws InterruptedException
* @throws IOException
*/
private void select(Context context) throws IOException, InterruptedException {
List<String> selItems = null;
if (probRedAlgorithm.equals(PROB_RED_LINEAR )) {
selItems = linearSelect(context, false);
} else if (probRedAlgorithm.equals(PROB_RED_LOG_LINEAR )) {
selItems = linearSelect(context, true);
} else if (probRedAlgorithm.equals(AUER_GREEDY )) {
selItems = greedyAuerSelect(context);
}
//emit all selected items
if (outputDecisionCount) {
Map<String, Integer> decisionCount = new HashMap<String, Integer>();
for (String item : selItems) {
Integer itemCount = decisionCount.get(item);
if (null == itemCount) {
itemCount = 0;
}
decisionCount.put(item, itemCount + 1);
}
for (String item : decisionCount.keySet()) {
outVal.set(curGroupID + fieldDelim + item + fieldDelim + decisionCount.get(item));
context.write(NullWritable.get(), outVal);
}
} else {
for (String item : selItems) {
outVal.set(curGroupID + fieldDelim + item);
context.write(NullWritable.get(), outVal);
}
}
}
/**
* @return
* @throws InterruptedException
* @throws IOException
*/
private List<String> linearSelect(Context context, boolean logLinear) throws IOException, InterruptedException {
List<String> items = new ArrayList<String>();
int batchSize = getBatchSize();
int count = (roundNum -1) * batchSize;
String itemID = null;
float curProb;
//select items for the batch
for (int i = 0; i < batchSize; ++i) {
++count;
if (logLinear) {
curProb = (float)(randomSelectionProb * probReductionConstant * Math.log(count) / count);
} else {
curProb = randomSelectionProb * probReductionConstant / count ;
}
curProb = curProb <= randomSelectionProb ? curProb : randomSelectionProb;
itemID = linearSelectHelper(curProb, context);
if (selectionUnique) {
while(items.contains(itemID)) {
itemID = linearSelectHelper(curProb, context);
}
}
items.add(itemID);
}
return items;
}
/**
* @param context
* @throws IOException
* @throws InterruptedException
*/
private List<String> greedyAuerSelect(Context context) throws IOException, InterruptedException {
List<String> items = new ArrayList<String>();
int batchSize = getBatchSize();
int count = (roundNum -1) * batchSize;
int maxReward = 0;
int nextMaxreward = 0;
//max reward in this group
int groupCount = groupedItems.size();
//until we have full batch
while (items.size() < batchSize) {
//clear all use counts and start over
groupedItems.clearAllUseCount();
//collect items not tried before
List<DynamicBean> collectedItems = groupedItems.collectItemsNotTried(batchSize);
count += collectedItems.size();
for (DynamicBean it : collectedItems) {
items.add(it.getString(ITEM_ID));
groupedItems.select(it, minReward);
}
//collect items according to greedy algorithm
while (items.size() < batchSize) {
DynamicBean maxRewardItem = groupedItems.getMaxRewardItem();
groupedItems.remove(maxRewardItem);
DynamicBean nextMaxRewardItem = groupedItems.getMaxRewardItem();
groupedItems.add(maxRewardItem);
maxReward = maxRewardItem.getInt(ITEM_REWARD);
nextMaxreward = nextMaxRewardItem.getInt(ITEM_REWARD);
double rewardDiff = (double)((maxReward - nextMaxreward)) / maxReward;
//select as per Auer greedy algorithm
double prob = maxReward == nextMaxreward ? 1.0 :
auerGreedyConstant * groupCount / (rewardDiff * rewardDiff * count);
prob = prob > 1.0 ? 1.0 : prob;
DynamicBean selectedItem = null;
if (prob < Math.random()) {
//select random
selectedItem = groupedItems.selectRandom();
} else {
//select one with best reward
selectedItem = groupedItems.select(maxRewardItem);
}
items.add(selectedItem.getString(ITEM_ID));
groupedItems.select(selectedItem, minReward);
++count;
}
}
return items;
}
/**
* @param curProb
* @param context
* @throws IOException
* @throws InterruptedException
*/
private String linearSelectHelper(float curProb, Context context) throws IOException, InterruptedException {
String itemID = null;
DynamicBean selItem = null;
groupedItems.clearAllUseCount();
if (curProb < Math.random()) {
//select random
selItem = groupedItems.selectRandom();
itemID = selItem.getString(ITEM_ID);
} else {
//choose best so far
DynamicBean maxRewardItem = groupedItems.getMaxRewardItem();
if (null == maxRewardItem) {
//nothing tried, choose randomly
selItem = groupedItems.selectRandom();
itemID = selItem.getString(ITEM_ID);
} else {
selItem = maxRewardItem;
itemID = selItem.getString(ITEM_ID);
}
}
groupedItems.select(selItem, minReward);
return itemID;
}
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
int exitCode = ToolRunner.run(new GreedyRandomBandit(), args);
System.exit(exitCode);
}
}