/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.classifier.naivebayes; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.math.DenseMatrix; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Matrix; import org.apache.mahout.math.Vector; import org.apache.mahout.math.Vector.Element; public abstract class NaiveBayesTestBase extends MahoutTestCase { private NaiveBayesModel model; @Override public void setUp() throws Exception { super.setUp(); model = createNaiveBayesModel(); model.validate(); } protected NaiveBayesModel getModel() { return model; } protected static double complementaryNaiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) { double weight = 0.0; double alpha = 1.0; for (int i = 0; i < featureSum.size(); i++) { double score = weightMatrix.get(i, label); double lSum = labelSum.get(label); double fSum = featureSum.get(i); double totalSum = featureSum.zSum(); double numerator = fSum - score + alpha; double denominator = totalSum - lSum + featureSum.size(); weight += Math.log(numerator / denominator); } return weight; } protected static double naiveBayesThetaWeight(int label, Matrix weightMatrix, Vector labelSum, Vector featureSum) { double weight = 0.0; double alpha = 1.0; for (int feature = 0; feature < featureSum.size(); feature++) { double score = weightMatrix.get(feature, label); double lSum = labelSum.get(label); double numerator = score + alpha; double denominator = lSum + featureSum.size(); weight += Math.log(numerator / denominator); } return weight; } protected static NaiveBayesModel createNaiveBayesModel() { double[][] matrix = { { 0.7, 0.1, 0.1, 0.3 }, { 0.4, 0.4, 0.1, 0.1 }, { 0.1, 0.0, 0.8, 0.1 }, { 0.1, 0.1, 0.1, 0.7 } }; double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 }; double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 }; DenseMatrix weightMatrix = new DenseMatrix(matrix); DenseVector labelSum = new DenseVector(labelSumArray); DenseVector featureSum = new DenseVector(featureSumArray); double[] thetaNormalizerSum = { naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum), naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum), naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) }; // now generate the model return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f); } protected static NaiveBayesModel createComplementaryNaiveBayesModel() { double[][] matrix = { { 0.7, 0.1, 0.1, 0.3 }, { 0.4, 0.4, 0.1, 0.1 }, { 0.1, 0.0, 0.8, 0.1 }, { 0.1, 0.1, 0.1, 0.7 } }; double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 }; double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 }; DenseMatrix weightMatrix = new DenseMatrix(matrix); DenseVector labelSum = new DenseVector(labelSumArray); DenseVector featureSum = new DenseVector(featureSumArray); double[] thetaNormalizerSum = { complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum), complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum), complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) }; // now generate the model return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f); } protected static int maxIndex(Vector instance) { int maxIndex = -1; double maxScore = Integer.MIN_VALUE; for (Element label : instance) { if (label.get() >= maxScore) { maxIndex = label.index(); maxScore = label.get(); } } return maxIndex; } }