/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.bayes; import java.io.BufferedWriter; import java.io.File; import java.util.List; import com.google.common.base.Charsets; import com.google.common.io.Closeables; import com.google.common.io.Files; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.mahout.classifier.ClassifierData; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.ResultAnalyzer; import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.nlp.NGrams; import org.junit.Before; import org.junit.Test; public final class BayesClassifierSelfTest extends MahoutTestCase { @Override @Before public void setUp() throws Exception { super.setUp(); File tempInputFile = getTestTempFile("bayesinput"); BufferedWriter writer = Files.newWriter(tempInputFile, Charsets.UTF_8); try { for (String[] entry : ClassifierData.DATA) { writer.write(entry[0] + '\t' + entry[1] + '\n'); } } finally { Closeables.closeQuietly(writer); } Path input = getTestTempFilePath("bayesinput"); Configuration conf = new Configuration(); FileSystem fs = input.getFileSystem(conf); fs.copyFromLocalFile(new Path(tempInputFile.getAbsolutePath()), input); } @Test public void testSelfTestBayes() throws Exception { BayesParameters params = new BayesParameters(); params.setGramSize(1); params.set("alpha_i", "1.0"); params.set("dataSource", "hdfs"); Path bayesInputPath = getTestTempFilePath("bayesinput"); Path bayesModelPath = getTestTempDirPath("bayesmodel"); TrainClassifier.trainNaiveBayes(bayesInputPath, bayesModelPath, params); params.set("verbose", "true"); params.setBasePath(bayesModelPath.toString()); params.set("classifierType", "bayes"); params.set("dataSource", "hdfs"); params.set("defaultCat", "unknown"); params.set("encoding", "UTF-8"); params.set("alpha_i", "1.0"); Algorithm algorithm = new BayesAlgorithm(); Datastore datastore = new InMemoryBayesDatastore(params); ClassifierContext classifier = new ClassifierContext(algorithm, datastore); classifier.initialize(); ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifier.getLabels(), params.get("defaultCat")); for (String[] entry : ClassifierData.DATA) { List<String> document = new NGrams(entry[1], params.getGramSize()).generateNGramsWithoutLabel(); assertEquals(3, classifier.classifyDocument(document.toArray(new String[document.size()]), params.get("defaultCat"), 100).length); ClassifierResult result = classifier.classifyDocument(document.toArray(new String[document.size()]), params .get("defaultCat")); assertEquals(entry[0], result.getLabel()); resultAnalyzer.addInstance(entry[0], result); } int[][] matrix = resultAnalyzer.getConfusionMatrix().getConfusionMatrix(); for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { assertEquals(i == j ? 4 : 0, matrix[i][j]); } } params.set("testDirPath", bayesInputPath.toString()); TestClassifier.classifyParallel(params); Configuration conf = new Configuration(); Path outputFiles = getTestTempFilePath("bayesinput-output/part*"); matrix = BayesClassifierDriver.readResult(outputFiles, conf, params).getConfusionMatrix(); for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { assertEquals(i == j ? 4 : 0, matrix[i][j]); } } } @Test public void testSelfTestCBayes() throws Exception { BayesParameters params = new BayesParameters(); params.setGramSize(1); params.set("alpha_i", "1.0"); params.set("dataSource", "hdfs"); Path bayesInputPath = getTestTempFilePath("bayesinput"); Path bayesModelPath = getTestTempDirPath("cbayesmodel"); TrainClassifier.trainCNaiveBayes(bayesInputPath, bayesModelPath, params); params.set("verbose", "true"); params.setBasePath(bayesModelPath.toString()); params.set("classifierType", "cbayes"); params.set("dataSource", "hdfs"); params.set("defaultCat", "unknown"); params.set("encoding", "UTF-8"); params.set("alpha_i", "1.0"); Algorithm algorithm = new CBayesAlgorithm(); Datastore datastore = new InMemoryBayesDatastore(params); ClassifierContext classifier = new ClassifierContext(algorithm, datastore); classifier.initialize(); ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifier.getLabels(), params.get("defaultCat")); for (String[] entry : ClassifierData.DATA) { List<String> document = new NGrams(entry[1], params.getGramSize()).generateNGramsWithoutLabel(); assertEquals(3, classifier.classifyDocument(document.toArray(new String[document.size()]), params.get("defaultCat"), 100).length); ClassifierResult result = classifier.classifyDocument(document.toArray(new String[document.size()]), params .get("defaultCat")); assertEquals(entry[0], result.getLabel()); resultAnalyzer.addInstance(entry[0], result); } int[][] matrix = resultAnalyzer.getConfusionMatrix().getConfusionMatrix(); for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { assertEquals(i == j ? 4 : 0, matrix[i][j]); } } params.set("testDirPath", bayesInputPath.toString()); TestClassifier.classifyParallel(params); Configuration conf = new Configuration(); Path outputFiles = getTestTempFilePath("bayesinput-output/part*"); matrix = BayesClassifierDriver.readResult(outputFiles, conf, params).getConfusionMatrix(); for (int i = 0; i < 3; i++) { for (int j = 0; j < 3; j++) { assertEquals(i == j ? 4 : 0, matrix[i][j]); } } } }