/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.rbm.test; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.ResultAnalyzer; import org.apache.mahout.classifier.rbm.RBMClassifier; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.Pair; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The Class TestRBMClassifierJob which runs the tests in map/reduce or locally multithreaded. */ public class TestRBMClassifierJob extends AbstractJob { /** The Constant log. */ private static final Logger log = LoggerFactory.getLogger(TestRBMClassifierJob.class); /** * The main method. * * @param args the arguments * @throws Exception the exception */ public static void main(String[] args) throws Exception { ToolRunner.run(new Configuration(), new TestRBMClassifierJob(), args); } private int iterations; /* (non-Javadoc) * @see org.apache.hadoop.util.Tool#run(java.lang.String[]) */ @Override public int run(String[] args) throws Exception { addInputOption(); addOption("model", "m", "The path to the model built during training", true); addOption("labelcount", "lc", "total count of labels existent in the training set", true); addOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION, "max", "least number of stable iterations in classification layer when classifying","10"); addOption(new DefaultOptionBuilder() .withLongName(DefaultOptionCreator.MAPREDUCE_METHOD) .withRequired(false) .withDescription("Run tests with map/reduce") .withShortName("mr").create()); Map<String, String> parsedArgs = parseArguments(args); if (parsedArgs == null) { return -1; } int labelcount = Integer.parseInt(getOption("labelcount")); iterations = Integer.parseInt(getOption("maxIter")); //check models existence Path model = new Path(parsedArgs.get("--model")); if(!model.getFileSystem(getConf()).exists(model)) { log.error("Model file does not exist!"); return -1; } //create the list of all labels List<String> lables= new ArrayList<String>(); for(int i = 0; i<labelcount; i++) lables.add(String.valueOf(i)); FileSystem fs = getInputPath().getFileSystem(getConf()); ResultAnalyzer analyzer = new ResultAnalyzer(lables, "-1"); //initiate the paths to the test batches Path[] batches; if(fs.isFile(getInputPath())) batches = new Path[]{getInputPath()}; else { FileStatus[] stati = fs.listStatus(getInputPath()); batches = new Path[stati.length]; for (int i = 0; i < stati.length; i++) { batches[i] = stati[i].getPath(); } } if(hasOption("mapreduce")) HadoopUtil.delete(getConf(), getTempPath("testresults")); for (Path input : batches) { if(hasOption("mapreduce")) { HadoopUtil.cacheFiles(model, getConf()); //the output key is the expected value, the output value are the scores for all the labels Job testJob = prepareJob(input, getTempPath("testresults"), SequenceFileInputFormat.class, TestRBMClassifierMapper.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class); testJob.getConfiguration().set("maxIter", String.valueOf(iterations)); testJob.waitForCompletion(true); //loop over the results and create the confusion matrix SequenceFileDirIterable<IntWritable, VectorWritable> dirIterable = new SequenceFileDirIterable<IntWritable, VectorWritable>(getTempPath("testresults"), PathType.LIST, PathFilters.partFilter(), getConf()); analyzeResults(dirIterable, analyzer); } else { //test job locally runTestsLocally(model, analyzer,input); } } //output the result of the tests log.info("RBMClassifier Results: {}", analyzer); //stop all running threads if(executor!=null) executor.shutdownNow(); return 0; } /** The executor. */ private ExecutorService executor; /** The tasks. */ List<RBMClassifierCall> tasks; /** * Analyze results locally. * * @param model the model * @param analyzer the analyzer * @param input the input * @throws IOException Signals that an I/O exception has occurred. * @throws InterruptedException the interrupted exception * @throws ExecutionException the execution exception */ private void runTestsLocally(Path model, ResultAnalyzer analyzer, Path input) throws IOException, InterruptedException, ExecutionException { int testsize =0; //maximum number of threads that are used, I think 20 is ok int threadCount =20; RBMClassifier rbmCl = RBMClassifier.materialize(model, getConf()); //initialize the executor if not already done if(executor==null) executor = Executors.newFixedThreadPool(threadCount); //initialize the tasks, which are run by the executor if(tasks==null) tasks = new ArrayList<RBMClassifierCall>(); for (Pair<IntWritable, VectorWritable> record : new SequenceFileIterable<IntWritable, VectorWritable>(input,getConf())) { //prepare the tasks if(tasks.size()<threadCount) tasks.add(new RBMClassifierCall(rbmCl.clone(), record.getSecond().get(), record.getFirst().get(), iterations)); else { tasks.get(testsize%threadCount).input = record.getSecond().get(); tasks.get(testsize%threadCount).label = record.getFirst().get(); } //run the tasks if(testsize%threadCount==threadCount-1) { List<Future<Pair<Integer,Vector>>> futureResults = executor.invokeAll(tasks); //analyze results for (int i = 0; i < futureResults.size(); i++) { int bestIdx = Integer.MIN_VALUE; double bestScore = Long.MIN_VALUE; Pair<Integer, Vector> pair = futureResults.get(i).get(); for (Vector.Element element : pair.getSecond()) { if (element.get() > bestScore) { bestScore = element.get(); bestIdx = element.index(); } } if (bestIdx != Integer.MIN_VALUE) { ClassifierResult classifierResult = new ClassifierResult(String.valueOf(bestIdx), bestScore); analyzer.addInstance(String.valueOf(pair.getFirst()), classifierResult); } } } testsize++; } //run and analyze remaining tasks if(testsize%20!=0) { List<Future<Pair<Integer,Vector>>> futureResults = executor.invokeAll(tasks.subList(0, (testsize-1) %20)); for (int i = 0; i < futureResults.size(); i++) { int bestIdx = Integer.MIN_VALUE; double bestScore = Long.MIN_VALUE; Pair<Integer, Vector> pair = futureResults.get(i).get(); for (Vector.Element element : pair.getSecond()) { if (element.get() > bestScore) { bestScore = element.get(); bestIdx = element.index(); } } if (bestIdx != Integer.MIN_VALUE) { ClassifierResult classifierResult = new ClassifierResult(String.valueOf(bestIdx), bestScore); analyzer.addInstance(String.valueOf(pair.getFirst()), classifierResult); } } } } /** * Analyze results of M/R job. * * @param dirIterable the directory with the results * @param analyzer the analyzer */ private void analyzeResults(SequenceFileDirIterable<IntWritable, VectorWritable> dirIterable, ResultAnalyzer analyzer) { for (Pair<IntWritable, VectorWritable> pair : dirIterable) { int bestIdx = Integer.MIN_VALUE; double bestScore = Long.MIN_VALUE; for (Vector.Element element : pair.getSecond().get()) { if (element.get() > bestScore) { bestScore = element.get(); bestIdx = element.index(); } } if (bestIdx != Integer.MIN_VALUE) { ClassifierResult classifierResult = new ClassifierResult(String.valueOf(bestIdx), bestScore); analyzer.addInstance(String.valueOf(pair.getFirst().get()), classifierResult); } } } /** * The Class RBMClassifier is the callable thread for the local classifying task. */ class RBMClassifierCall implements Callable<Pair<Integer,Vector>> { /** The rbm cl. */ private RBMClassifier rbmCl; /** The input. */ private Vector input; /** The label. */ private int label; /** The iterations. */ private int iterations; /** * Instantiates a new rBM classifier call. * * @param rbmCl the rbm cl * @param input the input * @param label the label * @param iterations the number of iterations until stable */ public RBMClassifierCall(RBMClassifier rbmCl, Vector input, int label, int iterations) { this.rbmCl = rbmCl; this.input = input; this.label = label; this.iterations = iterations; } /* (non-Javadoc) * @see java.util.concurrent.Callable#call() */ @Override public Pair<Integer,Vector> call() throws Exception { return new Pair<Integer, Vector>(label, rbmCl.classify(input, iterations)); } } }