/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a> */ package cc.mallet.classify.tests; import junit.framework.*; import java.net.URI; import java.io.File; import cc.mallet.classify.*; import cc.mallet.pipe.*; import cc.mallet.pipe.iterator.ArrayIterator; import cc.mallet.pipe.iterator.FileIterator; import cc.mallet.types.*; import cc.mallet.util.*; public class TestNaiveBayes extends TestCase { public TestNaiveBayes (String name) { super (name); } public void testNonTrained () { Alphabet fdict = new Alphabet (); System.out.println ("fdict.size="+fdict.size()); LabelAlphabet ldict = new LabelAlphabet (); Multinomial.Estimator me1 = new Multinomial.LaplaceEstimator (fdict); Multinomial.Estimator me2 = new Multinomial.LaplaceEstimator (fdict); // Prior ldict.lookupIndex ("sports"); ldict.lookupIndex ("politics"); ldict.stopGrowth (); System.out.println ("ldict.size="+ldict.size()); Multinomial prior = new Multinomial (new double[] {.5, .5}, ldict); // Sports me1.increment ("win", 5); me1.increment ("puck", 5); me1.increment ("team", 5); System.out.println ("fdict.size="+fdict.size()); // Politics me2.increment ("win", 5); me2.increment ("speech", 5); me2.increment ("vote", 5); Multinomial sports = me1.estimate(); Multinomial politics = me2.estimate(); // We must estimate from me1 and me2 after all data is incremented, // so that the "sports" multinomial knows the full dictionary size! Classifier c = new NaiveBayes (new Noop (fdict, ldict), prior, new Multinomial[] {sports, politics}); Instance inst = c.getInstancePipe().instanceFrom( new Instance (new FeatureVector (fdict, new Object[] {"speech", "win"}, new double[] {1, 1}), ldict.lookupLabel ("politics"), null, null)); System.out.println ("inst.data = "+inst.getData ()); Classification cf = c.classify (inst); LabelVector l = (LabelVector) cf.getLabeling(); //System.out.println ("l.size="+l.size()); System.out.println ("l.getBestIndex="+l.getBestIndex()); assertTrue (cf.getLabeling().getBestLabel() == ldict.lookupLabel("politics")); assertTrue (cf.getLabeling().getBestValue() > 0.6); } public void testStringTrained () { String[] africaTraining = new String[] { "on the plains of africa the lions roar", "in swahili ngoma means to dance", "nelson mandela became president of south africa", "the saraha dessert is expanding"}; String[] asiaTraining = new String[] { "panda bears eat bamboo", "china's one child policy has resulted in a surplus of boys", "tigers live in the jungle"}; InstanceList instances = new InstanceList ( new SerialPipes (new Pipe[] { new Target2Label (), new CharSequence2TokenSequence (), new TokenSequence2FeatureSequence (), new FeatureSequence2FeatureVector ()})); instances.addThruPipe (new ArrayIterator (africaTraining, "africa")); instances.addThruPipe (new ArrayIterator (asiaTraining, "asia")); Classifier c = new NaiveBayesTrainer ().train (instances); Classification cf = c.classify ("nelson mandela never eats lions"); assertTrue (cf.getLabeling().getBestLabel() == ((LabelAlphabet)instances.getTargetAlphabet()).lookupLabel("africa")); } public void testRandomTrained () { InstanceList ilist = new InstanceList (new Randoms(1), 10, 2); Classifier c = new NaiveBayesTrainer ().train (ilist); // test on the training data int numCorrect = 0; for (int i = 0; i < ilist.size(); i++) { Instance inst = ilist.get(i); Classification cf = c.classify (inst); cf.print (); if (cf.getLabeling().getBestLabel() == inst.getLabeling().getBestLabel()) numCorrect++; } System.out.println ("Accuracy on training set = " + ((double)numCorrect)/ilist.size()); } public void testIncrementallyTrainedGrowingAlphabets() { System.out.println("testIncrementallyTrainedGrowingAlphabets"); String[] args = new String[] { "src/cc/mallet/classify/tests/NaiveBayesData/learn/a", "src/cc/mallet/classify/tests/NaiveBayesData/learn/b" }; File[] directories = new File[args.length]; for (int i = 0; i < args.length; i++) directories[i] = new File (args[i]); SerialPipes instPipe = // MALLET pipeline for converting instances to feature vectors new SerialPipes(new Pipe[] { new Target2Label(), new Input2CharSequence(), //SKIP_HEADER only works for Unix //new CharSubsequence(CharSubsequence.SKIP_HEADER), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector() }); InstanceList instList = new InstanceList(instPipe); instList.addThruPipe(new FileIterator(directories, FileIterator.STARTING_DIRECTORIES)); System.out.println("Training 1"); NaiveBayesTrainer trainer = new NaiveBayesTrainer(); NaiveBayes classifier = trainer.trainIncremental(instList); //instList.getDataAlphabet().stopGrowth(); // incrementally train... String[] t2directories = { "src/cc/mallet/classify/tests/NaiveBayesData/learn/b" }; System.out.println("data alphabet size " + instList.getDataAlphabet().size()); System.out.println("target alphabet size " + instList.getTargetAlphabet().size()); InstanceList instList2 = new InstanceList(instPipe); instList2.addThruPipe(new FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES)); System.out.println("Training 2"); System.out.println("data alphabet size " + instList2.getDataAlphabet().size()); System.out.println("target alphabet size " + instList2.getTargetAlphabet().size()); NaiveBayes classifier2 = (NaiveBayes) trainer.trainIncremental(instList2); } public void testIncrementallyTrained() { System.out.println("testIncrementallyTrained"); String[] args = new String[] { "src/cc/mallet/classify/tests/NaiveBayesData/learn/a", "src/cc/mallet/classify/tests/NaiveBayesData/learn/b" }; File[] directories = new File[args.length]; for (int i = 0; i < args.length; i++) directories[i] = new File (args[i]); SerialPipes instPipe = // MALLET pipeline for converting instances to feature vectors new SerialPipes(new Pipe[] { new Target2Label(), new Input2CharSequence(), //SKIP_HEADER only works for Unix //new CharSubsequence(CharSubsequence.SKIP_HEADER), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector() }); InstanceList instList = new InstanceList(instPipe); instList.addThruPipe(new FileIterator(directories, FileIterator.STARTING_DIRECTORIES)); System.out.println("Training 1"); NaiveBayesTrainer trainer = new NaiveBayesTrainer(); NaiveBayes classifier = (NaiveBayes) trainer.trainIncremental(instList); Classification initialClassification = classifier.classify("Hello Everybody"); Classification initial2Classification = classifier.classify("Goodbye now"); System.out.println("Initial Classification = "); initialClassification.print(); initial2Classification.print(); System.out.println("data alphabet " + classifier.getAlphabet()); System.out.println("label alphabet " + classifier.getLabelAlphabet()); // incrementally train... String[] t2directories = { "src/cc/mallet/classify/tests/NaiveBayesData/learn/b" }; System.out.println("data alphabet size " + instList.getDataAlphabet().size()); System.out.println("target alphabet size " + instList.getTargetAlphabet().size()); InstanceList instList2 = new InstanceList(instPipe); instList2.addThruPipe(new FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES)); System.out.println("Training 2"); System.out.println("data alphabet size " + instList2.getDataAlphabet().size()); System.out.println("target alphabet size " + instList2.getTargetAlphabet().size()); NaiveBayes classifier2 = (NaiveBayes) trainer.trainIncremental(instList2); } public void testEmptyStringBug() { System.out.println("testEmptyStringBug"); String[] args = new String[] { "src/cc/mallet/classify/tests/NaiveBayesData/learn/a", "src/cc/mallet/classify/tests/NaiveBayesData/learn/b" }; File[] directories = new File[args.length]; for (int i = 0; i < args.length; i++) directories[i] = new File (args[i]); SerialPipes instPipe = // MALLET pipeline for converting instances to feature vectors new SerialPipes(new Pipe[] { new Target2Label(), new Input2CharSequence(), //SKIP_HEADER only works for Unix //new CharSubsequence(CharSubsequence.SKIP_HEADER), new CharSequence2TokenSequence(), new TokenSequenceLowercase(), new TokenSequenceRemoveStopwords(), new TokenSequence2FeatureSequence(), new FeatureSequence2FeatureVector() }); InstanceList instList = new InstanceList(instPipe); instList.addThruPipe(new FileIterator(directories, FileIterator.STARTING_DIRECTORIES)); System.out.println("Training 1"); NaiveBayesTrainer trainer = new NaiveBayesTrainer(); NaiveBayes classifier = (NaiveBayes) trainer.trainIncremental(instList); Classification initialClassification = classifier.classify("Hello Everybody"); Classification initial2Classification = classifier.classify("Goodbye now"); System.out.println("Initial Classification = "); initialClassification.print(); initial2Classification.print(); System.out.println("data alphabet " + classifier.getAlphabet()); System.out.println("label alphabet " + classifier.getLabelAlphabet()); // test String[] t2directories = { "src/cc/mallet/classify/tests/NaiveBayesData/learn/b" }; System.out.println("data alphabet size " + instList.getDataAlphabet().size()); System.out.println("target alphabet size " + instList.getTargetAlphabet().size()); InstanceList instList2 = new InstanceList(instPipe); instList2.addThruPipe(new FileIterator(t2directories, FileIterator.STARTING_DIRECTORIES, true)); System.out.println("Training 2"); System.out.println("data alphabet size " + instList2.getDataAlphabet().size()); System.out.println("target alphabet size " + instList2.getTargetAlphabet().size()); NaiveBayes classifier2 = (NaiveBayes) trainer.trainIncremental(instList2); Classification secondClassification = classifier.classify("Goodbye now"); secondClassification.print(); } static Test suite () { return new TestSuite (TestNaiveBayes.class); //TestSuite suite= new TestSuite(); // //suite.addTest(new TestNaiveBayes("testIncrementallyTrained")); // suite.addTest(new TestNaiveBayes("testEmptyStringBug")); // return suite; } protected void setUp () { } public static void main (String[] args) { junit.textui.TestRunner.run (suite()); } }