/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.bayes; import java.io.File; import java.io.FilenameFilter; import java.io.IOException; import java.nio.charset.Charset; import java.util.List; import java.util.Map; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.OptionException; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.ConfusionMatrix; import org.apache.mahout.classifier.ResultAnalyzer; import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver; import org.apache.mahout.common.CommandLineUtil; import org.apache.mahout.common.iterator.FileLineIterable; import org.apache.mahout.common.TimingStatistics; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.common.nlp.NGrams; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Test the Naive Bayes classifier with improved weighting * <p/> * To run the twenty newsgroups example: refer http://cwiki.apache.org/MAHOUT/twentynewsgroups.html */ public final class TestClassifier { private static final Logger log = LoggerFactory.getLogger(TestClassifier.class); private TestClassifier() { // do nothing } public static void main(String[] args) throws IOException, InvalidDatastoreException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option pathOpt = obuilder.withLongName("model").withRequired(true).withArgument( abuilder.withName("model").withMinimum(1).withMaximum(1).create()).withDescription( "The path on HDFS as defined by the -source parameter").withShortName("m") .create(); Option dirOpt = obuilder.withLongName("testDir").withRequired(true).withArgument( abuilder.withName("testDir").withMinimum(1).withMaximum(1).create()).withDescription( "The directory where test documents resides in").withShortName("d").create(); Option helpOpt = DefaultOptionCreator.helpOption(); Option encodingOpt = obuilder.withLongName("encoding").withArgument( abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription( "The file encoding. Defaults to UTF-8").withShortName("e").create(); Option defaultCatOpt = obuilder.withLongName("defaultCat").withArgument( abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create()).withDescription( "The default category Default Value: unknown").withShortName("default").create(); Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false).withArgument( abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription( "Size of the n-gram. Default Value: 1").withShortName("ng").create(); Option alphaOpt = obuilder.withLongName("alpha").withRequired(false).withArgument( abuilder.withName("a").withMinimum(1).withMaximum(1).create()).withDescription( "Smoothing parameter Default Value: 1.0").withShortName("a").create(); Option verboseOutputOpt = obuilder.withLongName("verbose").withRequired(false).withDescription( "Output which values were correctly and incorrectly classified").withShortName("v").create(); Option typeOpt = obuilder.withLongName("classifierType").withRequired(false).withArgument( abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription( "Type of classifier: bayes|cbayes. Default Value: bayes").withShortName("type").create(); Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(false).withArgument( abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription( "Location of model: hdfs").withShortName("source").create(); Option methodOpt = obuilder.withLongName("method").withRequired(false).withArgument( abuilder.withName("method").withMinimum(1).withMaximum(1).create()).withDescription( "Method of Classification: sequential|mapreduce. Default Value: mapreduce").withShortName("method") .create(); Option confusionMatrixOpt = obuilder.withLongName("confusionMatrix").withRequired(false).withArgument( abuilder.withName("confusionMatrix").withMinimum(1).withMaximum(1).create()).withDescription( "Export ConfusionMatrix as SequenceFile").withShortName("cm").create(); Group group = gbuilder.withName("Options").withOption(defaultCatOpt).withOption(dirOpt).withOption( encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt) .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt).withOption(alphaOpt) .withOption(confusionMatrixOpt).create(); try { Parser parser = new Parser(); parser.setGroup(group); CommandLine cmdLine = parser.parse(args); if (cmdLine.hasOption(helpOpt)) { CommandLineUtil.printHelp(group); return; } BayesParameters params = new BayesParameters(); // Setting all default values int gramSize = 1; String modelBasePath = (String) cmdLine.getValue(pathOpt); if (cmdLine.hasOption(gramSizeOpt)) { gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)); } String classifierType = "bayes"; if (cmdLine.hasOption(typeOpt)) { classifierType = (String) cmdLine.getValue(typeOpt); } String dataSource = "hdfs"; if (cmdLine.hasOption(dataSourceOpt)) { dataSource = (String) cmdLine.getValue(dataSourceOpt); } String defaultCat = "unknown"; if (cmdLine.hasOption(defaultCatOpt)) { defaultCat = (String) cmdLine.getValue(defaultCatOpt); } String encoding = "UTF-8"; if (cmdLine.hasOption(encodingOpt)) { encoding = (String) cmdLine.getValue(encodingOpt); } String alphaI = "1.0"; if (cmdLine.hasOption(alphaOpt)) { alphaI = (String) cmdLine.getValue(alphaOpt); } boolean verbose = cmdLine.hasOption(verboseOutputOpt); String testDirPath = (String) cmdLine.getValue(dirOpt); String classificationMethod = "mapreduce"; if (cmdLine.hasOption(methodOpt)) { classificationMethod = (String) cmdLine.getValue(methodOpt); } String confusionMatrixFile = null; if (cmdLine.hasOption(confusionMatrixOpt)) { confusionMatrixFile = (String) cmdLine.getValue(confusionMatrixOpt); } params.setGramSize(gramSize); params.set("verbose", Boolean.toString(verbose)); params.setBasePath(modelBasePath); params.set("classifierType", classifierType); params.set("dataSource", dataSource); params.set("defaultCat", defaultCat); params.set("encoding", encoding); params.set("alpha_i", alphaI); params.set("testDirPath", testDirPath); params.set("confusionMatrix", confusionMatrixFile); if ("sequential".equalsIgnoreCase(classificationMethod)) { classifySequential(params); } else if ("mapreduce".equalsIgnoreCase(classificationMethod)) { classifyParallel(params); } } catch (OptionException e) { CommandLineUtil.printHelp(group); } } public static void classifySequential(BayesParameters params) throws IOException, InvalidDatastoreException { log.info("Loading model from: {}", params.print()); boolean verbose = Boolean.valueOf(params.get("verbose")); File dir = new File(params.get("testDirPath")); File[] subdirs = dir.listFiles(new FilenameFilter() { @Override public boolean accept(File file, String s) { return !s.startsWith("."); } }); Algorithm algorithm; Datastore datastore; if ("hdfs".equals(params.get("dataSource"))) { if ("bayes".equalsIgnoreCase(params.get("classifierType"))) { log.info("Testing Bayes Classifier"); algorithm = new BayesAlgorithm(); datastore = new InMemoryBayesDatastore(params); } else if ("cbayes".equalsIgnoreCase(params.get("classifierType"))) { log.info("Testing Complementary Bayes Classifier"); algorithm = new CBayesAlgorithm(); datastore = new InMemoryBayesDatastore(params); } else { throw new IllegalArgumentException("Unrecognized classifier type: " + params.get("classifierType")); } } else { throw new IllegalArgumentException("Unrecognized dataSource type: " + params.get("dataSource")); } ClassifierContext classifier = new ClassifierContext(algorithm, datastore); classifier.initialize(); ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifier.getLabels(), params.get("defaultCat")); TimingStatistics totalStatistics = new TimingStatistics(); if (subdirs != null) { for (File file : subdirs) { if (verbose) { log.info("--------------"); log.info("Testing: {}", file); } TimingStatistics operationStats = new TimingStatistics(); long lineNum = 0; for (String line : new FileLineIterable(new File(file.getPath()), Charset.forName(params .get("encoding")), false)) { Map<String,List<String>> document = new NGrams(line, Integer.parseInt(params.get("gramSize"))) .generateNGrams(); for (Map.Entry<String,List<String>> stringListEntry : document.entrySet()) { String correctLabel = stringListEntry.getKey(); List<String> strings = stringListEntry.getValue(); TimingStatistics.Call call = operationStats.newCall(); TimingStatistics.Call outercall = totalStatistics.newCall(); ClassifierResult classifiedLabel = classifier.classifyDocument(strings.toArray(new String[strings .size()]), params.get("defaultCat")); call.end(); outercall.end(); boolean correct = resultAnalyzer.addInstance(correctLabel, classifiedLabel); if (verbose) { // We have one document per line log.info("Line Number: {} Line(30): {} Expected Label: {} Classified Label: {} Correct: {}", new Object[] {lineNum, line.length() > 30 ? line.substring(0, 30) : line, correctLabel, classifiedLabel.getLabel(), correct,}); } // log.info("{} {}", correctLabel, classifiedLabel); } lineNum++; } ConfusionMatrix matrix = resultAnalyzer.getConfusionMatrix(); log.info("{}", matrix); BayesClassifierDriver.confusionMatrixSeqFileExport(params, matrix); log.info("ConfusionMatrix: {}", matrix.toString()); log.info("Classified instances from {}", file.getName()); if (verbose) { log.info("Performance stats {}", operationStats.toString()); } } } if (verbose) { log.info("{}", totalStatistics); } log.info("{}", resultAnalyzer); } public static void classifyParallel(BayesParameters params) throws IOException { BayesClassifierDriver.runJob(params); } }