/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.avenir.reinforce; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; import org.chombo.util.DynamicBean; import org.chombo.util.RandomSampler; import org.chombo.util.Utility; /** * SoftMax multi arm bandit * @author pranab * */ public class SoftMaxBandit extends Configured implements Tool { @Override public int run(String[] args) throws Exception { Job job = new Job(getConf()); String jobName = "Soft max MAB"; job.setJobName(jobName); job.setJarByClass(SoftMaxBandit.class); FileInputFormat.addInputPath(job, new Path(args[0])); FileOutputFormat.setOutputPath(job, new Path(args[1])); Utility.setConfiguration(job.getConfiguration(), "avenir"); job.setMapperClass(SoftMaxBandit.BanditMapper.class); job.setOutputKeyClass(NullWritable.class); job.setOutputValueClass(Text.class); int status = job.waitForCompletion(true) ? 0 : 1; return status; } /** * @author pranab * */ public static class BanditMapper extends Mapper<LongWritable, Text, NullWritable, Text> { private String fieldDelimRegex; private String fieldDelim ; private String[] items; private Text outVal = new Text(); private int roundNum; private double tempConstant; private String curGroupID = null; private String groupID; private int countOrdinal; private int rewardOrdinal; private Map<String, Integer> groupBatchCount = new HashMap<String, Integer>(); private GroupedItems groupedItems = new GroupedItems(); private RandomSampler sampler = new RandomSampler(); private static final int DISTR_SCALE = 1000; /* (non-Javadoc) * @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context) */ protected void setup(Context context) throws IOException, InterruptedException { Configuration conf = context.getConfiguration(); fieldDelimRegex = conf.get("field.delim.regex", ","); fieldDelim = conf.get("field.delim", ","); roundNum = conf.getInt("current.round.num", -1); tempConstant = Double.parseDouble(conf.get("temp.constant", "1.0")); countOrdinal = conf.getInt("count.ordinal", -1); rewardOrdinal = conf.getInt("reward.ordinal", -1); //batch size List<String[]> lines = Utility.parseFileLines(conf, "group.item.count.path", ","); String groupID; int batchSize; for (String[] line : lines) { groupID= line[0]; batchSize = Integer.parseInt(line[1]); groupBatchCount.put(groupID, batchSize ); } } /* (non-Javadoc) * @see org.apache.hadoop.mapreduce.Mapper#cleanup(org.apache.hadoop.mapreduce.Mapper.Context) */ protected void cleanup(Context context) throws IOException, InterruptedException { select( context); } @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { items = value.toString().split(fieldDelimRegex); groupID = items[0]; if (null == curGroupID || !groupID.equals(curGroupID)) { //new group if (null == curGroupID) { collectGroupItems(); curGroupID = groupID; } else { //process this group select( context); //start next group groupedItems.initialize(); curGroupID = groupID; collectGroupItems(); } } else { //existing group collectGroupItems(); } } /** * @return */ private int getBatchSize() { int batchSize = groupBatchCount.isEmpty() ? 1 : groupBatchCount.get(curGroupID); return batchSize; } /** * */ private void collectGroupItems() { groupedItems.createtem(items[1], Integer.parseInt(items[countOrdinal]), Integer.parseInt(items[rewardOrdinal])); } /** * @return * @throws InterruptedException * @throws IOException */ private void select(Context context) throws IOException, InterruptedException { List<String> items = new ArrayList<String>(); int batchSize = getBatchSize(); int count = (roundNum -1) * batchSize; //collect items not tried before List<DynamicBean> collectedItems = groupedItems.collectItemsNotTried(batchSize); count += collectedItems.size(); for (DynamicBean it : collectedItems) { items.add(it.getString(GroupedItems.ITEM_ID)); } //random sampling based on distribution sampler.initialize(); int maxReward = groupedItems.getMaxRewardItem().getInt(GroupedItems.ITEM_REWARD); for ( DynamicBean item : groupedItems.getGroupItems()) { double distr = ((double) item.getInt(GroupedItems.ITEM_REWARD)) / maxReward; int scaledDistr = (int)(Math.exp(distr / tempConstant) * DISTR_SCALE); sampler.addToDistr(item.getString(GroupedItems.ITEM_ID), scaledDistr); } Set<String> sampledItems = new HashSet<String>(); while (items.size() < batchSize) { String selected = sampler.sample(); if (!sampledItems.contains(selected)) { sampledItems.add(selected); items.add(selected); ++count; } } //emit all selected items for (String it : items) { outVal.set(curGroupID + fieldDelim + it); context.write(NullWritable.get(), outVal); } } } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { int exitCode = ToolRunner.run(new SoftMaxBandit(), args); System.exit(exitCode); } }