/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.reinforce;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.chombo.util.DynamicBean;
import org.chombo.util.Utility;
/**
* Deterministic Auer MAB algorithm
* @author pranab
*
*/
public class AuerDeterministic extends Configured implements Tool {
@Override
public int run(String[] args) throws Exception {
Job job = new Job(getConf());
String jobName = "Auer determininstic MAB";
job.setJobName(jobName);
job.setJarByClass(AuerDeterministic.class);
FileInputFormat.addInputPath(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
Utility.setConfiguration(job.getConfiguration(), "avenir");
job.setMapperClass(AuerDeterministic.BanditMapper.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(Text.class);
int status = job.waitForCompletion(true) ? 0 : 1;
return status;
}
/**
* @author pranab
*
*/
public static class BanditMapper extends Mapper<LongWritable, Text, NullWritable, Text> {
private String fieldDelimRegex;
private String fieldDelim ;
private String[] items;
private Text outVal = new Text();
private int roundNum;
private String detAlgorithm;
private String curGroupID = null;
private String groupID;
private int countOrdinal;
private int rewardOrdinal;
private static final String AUER_DET_UBC1 = "AuerUBC1";
private static final String ITEM_ID = "itemID";
private static final String ITEM_COUNT = "count";
private static final String ITEM_REWARD = "reward";
private Map<String, Integer> groupBatchCount = new HashMap<String, Integer>();
private GroupedItems groupedItems = new GroupedItems();
private int globalBatchSize;
private int minReward;
private boolean outputDecisionCount;
private static final int GR_ORD = 0;
private static final int IT_ORD = 1;
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void setup(Context context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
fieldDelimRegex = conf.get("field.delim.regex", ",");
fieldDelim = conf.get("field.delim", ",");
roundNum = Utility.assertIntConfigParam(conf, "current.round.num", "missing round number config param");
detAlgorithm = conf.get("det.algorithm", AUER_DET_UBC1 );
countOrdinal = conf.getInt("count.ordinal", -1);
rewardOrdinal = conf.getInt("reward.ordinal", -1);
minReward = conf.getInt("min.reward", 5);
outputDecisionCount = conf.getBoolean("output.decision.count", false);
//batch size
globalBatchSize = conf.getInt("global.batch.size", -1);
if (globalBatchSize < 0) {
List<String[]> lines = Utility.parseFileLines(conf, "group.item.count.path", ",");
if (lines.isEmpty()) {
throw new IllegalStateException("either global batch size or groupwise batch size needs to be defined");
}
String groupID;
int batchSize;
for (String[] line : lines) {
groupID= line[0];
batchSize = Integer.parseInt(line[1]);
groupBatchCount.put(groupID, batchSize );
}
}
}
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#cleanup(org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void cleanup(Context context) throws IOException, InterruptedException {
select( context);
}
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
items = value.toString().split(fieldDelimRegex);
groupID = items[GR_ORD];
if (null == curGroupID || !groupID.equals(curGroupID)) {
//new group
if (null == curGroupID) {
//first group
collectGroupItems();
curGroupID = groupID;
} else {
//process this group
select( context);
//start next group
groupedItems.initialize();
curGroupID = groupID;
collectGroupItems();
}
} else {
//existing group
collectGroupItems();
}
}
/**
* @return
*/
private int getBatchSize() {
int batchSize = groupBatchCount.isEmpty() ? globalBatchSize : groupBatchCount.get(curGroupID);
return batchSize;
}
/**
*
*/
private void collectGroupItems() {
groupedItems.createtem(items[IT_ORD], Integer.parseInt(items[countOrdinal]),
Integer.parseInt(items[rewardOrdinal]));
}
/**
* @return
* @throws InterruptedException
* @throws IOException
*/
private void select(Context context) throws IOException, InterruptedException {
if (detAlgorithm.equals(AUER_DET_UBC1 )) {
deterministicAuerSelect(context);
} else {
throw new IllegalArgumentException("inalid auer deterministic algorithm");
}
}
/**
* @param context
* @throws IOException
* @throws InterruptedException
*/
private void deterministicAuerSelect(Context context) throws IOException, InterruptedException {
List<String> selItems = new ArrayList<String>();
int batchSize = getBatchSize();
int count = (roundNum -1) * batchSize;
//collect items not tried before
count = collectUntriedItems(selItems, batchSize, count);
//collect items according to UBC
count = collectItemsByValue(selItems, batchSize, count);
//emit all selected items
if (outputDecisionCount) {
Map<String, Integer> decisionCount = new HashMap<String, Integer>();
for (String item : selItems) {
Integer itemCount = decisionCount.get(item);
if (null == itemCount) {
itemCount = 0;
}
decisionCount.put(item, itemCount + 1);
}
for (String item : decisionCount.keySet()) {
outVal.set(curGroupID + fieldDelim + item + fieldDelim + decisionCount.get(item));
context.write(NullWritable.get(), outVal);
}
} else {
for (String item : selItems) {
outVal.set(curGroupID + fieldDelim + item);
context.write(NullWritable.get(), outVal);
}
}
}
/**
* @param items
* @param batchSize
* @return
*/
private int collectUntriedItems(List<String> items, int batchSize, int count) {
List<DynamicBean> collectedItems = groupedItems.collectItemsNotTried(batchSize);
count += collectedItems.size();
for (DynamicBean it : collectedItems) {
items.add(it.getString(ITEM_ID));
if (minReward > 0) {
groupedItems.select(it, minReward);
}
}
return count;
}
/**
* @param items
* @param batchSize
* @return
*/
private int collectItemsByValue(List<String> items, int batchSize, int count) {
int maxReward = 0;
String item = null;
int thisCount = 0;
int reward = 0;
while (items.size() < batchSize) {
//max reward in this group
DynamicBean maxRewardItem = groupedItems.getMaxRewardItem();
maxReward = maxRewardItem.getInt(ITEM_REWARD);
double valueMax = 0.0;
double value;
DynamicBean selectedGroupItem = null;
List<DynamicBean> groupItems = groupedItems.getGroupItems();
for (DynamicBean groupItem : groupItems) {
reward = groupedItems.getReward(groupItem);
thisCount = groupedItems.getTotalCount(groupItem);
if (thisCount > 0) {
value = ((double)reward) / maxReward + Math.sqrt(2.0 * Math.log(count) / thisCount);
if (value > valueMax) {
item = groupItem.getString(ITEM_ID);
valueMax = value;
selectedGroupItem = groupItem;
}
}
}
if (null != selectedGroupItem) {
items.add(item);
if (minReward > 0) {
groupedItems.select(selectedGroupItem, minReward);
}
++count;
} else {
throw new IllegalArgumentException("Should not be here. Failed to select item by value");
}
}
return count;
}
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
int exitCode = ToolRunner.run(new AuerDeterministic(), args);
System.exit(exitCode);
}
}