/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.markov;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.avenir.util.StateTransitionProbability;
import org.chombo.util.Tuple;
import org.chombo.util.Utility;
/**
* Builds HMM from labeled data. Data could be fully or partially tagged
* @author pranab
*
*/
public class HiddenMarkovModelBuilder extends Configured implements Tool {
@Override
public int run(String[] args) throws Exception {
Job job = new Job(getConf());
String jobName = "HMM model builder";
job.setJobName(jobName);
job.setJarByClass(HiddenMarkovModelBuilder.class);
FileInputFormat.addInputPath(job, new Path(args[0]));
FileOutputFormat.setOutputPath(job, new Path(args[1]));
Utility.setConfiguration(job.getConfiguration(), "avenir");
job.setMapperClass(HiddenMarkovModelBuilder.StateTransitionMapper.class);
job.setReducerClass(HiddenMarkovModelBuilder.StateTransitionReducer.class);
job.setCombinerClass(MarkovStateTransitionModel.StateTransitionCombiner.class);
job.setMapOutputKeyClass(Tuple.class);
job.setMapOutputValueClass(IntWritable.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(Text.class);
job.setNumReduceTasks(job.getConfiguration().getInt("num.reducer", 1));
int status = job.waitForCompletion(true) ? 0 : 1;
return status;
}
/**
* @author pranab
*
*/
public static class StateTransitionMapper extends Mapper<LongWritable, Text, Tuple, IntWritable> {
private String fieldDelimRegex;
private String[] items;
private int skipFieldCount;
private Tuple outKey = new Tuple();
private IntWritable outVal = new IntWritable(1);
private List<String[]> obsStateList = new ArrayList<String[]>();
private String subFieldDelim;
private boolean partiallyTagged;
private String[] states;
private int[] windowFunction;
private List<Integer> stateIndexes = new ArrayList<Integer>();
private static Integer STATE_TRANS = 0;
private static Integer STATE_OBS = 1;
private static Integer INITIAL_STATE = 2;
private static final Logger LOG = Logger.getLogger(StateTransitionMapper.class);
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void setup(Context context) throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
if (conf.getBoolean("debug.on", false)) {
LOG.setLevel(Level.DEBUG);
}
fieldDelimRegex = conf.get("field.delim.regex", ",");
skipFieldCount = conf.getInt("hmmb.skip.field.count", 0);
subFieldDelim = conf.get("sub.field.delim", ":");
partiallyTagged = conf.getBoolean("hmmb.partially.tagged", false);
states = conf.get("hmmb.model.states").split(",");
if (partiallyTagged) {
windowFunction = Utility.intArrayFromString(conf.get("hmmb.window.function"), ",");
}
}
/* (non-Javadoc)
* @see org.apache.hadoop.mapreduce.Mapper#map(KEYIN, VALUEIN, org.apache.hadoop.mapreduce.Mapper.Context)
*/
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
items = value.toString().split(fieldDelimRegex);
if (partiallyTagged) {
processPartiallyTagged(context);
} else {
processFullyTagged(context);
}
}
/**
* @param context
* @throws IOException
* @throws InterruptedException
*/
private void processFullyTagged(Context context) throws IOException, InterruptedException {
obsStateList.clear();
if (items.length >= (skipFieldCount + 2)) {
for (int i = skipFieldCount; i < items.length; ++i) {
String[] obsState = items[i].split(subFieldDelim);;
obsStateList.add(obsState);
}
}
//all observation state pairs
for (int i = 0; i < obsStateList.size(); ++i) {
if (i == 0) {
//intial state
outKey.initialize();
outKey.add(INITIAL_STATE, obsStateList.get(i)[1], obsStateList.get(i)[1]);
context.write(outKey, outVal);
}
//state observation
outKey.initialize();
outKey.add(STATE_OBS, obsStateList.get(i)[1], obsStateList.get(i)[0]);
context.write(outKey, outVal);
if (i > 0) {
//state transition
outKey.initialize();
outKey.add(STATE_TRANS, obsStateList.get(i-1)[1], obsStateList.get(i)[1]);
context.write(outKey, outVal);
}
}
}
/**
* @param context
* @throws InterruptedException
* @throws IOException
*/
private void processPartiallyTagged(Context context) throws IOException, InterruptedException {
//identify states
stateIndexes.clear();
for (int i = 0; i < items.length; ++i) {
if (ArrayUtils.contains(states, items[i])) {
stateIndexes.add(i);
}
}
//intial state
outKey.initialize();
outKey.add(INITIAL_STATE, items[stateIndexes.get(0)], items[stateIndexes.get(0)]);
outVal.set(1);
context.write(outKey, outVal);
//state to observation
int leftBound = 0;
int rightBound = 0;
int leftWindow = 0;
int rightWindow = 0;
for (int i = 0; i < stateIndexes.size(); ++i ) {
//boundary on left
if (i > 0) {
leftWindow = stateIndexes.get(i) - stateIndexes.get(i-1) / 2;
leftBound =stateIndexes.get(i) - leftWindow;
} else {
leftBound = -1;
}
//boundary on right
if (i < stateIndexes.size() -1) {
rightWindow = stateIndexes.get(i+1) - stateIndexes.get(i) / 2;
rightBound = stateIndexes.get(i) + rightWindow;
} else {
rightBound = -1;
}
//at ends
if (leftBound == -1 && rightBound != -1) {
//first state
leftBound =stateIndexes.get(i) - rightWindow;
if (leftBound < 0) {
leftBound = 0;
}
} else if (rightBound == -1 && leftBound != -1) {
//last state
rightBound =stateIndexes.get(i) + leftWindow;
if (rightBound >= items.length) {
rightBound = items.length-1;
}
} else if (leftBound == -1 && rightBound == -1) {
//only one state
leftBound =stateIndexes.get(i) / 2;
rightBound = stateIndexes.get(i) + (items.length - 1 - stateIndexes.get(i)) / 2;
}
//state observation count to left
String state = items[stateIndexes.get(i)];
for (int j = stateIndexes.get(i)-1, k=0; j >= leftBound; --j,++k ) {
String obs = items[j];
outKey.initialize();
outKey.add(STATE_OBS, state, obs);
int val = k < windowFunction.length ? windowFunction[k] : windowFunction[ windowFunction.length -1];
outVal.set(val);
context.write(outKey, outVal);
}
//state observation count to left
for (int j = stateIndexes.get(i)+1, k=0; j <= rightBound; ++j,++k ) {
String obs = items[j];
outKey.initialize();
outKey.add(STATE_OBS, state, obs);
int val = k < windowFunction.length ? windowFunction[k] : windowFunction[ windowFunction.length -1];
outVal.set(val);
context.write(outKey, outVal);
}
}
//state to state
for (int i = 0; i < stateIndexes.size() -1; ++i) {
outKey.initialize();
outKey.add(STATE_TRANS, items[stateIndexes.get(i)], items[stateIndexes.get(i+1)]);
outVal.set(1);
context.write(outKey, outVal);
}
}
}
/**
* @author pranab
*
*/
public static class StateTransitionReducer extends Reducer<Tuple, IntWritable, NullWritable, Text> {
private String fieldDelim;
private Text outVal = new Text();
private String[] states;
private String[] observations;
private String[] initial;
private StateTransitionProbability stateTransProb;
private StateTransitionProbability stateObsProb;
private StateTransitionProbability initialStateProb;
private int count;
private static Integer STATE_TRANS = 0;
private static Integer STATE_OBS = 1;
private static Integer INITIAL_STATE = 2;
private static String INITIAL = "initial";
private static final Logger LOG = Logger.getLogger(StateTransitionMapper.class);
protected void setup(Context context)
throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
if (conf.getBoolean("debug.on", false)) {
LOG.setLevel(Level.DEBUG);
}
fieldDelim = conf.get("field.delim.out", ",");
states = conf.get("hmmb.model.states").split(",");
observations = conf.get("hmmb.model.observations").split(",");
int transProbScale = conf.getInt("hmmb.trans.prob.scale", 1000);
//state transition
stateTransProb = new StateTransitionProbability(states, states);
stateTransProb.setScale(transProbScale);
//state observation
stateObsProb = new StateTransitionProbability(states, observations);
stateObsProb.setScale(transProbScale);
//initial state
initial = new String[1];
initial[0] = INITIAL;
initialStateProb = new StateTransitionProbability(initial, states);
}
protected void cleanup(Context context)
throws IOException, InterruptedException {
//all states
outVal.set(Utility.join(states));
context.write(NullWritable.get(),outVal);
//all observations
outVal.set(Utility.join(observations));
context.write(NullWritable.get(),outVal);
//state transition
stateTransProb.normalizeRows();
for (int i = 0; i < states.length; ++i) {
String val = stateTransProb.serializeRow(i);
outVal.set(val);
context.write(NullWritable.get(),outVal);
}
//state observation
stateObsProb.normalizeRows();
for (int i = 0; i < states.length; ++i) {
String val = stateObsProb.serializeRow(i);
outVal.set(val);
context.write(NullWritable.get(),outVal);
}
//intial state
initialStateProb.normalizeRows();
String val = initialStateProb.serializeRow(0);
outVal.set(val);
context.write(NullWritable.get(),outVal);
}
protected void reduce(Tuple key, Iterable<IntWritable> values, Context context)
throws IOException, InterruptedException {
count = 0;
for (IntWritable value : values) {
count += value.get();
}
//state transition
if (key.getInt(0) == STATE_TRANS) {
String fromSt = key.getString(1);
String toSt = key.getString(2);
stateTransProb.add(fromSt, toSt, count);
} else if (key.getInt(0) == STATE_OBS) {
String fromSt = key.getString(1);
String toObs = key.getString(2);
stateObsProb.add(fromSt, toObs, count);
} else if (key.getInt(0) == INITIAL_STATE) {
String toSt = key.getString(1);
initialStateProb.add(INITIAL, toSt, count);
}
}
}
public static void main(String[] args) throws Exception {
int exitCode = ToolRunner.run(new HiddenMarkovModelBuilder(), args);
System.exit(exitCode);
}
}