/* * avenir: Predictive analytic based on Hadoop Map Reduce * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.avenir.model; import java.io.IOException; import java.io.InputStream; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; import org.avenir.tree.DecisionTreeModel; import org.chombo.util.FeatureSchema; import org.chombo.util.Utility; /** * Generic classification model predictor MR * @author pranab * */ public class ModelPredictor extends Configured implements Tool { @Override public int run(String[] args) throws Exception { Job job = new Job(getConf()); String jobName = "model predictor MR"; job.setJobName(jobName); job.setJarByClass(ModelPredictor.class); FileInputFormat.addInputPath(job, new Path(args[0])); FileOutputFormat.setOutputPath(job, new Path(args[1])); Utility.setConfiguration(job.getConfiguration(), "avenir"); job.setMapperClass(ModelPredictor.PredictorMapper.class); job.setOutputKeyClass(NullWritable.class); job.setOutputValueClass(Text.class); int status = job.waitForCompletion(true) ? 0 : 1; return status; } /** * @author pranab * */ public static class PredictorMapper extends Mapper<LongWritable, Text, NullWritable, Text> { private String[] items; private Text outVal = new Text(); private String fieldDelimRegex; private String fieldDelim; private PredictiveModel model; private EnsemblePredictiveModel ensembleModel; private String predClass; private String outputMode; private int idOrdinal; private int classAttrOrdinal; private StringBuilder stBld = new StringBuilder();; private static String CLASS_DEC_TREE = "decTreeClassifier"; private static final String OUTPUT_WITH_RECORD = "withRecord"; private static final String OUTPUT_WITH_ID = "withKId"; private static final String OUTPUT_WITH_CLASS_ATTR = "withActualClassAttr"; /* (non-Javadoc) * @see org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper.Context) */ protected void setup(Context context) throws IOException, InterruptedException { Configuration config = context.getConfiguration(); fieldDelimRegex = config.get("field.delim.regex", ","); fieldDelim = config.get("field.delim.out", ","); //schema FeatureSchema schema = Utility.getFeatureSchema(config, "mop.feature.schema.file.path"); //model files String modelDirPath = Utility.assertStringConfigParam(config, "mop.model.dir.path", "missing model directory path"); String[] modelFileNames = Utility.assertStringArrayConfigParam(config, "mop.model.file.names", Utility.configDelim, "missing mode file names"); String classifierType = Utility.assertStringConfigParam(config, "mop.classifier.type", "missing classifier type"); //error counting boolean errorCountingEnabled = config.getBoolean("mop.error.counting.enabled", false); int classAttrOrd = -1; if (errorCountingEnabled) { classAttrOrd = Utility.assertIntConfigParam(config, "mop.class.attr.ord", ""); } //cost based classification boolean costBasedPredictionEnabled = config.getBoolean("mop.cost.based.prediction.enabled", false); String[] classAttrValues = null; double[] misclassCosts = null; if (costBasedPredictionEnabled || errorCountingEnabled) { classAttrValues = Utility.assertStringArrayConfigParam(config, "mop.class.attr.values", Utility.configDelim, "missing class atrribute values, need for for cost based prediction"); if (classAttrValues.length > 2) { throw new IllegalStateException("cost based classification possible only for binary classification"); } if (costBasedPredictionEnabled) { misclassCosts = Utility.assertDoubleArrayConfigParam(config, "mop.miss.class.costs", Utility.configDelim, "missing misclassification costs"); } } //build model if (modelFileNames.length > 1) { //ensemble double[] memeberWeights = Utility.optionalDoubleArrayConfigParam(config, "mop.ensemble.memeber.weights", Utility.configDelim); ensembleModel = new EnsemblePredictiveModel(); for (int i = 0; i < modelFileNames.length; ++i) { PredictiveModel memberModel = buildModel(schema, modelDirPath, modelFileNames[i], classifierType, false, classAttrOrd, classAttrValues, misclassCosts); double weight = null != memeberWeights ? memeberWeights[i] : 1.0; ensembleModel.addModel(memberModel, weight); } //error counting if (errorCountingEnabled) { ensembleModel.enableErrorCounting(classAttrOrd, classAttrValues[0], classAttrValues[1]); } } else { //single model = buildModel(schema, modelDirPath, modelFileNames[0], classifierType, errorCountingEnabled, classAttrOrd, classAttrValues, misclassCosts); } //output outputMode = config.get("mop.output.mode", OUTPUT_WITH_RECORD); if (outputMode.equals(OUTPUT_WITH_ID)) { idOrdinal = Utility.assertIntConfigParam(config, "mop.rec.id.ordinal", "missing id ordinal"); } if (outputMode.equals(OUTPUT_WITH_CLASS_ATTR)) { classAttrOrdinal = Utility.assertIntConfigParam(config, "mop.rec.class.attr.ordinal", "missing class attribute ordinal"); } } /** * @param modelDirPath * @param modelFileName * @param classifierType * @param errorCountingEnabled * @param classAttrOrd * @param classAttrValues * @param misclassCosts * @return * @throws IOException */ private PredictiveModel buildModel(FeatureSchema schema,String modelDirPath, String modelFileName, String classifierType, boolean errorCountingEnabled, int classAttrOrd, String [] classAttrValues, double[] misclassCosts) throws IOException { PredictiveModel model = null; String modelFilePath = modelDirPath + "/" + modelFileName; InputStream modelStream = Utility.getFileStream(modelFilePath); if (classifierType.equals(CLASS_DEC_TREE)) { model = new DecisionTreeModel(schema, modelStream); } else { throw new IllegalStateException("invalid classifier type"); } modelStream.close(); //error counting if (errorCountingEnabled) { model.enableErrorCounting(classAttrOrd, classAttrValues[0], classAttrValues[1]); } //cost based classification if (null != misclassCosts) { model.enableCostBasedPrediction(classAttrValues[0], classAttrValues[1], misclassCosts[0], misclassCosts[1]); } return model; } @Override protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { items = value.toString().split(fieldDelimRegex); predClass = null != model ? model.predict(items) : ensembleModel.predict(items); stBld.delete(0, stBld.length()); if (outputMode.equals(OUTPUT_WITH_RECORD)) { //full record stBld.append(value.toString()).append(fieldDelim).append(predClass); } else { //partial record if (outputMode.equals(OUTPUT_WITH_ID)) { stBld.append(items[idOrdinal]).append(fieldDelim); } if (outputMode.equals(OUTPUT_WITH_CLASS_ATTR)) { stBld.append(items[classAttrOrdinal]).append(fieldDelim); } if (stBld.length() == 0) { throw new IllegalStateException("invalid output mode"); } stBld.append(predClass); } outVal.set(stBld.toString()); context.write(NullWritable.get(), outVal); } } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { int exitCode = ToolRunner.run(new ModelPredictor(), args); System.exit(exitCode); } }