/******************************************************************************* * Copyright (c) 2010 Haifeng Li * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ package smile.classification; import smile.validation.LOOCV; import smile.data.AttributeDataset; import smile.data.parser.ArffParser; import java.io.BufferedReader; import java.io.IOException; import java.util.ArrayList; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import smile.math.Math; import smile.feature.Bag; import smile.stat.distribution.Distribution; import smile.stat.distribution.GaussianMixture; import smile.validation.CrossValidation; import static org.junit.Assert.*; /** * * @author Haifeng Li */ public class NaiveBayesTest { String[] feature = { "outstanding", "wonderfully", "wasted", "lame", "awful", "poorly", "ridiculous", "waste", "worst", "bland", "unfunny", "stupid", "dull", "fantastic", "laughable", "mess", "pointless", "terrific", "memorable", "superb", "boring", "badly", "subtle", "terrible", "excellent", "perfectly", "masterpiece", "realistic", "flaws" }; double[][] moviex; int[] moviey; public NaiveBayesTest() { String[][] x = new String[2000][]; int[] y = new int[2000]; try(BufferedReader input = smile.data.parser.IOUtils.getTestDataReader("text/movie.txt")) { for (int i = 0; i < x.length; i++) { String[] words = input.readLine().trim().split(" "); if (words[0].equalsIgnoreCase("pos")) { y[i] = 1; } else if (words[0].equalsIgnoreCase("neg")) { y[i] = 0; } else { System.err.println("Invalid class label: " + words[0]); } x[i] = words; } } catch (IOException ex) { System.err.println(ex); } moviex = new double[x.length][]; moviey = new int[y.length]; Bag<String> bag = new Bag<>(feature); for (int i = 0; i < x.length; i++) { moviex[i] = bag.feature(x[i]); moviey[i] = y[i]; } } @BeforeClass public static void setUpClass() throws Exception { } @AfterClass public static void tearDownClass() throws Exception { } @Before public void setUp() { } @After public void tearDown() { } /** * Test of predict method, of class NaiveBayes. */ @Test public void testPredict() { System.out.println("predict"); ArffParser arffParser = new ArffParser(); arffParser.setResponseIndex(4); try { AttributeDataset iris = arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff")); double[][] x = iris.toArray(new double[iris.size()][]); int[] y = iris.toArray(new int[iris.size()]); int n = x.length; LOOCV loocv = new LOOCV(n); int error = 0; for (int l = 0; l < n; l++) { double[][] trainx = Math.slice(x, loocv.train[l]); int[] trainy = Math.slice(y, loocv.train[l]); int p = trainx[0].length; int k = Math.max(trainy) + 1; double[] priori = new double[k]; Distribution[][] condprob = new Distribution[k][p]; for (int i = 0; i < k; i++) { priori[i] = 1.0 / k; for (int j = 0; j < p; j++) { ArrayList<Double> axi = new ArrayList<>(); for (int m = 0; m < trainx.length; m++) { if (trainy[m] == i) { axi.add(trainx[m][j]); } } double[] xi = new double[axi.size()]; for (int m = 0; m < xi.length; m++) { xi[m] = axi.get(m); } condprob[i][j] = new GaussianMixture(xi, 3); } } NaiveBayes bayes = new NaiveBayes(priori, condprob); if (y[loocv.test[l]] != bayes.predict(x[loocv.test[l]])) error++; } System.out.format("Iris error rate = %.2f%%%n", 100.0 * error / x.length); assertEquals(5, error); } catch (Exception ex) { System.err.println(ex); } } /** * Test of learn method, of class SequenceNaiveBayes. */ @Test public void testLearnMultinomial() { System.out.println("batch learn Multinomial"); double[][] x = moviex; int[] y = moviey; int n = x.length; int k = 10; CrossValidation cv = new CrossValidation(n, k); int error = 0; int total = 0; for (int i = 0; i < k; i++) { double[][] trainx = Math.slice(x, cv.train[i]); int[] trainy = Math.slice(y, cv.train[i]); NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, feature.length); bayes.learn(trainx, trainy); double[][] testx = Math.slice(x, cv.test[i]); int[] testy = Math.slice(y, cv.test[i]); for (int j = 0; j < testx.length; j++) { int label = bayes.predict(testx[j]); if (label != -1) { total++; if (testy[j] != label) { error++; } } } } System.out.format("Multinomial error = %d of %d%n", error, total); assertTrue(error < 265); } /** * Test of learn method, of class SequenceNaiveBayes. */ @Test public void testLearnMultinomial2() { System.out.println("online learn Multinomial"); double[][] x = moviex; int[] y = moviey; int n = x.length; int k = 10; CrossValidation cv = new CrossValidation(n, k); int error = 0; int total = 0; for (int i = 0; i < k; i++) { double[][] trainx = Math.slice(x, cv.train[i]); int[] trainy = Math.slice(y, cv.train[i]); NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.MULTINOMIAL, 2, feature.length); for (int j = 0; j < trainx.length; j++) { bayes.learn(trainx[j], trainy[j]); } double[][] testx = Math.slice(x, cv.test[i]); int[] testy = Math.slice(y, cv.test[i]); for (int j = 0; j < testx.length; j++) { int label = bayes.predict(testx[j]); if (label != -1) { total++; if (testy[j] != label) { error++; } } } } System.out.format("Multinomial error = %d of %d%n", error, total); assertTrue(error < 265); } /** * Test of learn method, of class SequenceNaiveBayes. */ @Test public void testLearnBernoulli() { System.out.println("batch learn Bernoulli"); double[][] x = moviex; int[] y = moviey; int n = x.length; int k = 10; CrossValidation cv = new CrossValidation(n, k); int error = 0; int total = 0; for (int i = 0; i < k; i++) { double[][] trainx = Math.slice(x, cv.train[i]); int[] trainy = Math.slice(y, cv.train[i]); NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, feature.length); bayes.learn(trainx, trainy); double[][] testx = Math.slice(x, cv.test[i]); int[] testy = Math.slice(y, cv.test[i]); for (int j = 0; j < testx.length; j++) { int label = bayes.predict(testx[j]); if (label != -1) { total++; if (testy[j] != label) { error++; } } } } System.out.format("Bernoulli error = %d of %d%n", error, total); assertTrue(error < 270); } /** * Test of learn method, of class SequenceNaiveBayes. */ @Test public void testLearnBernoulli2() { System.out.println("online learn Bernoulli"); double[][] x = moviex; int[] y = moviey; int n = x.length; int k = 10; CrossValidation cv = new CrossValidation(n, k); int error = 0; int total = 0; for (int i = 0; i < k; i++) { double[][] trainx = Math.slice(x, cv.train[i]); int[] trainy = Math.slice(y, cv.train[i]); NaiveBayes bayes = new NaiveBayes(NaiveBayes.Model.BERNOULLI, 2, feature.length); for (int j = 0; j < trainx.length; j++) { bayes.learn(trainx[j], trainy[j]); } double[][] testx = Math.slice(x, cv.test[i]); int[] testy = Math.slice(y, cv.test[i]); for (int j = 0; j < testx.length; j++) { int label = bayes.predict(testx[j]); if (label != -1) { total++; if (testy[j] != label) { error++; } } } } System.out.format("Bernoulli error = %d of %d%n", error, total); assertTrue(error < 270); } }