/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier; import java.io.File; import java.nio.charset.Charset; import java.util.List; import com.google.common.io.Files; import org.apache.commons.cli2.CommandLine; import org.apache.commons.cli2.Group; import org.apache.commons.cli2.Option; import org.apache.commons.cli2.builder.ArgumentBuilder; import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.commons.cli2.builder.GroupBuilder; import org.apache.commons.cli2.commandline.Parser; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.util.Version; import org.apache.mahout.classifier.bayes.Algorithm; import org.apache.mahout.classifier.bayes.BayesAlgorithm; import org.apache.mahout.classifier.bayes.BayesParameters; import org.apache.mahout.classifier.bayes.CBayesAlgorithm; import org.apache.mahout.classifier.bayes.Datastore; import org.apache.mahout.classifier.bayes.ClassifierContext; import org.apache.mahout.classifier.bayes.InMemoryBayesDatastore; import org.apache.mahout.common.ClassUtils; import org.apache.mahout.common.nlp.NGrams; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Runs the Bayes classifier using the given model location on HDFS * */ public final class Classify { private static final Logger log = LoggerFactory.getLogger(Classify.class); private Classify() { } public static void main(String[] args) throws Exception { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); Option pathOpt = obuilder.withLongName("path").withRequired(true).withArgument( abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription( "The local file system path").withShortName("m").create(); Option classifyOpt = obuilder.withLongName("classify").withRequired(true).withArgument( abuilder.withName("classify").withMinimum(1).withMaximum(1).create()).withDescription( "The doc to classify").withShortName("").create(); Option encodingOpt = obuilder.withLongName("encoding").withRequired(true).withArgument( abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription( "The file encoding. Default: UTF-8").withShortName("e").create(); Option analyzerOpt = obuilder.withLongName("analyzer").withRequired(true).withArgument( abuilder.withName("analyzer").withMinimum(1).withMaximum(1).create()).withDescription( "The Analyzer to use").withShortName("a").create(); Option defaultCatOpt = obuilder.withLongName("defaultCat").withRequired(true).withArgument( abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create()).withDescription( "The default category").withShortName("d").create(); Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true).withArgument( abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription( "Size of the n-gram").withShortName("ng").create(); Option typeOpt = obuilder.withLongName("classifierType").withRequired(true).withArgument( abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription( "Type of classifier").withShortName("type").create(); Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(true).withArgument( abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription( "Location of model: hdfs").withShortName("source").create(); Group options = gbuilder.withName("Options").withOption(pathOpt).withOption(classifyOpt).withOption( encodingOpt).withOption(analyzerOpt).withOption(defaultCatOpt).withOption(gramSizeOpt).withOption( typeOpt).withOption(dataSourceOpt).create(); Parser parser = new Parser(); parser.setGroup(options); CommandLine cmdLine = parser.parse(args); int gramSize = 1; if (cmdLine.hasOption(gramSizeOpt)) { gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)); } BayesParameters params = new BayesParameters(); params.setGramSize(gramSize); String modelBasePath = (String) cmdLine.getValue(pathOpt); params.setBasePath(modelBasePath); log.info("Loading model from: {}", params.print()); Algorithm algorithm; Datastore datastore; String classifierType = (String) cmdLine.getValue(typeOpt); String dataSource = (String) cmdLine.getValue(dataSourceOpt); if ("hdfs".equals(dataSource)) { if ("bayes".equalsIgnoreCase(classifierType)) { log.info("Using Bayes Classifier"); algorithm = new BayesAlgorithm(); datastore = new InMemoryBayesDatastore(params); } else if ("cbayes".equalsIgnoreCase(classifierType)) { log.info("Using Complementary Bayes Classifier"); algorithm = new CBayesAlgorithm(); datastore = new InMemoryBayesDatastore(params); } else { throw new IllegalArgumentException("Unrecognized classifier type: " + classifierType); } } else { throw new IllegalArgumentException("Unrecognized dataSource type: " + dataSource); } ClassifierContext classifier = new ClassifierContext(algorithm, datastore); classifier.initialize(); String defaultCat = "unknown"; if (cmdLine.hasOption(defaultCatOpt)) { defaultCat = (String) cmdLine.getValue(defaultCatOpt); } File docPath = new File((String) cmdLine.getValue(classifyOpt)); String encoding = "UTF-8"; if (cmdLine.hasOption(encodingOpt)) { encoding = (String) cmdLine.getValue(encodingOpt); } Analyzer analyzer = null; if (cmdLine.hasOption(analyzerOpt)) { analyzer = ClassUtils.instantiateAs((String) cmdLine.getValue(analyzerOpt), Analyzer.class); } if (analyzer == null) { analyzer = new StandardAnalyzer(Version.LUCENE_31); } log.info("Converting input document to proper format"); String[] document = BayesFileFormatter.readerToDocument(analyzer,Files.newReader(docPath, Charset.forName(encoding))); StringBuilder line = new StringBuilder(); for (String token : document) { line.append(token).append(' '); } List<String> doc = new NGrams(line.toString(), gramSize).generateNGramsWithoutLabel(); log.info("Done converting"); log.info("Classifying document: {}", docPath); ClassifierResult category = classifier.classifyDocument(doc.toArray(new String[doc.size()]), defaultCat); log.info("Category for {} is {}", docPath, category); } }