// Stanford Classifier - a multiclass maxent classifier
// NaiveBayesClassifierFactory
// Copyright (c) 2003-2007 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.eduu
// http://www-nlp.stanford.edu/software/classifier.shtml
package edu.stanford.nlp.classify;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.optimization.DiffFunction;
import edu.stanford.nlp.optimization.Minimizer;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.*;
/** Creates a NaiveBayesClassifier given an RVFDataset.
*
* @author Kristina Toutanova (kristina@cs.stanford.edu)
*/
public class NaiveBayesClassifierFactory<L, F> implements ClassifierFactory<L, F, NaiveBayesClassifier<L, F>> {
/** A logger for this class */
private static final Redwood.RedwoodChannels logger = Redwood.channels(NaiveBayesClassifierFactory.class);
private static final long serialVersionUID = -8164165428834534041L;
public static final int JL = 0;
public static final int CL = 1;
public static final int UCL = 2;
private int kind = JL;
private double alphaClass;
private double alphaFeature;
private double sigma;
private int prior = LogPrior.LogPriorType.NULL.ordinal();
private Index<L> labelIndex;
private Index<F> featureIndex;
public NaiveBayesClassifierFactory() {
}
public NaiveBayesClassifierFactory(double alphaC, double alphaF, double sigma, int prior, int kind) {
alphaClass = alphaC;
alphaFeature = alphaF;
this.sigma = sigma;
this.prior = prior;
this.kind = kind;
}
private NaiveBayesClassifier<L, F> trainClassifier(int[][] data, int[] labels, int numFeatures,
int numClasses, Index<L> labelIndex, Index<F> featureIndex) {
Set<L> labelSet = Generics.newHashSet();
NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses);
Counter<L> priors = new ClassicCounter<>();
double[] pr = nbWeights.priors;
for (int i = 0; i < pr.length; i++) {
priors.incrementCount(labelIndex.get(i), pr[i]);
labelSet.add(labelIndex.get(i));
}
Counter<Pair<Pair<L, F>, Number>> weightsCounter = new ClassicCounter<>();
double[][][] wts = nbWeights.weights;
for (int c = 0; c < numClasses; c++) {
L label = labelIndex.get(c);
for (int f = 0; f < numFeatures; f++) {
F feature = featureIndex.get(f);
Pair<L, F> p = new Pair<>(label, feature);
for (int val = 0; val < wts[c][f].length; val++) {
Pair<Pair<L, F>, Number> key = new Pair<>(p, Integer.valueOf(val));
weightsCounter.incrementCount(key, wts[c][f][val]);
}
}
}
return new NaiveBayesClassifier<>(weightsCounter, priors, labelSet);
}
/**
* The examples are assumed to be a list of RFVDatum.
* The datums are assumed to not contain the zeroes and then they are added to each instance.
*/
public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> examples, Set<F> featureSet) {
int numFeatures = featureSet.size();
int[][] data = new int[examples.size()][numFeatures];
int[] labels = new int[examples.size()];
labelIndex = new HashIndex<>();
featureIndex = new HashIndex<>();
for (F feat : featureSet) {
featureIndex.add(feat);
}
for (int d = 0; d < examples.size(); d++) {
RVFDatum<L, F> datum = examples.getRVFDatum(d);
Counter<F> c = datum.asFeaturesCounter();
for (F feature : c.keySet()) {
int fNo = featureIndex.indexOf(feature);
int value = (int) c.getCount(feature);
data[d][fNo] = value;
}
labelIndex.add(datum.label());
labels[d] = labelIndex.indexOf(datum.label());
}
int numClasses = labelIndex.size();
return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
}
/**
* Here the data is assumed to be for every instance, array of length numFeatures
* and the value of the feature is stored including zeroes.
*
* @return {@literal label,fno,value -> weight}
*/
private NBWeights trainWeights(int[][] data, int[] labels, int numFeatures, int numClasses) {
if (kind == JL) {
return trainWeightsJL(data, labels, numFeatures, numClasses);
}
if (kind == UCL) {
return trainWeightsUCL(data, labels, numFeatures, numClasses);
}
if (kind == CL) {
return trainWeightsCL(data, labels, numFeatures, numClasses);
}
return null;
}
private NBWeights trainWeightsJL(int[][] data, int[] labels, int numFeatures, int numClasses) {
int[] numValues = numberValues(data, numFeatures);
double[] priors = new double[numClasses];
double[][][] weights = new double[numClasses][numFeatures][];
//init weights array
for (int cl = 0; cl < numClasses; cl++) {
for (int fno = 0; fno < numFeatures; fno++) {
weights[cl][fno] = new double[numValues[fno]];
}
}
for (int i = 0; i < data.length; i++) {
priors[labels[i]]++;
for (int fno = 0; fno < numFeatures; fno++) {
weights[labels[i]][fno][data[i][fno]]++;
}
}
for (int cl = 0; cl < numClasses; cl++) {
for (int fno = 0; fno < numFeatures; fno++) {
for (int val = 0; val < numValues[fno]; val++) {
weights[cl][fno][val] = Math.log((weights[cl][fno][val] + alphaFeature) / (priors[cl] + alphaFeature * numValues[fno]));
}
}
priors[cl] = Math.log((priors[cl] + alphaClass) / (data.length + alphaClass * numClasses));
}
return new NBWeights(priors, weights);
}
private NBWeights trainWeightsUCL(int[][] data, int[] labels, int numFeatures, int numClasses) {
int[] numValues = numberValues(data, numFeatures);
int[] sumValues = new int[numFeatures]; //how many feature-values are before this feature
for (int j = 1; j < numFeatures; j++) {
sumValues[j] = sumValues[j - 1] + numValues[j - 1];
}
int[][] newdata = new int[data.length][numFeatures + 1];
for (int i = 0; i < data.length; i++) {
newdata[i][0] = 0;
for (int j = 0; j < numFeatures; j++) {
newdata[i][j + 1] = sumValues[j] + data[i][j] + 1;
}
}
int totalFeatures = sumValues[numFeatures - 1] + numValues[numFeatures - 1] + 1;
logger.info("total feats " + totalFeatures);
LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<>(totalFeatures, numClasses, newdata, labels, prior, sigma, 0.0);
Minimizer<DiffFunction> min = new QNMinimizer();
double[] argmin = min.minimize(objective, 1e-4, objective.initial());
double[][] wts = objective.to2D(argmin);
System.out.println("weights have dimension " + wts.length);
return new NBWeights(wts, numValues);
}
private NBWeights trainWeightsCL(int[][] data, int[] labels, int numFeatures, int numClasses) {
LogConditionalEqConstraintFunction objective = new LogConditionalEqConstraintFunction(numFeatures, numClasses, data, labels, prior, sigma, 0.0);
Minimizer<DiffFunction> min = new QNMinimizer();
double[] argmin = min.minimize(objective, 1e-4, objective.initial());
double[][][] wts = objective.to3D(argmin);
double[] priors = objective.priors(argmin);
return new NBWeights(priors, wts);
}
static int[] numberValues(int[][] data, int numFeatures) {
int[] numValues = new int[numFeatures];
for (int[] row : data) {
for (int j = 0; j < row.length; j++) {
if (numValues[j] < row[j] + 1) {
numValues[j] = row[j] + 1;
}
}
}
return numValues;
}
static class NBWeights {
double[] priors;
double[][][] weights;
NBWeights(double[] priors, double[][][] weights) {
this.priors = priors;
this.weights = weights;
}
/**
* create the parameters from a coded representation
* where feature 0 is the prior etc.
*
*/
NBWeights(double[][] wts, int[] numValues) {
int numClasses = wts[0].length;
priors = new double[numClasses];
synchronized (System.class) {
System.arraycopy(wts[0], 0, priors, 0, numClasses);
}
int[] sumValues = new int[numValues.length];
for (int j = 1; j < numValues.length; j++) {
sumValues[j] = sumValues[j - 1] + numValues[j - 1];
}
weights = new double[priors.length][sumValues.length][];
for (int fno = 0; fno < numValues.length; fno++) {
for (int c = 0; c < numClasses; c++) {
weights[c][fno] = new double[numValues[fno]];
}
for (int val = 0; val < numValues[fno]; val++) {
int code = sumValues[fno] + val + 1;
for (int cls = 0; cls < numClasses; cls++) {
weights[cls][fno][val] = wts[code][cls];
}
}
}
}
}
// public static void main(String[] args) {
// List examples = new ArrayList();
// String leftLight = "leftLight";
// String rightLight = "rightLight";
// String broken = "BROKEN";
// String ok = "OK";
// Counter c1 = new ClassicCounter<>();
// c1.incrementCount(leftLight, 0);
// c1.incrementCount(rightLight, 0);
// RVFDatum d1 = new RVFDatum(c1, broken);
// examples.add(d1);
// Counter c2 = new ClassicCounter<>();
// c2.incrementCount(leftLight, 1);
// c2.incrementCount(rightLight, 1);
// RVFDatum d2 = new RVFDatum(c2, ok);
// examples.add(d2);
// Counter c3 = new ClassicCounter<>();
// c3.incrementCount(leftLight, 0);
// c3.incrementCount(rightLight, 1);
// RVFDatum d3 = new RVFDatum(c3, ok);
// examples.add(d3);
// Counter c4 = new ClassicCounter<>();
// c4.incrementCount(leftLight, 1);
// c4.incrementCount(rightLight, 0);
// RVFDatum d4 = new RVFDatum(c4, ok);
// examples.add(d4);
// Dataset data = new Dataset(examples.size());
// data.addAll(examples);
// NaiveBayesClassifier classifier = (NaiveBayesClassifier)
// new NaiveBayesClassifierFactory(200, 200, 1.0,
// LogPrior.LogPriorType.QUADRATIC.ordinal(),
// NaiveBayesClassifierFactory.CL)
// .trainClassifier(data);
// classifier.print();
// //now classifiy
// for (int i = 0; i < examples.size(); i++) {
// RVFDatum d = (RVFDatum) examples.get(i);
// Counter scores = classifier.scoresOf(d);
// System.out.println("for datum " + d + " scores are " + scores.toString());
// System.out.println(" class is " + Counters.topKeys(scores, 1));
// System.out.println(" class should be " + d.label());
// }
// }
// String trainFile = args[0];
// String testFile = args[1];
// NominalDataReader nR = new NominalDataReader();
// Map<Integer, Index<String>> indices = Generics.newHashMap();
// List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
// List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices);
// System.out.println("Constrained conditional likelihood no prior :");
// for (int j = 0; j < 100; j++) {
// NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(train);
// classifier.print();
// //now classifiy
//
// float accTrain = classifier.accuracy(train.iterator());
// log.info("training accuracy " + accTrain);
// float accTest = classifier.accuracy(test.iterator());
// log.info("test accuracy " + accTest);
//
// }
// System.out.println("Unconstrained conditional likelihood no prior :");
// for (int j = 0; j < 100; j++) {
// NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.UCL).trainClassifier(train);
// classifier.print();
// //now classify
//
// float accTrain = classifier.accuracy(train.iterator());
// log.info("training accuracy " + accTrain);
// float accTest = classifier.accuracy(test.iterator());
// log.info("test accuracy " + accTest);
// }
// }
@Override
public NaiveBayesClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
if(dataset instanceof RVFDataset){
throw new RuntimeException("Not sure if RVFDataset runs correctly in this method. Please update this code if it does.");
}
return trainClassifier(dataset.getDataArray(), dataset.labels, dataset.numFeatures(),
dataset.numClasses(), dataset.labelIndex, dataset.featureIndex);
}
}