/* * Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ------------------- * To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit * http://www.manning.com/ingersoll */ package com.tamingtext.classifier.bayes; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.util.HashMap; import java.util.Map; import junit.framework.TestCase; import org.apache.mahout.classifier.bayes.TestClassifier; import org.apache.mahout.classifier.bayes.TrainClassifier; import org.apache.solr.SolrTestCaseJ4; import org.junit.BeforeClass; import org.junit.Test; import com.tamingtext.classifier.bayes.ExtractTrainingData; public class ExtractTrainingDataTest extends SolrTestCaseJ4 { static File baseDir; @BeforeClass public static void beforeClass() throws Exception { baseDir = new File("target/test-output/extract-test"); baseDir.delete(); baseDir.mkdirs(); initCore("bayes-update-config.xml", "bayes-update-schema.xml"); } public static File createTempDirectory() throws IOException { File file = File.createTempFile("extract-test", "test"); file.delete(); file.mkdirs(); if (!file.isDirectory()) { throw new IOException("Could not create temporary directory: " + file); } return file; } public static String readFile(File file) throws IOException { StringBuilder buf = new StringBuilder(); String line; BufferedReader r = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF-8")); while ((line = r.readLine()) != null) { buf.append(line); } return buf.toString(); } @Test public void testExtract() throws Exception { assertU("Add a doc to be classified", adoc("id", "1", "category", "scifi", "details", "Star Wars: A New Hope")); assertU("Add a doc to be classified", adoc("id", "2", "category", "scifi", "details", "Star Wars: The Empire Strikes Back")); assertU("Add a doc to be classified", adoc("id", "3", "category", "scifi", "details", "Star Wars: The Revenge of the Jedi")); assertU("Add a doc to be classified", adoc("id", "4", "category", "fantasy", "details", "Lord of the Rings: Fellowship of the Ring")); assertU("Add a doc to be classified", adoc("id", "5", "category", "fantasy", "details", "Lord of the Rings: The Two Towers")); assertU("Add a doc to be classified", adoc("id", "6", "category", "fantasy", "details", "Lord of the Rings: Return of the King")); assertU(commit()); File outputDir = new File(baseDir, "extract"); File indexDir = new File(dataDir, "index"); File categoryFile = new File("src/test/resources/solr/conf/categories.txt"); String[] extractArgs = { "--dir", indexDir.getAbsolutePath(), "--categories", categoryFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), "--category-fields", "category", "--text-fields", "details", "--use-term-vectors" }; ExtractTrainingData.main(extractArgs); File[] files = outputDir.listFiles(); Map<String,File> names = new HashMap<String,File>(); for (File f: files) { names.put(f.getName(), f); } TestCase.assertEquals(2, files.length); TestCase.assertTrue("File list contains scifi", names.containsKey("scifi")); TestCase.assertTrue("File list contains fantasy", names.containsKey("fantasy")); String scifiContent = readFile(names.get("scifi")); TestCase.assertTrue("Sci-fi file has proper label", scifiContent.startsWith("scifi\t")); TestCase.assertTrue("Sci-fi file contents", scifiContent.contains("star")); TestCase.assertTrue("Sci-fi file contents", scifiContent.contains("jedi")); String fantasyContent = readFile(names.get("fantasy")); TestCase.assertTrue("Fantasy file has proper label", fantasyContent.startsWith("fantasy\t")); TestCase.assertTrue("Fantasy file contents", fantasyContent.contains("lord")); TestCase.assertTrue("Fantasy file contents", fantasyContent.contains("tower")); File modelDir = new File(baseDir, "model"); String[] trainArgs = { "-i", outputDir.getAbsolutePath(), "-o", modelDir.getAbsolutePath(), "-ng", "1", "-type", "bayes", "-source", "hdfs" }; TrainClassifier.main(trainArgs); String[] testArgs = { "-d", outputDir.getAbsolutePath(), "-m", modelDir.getAbsolutePath(), "-ng", "1", "-type", "bayes", "-source", "hdfs" }; TestClassifier.main(testArgs); } }