/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* NaiveBayesMultinomialText.java
* Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.bayes;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.Stopwords;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;
/**
<!-- globalinfo-start -->
* Multinomial naive bayes for text data. Operates
* directly (and only) on String attributes. Other types of input attributes are
* accepted but ignored during training and classification
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are:
* <p/>
*
* <pre>
* -W
* Use word frequencies instead of binary bag of words.
* </pre>
*
* <pre>
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
* </pre>
*
* <pre>
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
* </pre>
*
* <pre>
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm)
* </pre>
*
* <pre>
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
* </pre>
*
* <pre>
* -lnorm <num>
* Specify L-norm to use (default 2.0)
* </pre>
*
* <pre>
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
* </pre>
*
* <pre>
* -stoplist
* Ignore words that are in the stoplist.
* </pre>
*
* <pre>
* -stopwords <file>
* A file containing stopwords to override the default ones.
* Using this option automatically sets the flag ('-stoplist') to use the
* stoplist if the file exists.
* Format: one stopword per line, lines starting with '#'
* are interpreted as comments and ignored.
* </pre>
*
* <pre>
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
* </pre>
*
* <pre>
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
* </pre>
*
<!-- options-end -->
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @author Andrew Golightly (acg4@cs.waikato.ac.nz)
* @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
*
*/
public class NaiveBayesMultinomialText extends AbstractClassifier implements
UpdateableClassifier, WeightedInstancesHandler,
Aggregateable<NaiveBayesMultinomialText> {
/** For serialization */
private static final long serialVersionUID = 2139025532014821394L;
private static class Count implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 2104201532017340967L;
public double m_count;
public Count(double c) {
m_count = c;
}
}
/** The header of the training data */
protected Instances m_data;
protected double[] m_probOfClass;
protected double[] m_wordsPerClass;
protected Map<Integer, LinkedHashMap<String, Count>> m_probOfWordGivenClass;
/**
* Holds the current document vector (LinkedHashMap is more efficient when
* iterating over EntrySet than HashMap)
*/
protected transient LinkedHashMap<String, Count> m_inputVector;
/** Default (rainbow) stopwords */
protected transient Stopwords m_stopwords;
/**
* a file containing stopwords for using others than the default Rainbow ones.
*/
protected File m_stopwordsFile = new File(System.getProperty("user.dir"));
/** The tokenizer to use */
protected Tokenizer m_tokenizer = new WordTokenizer();
/** Whether or not to convert all tokens to lowercase */
protected boolean m_lowercaseTokens;
/** The stemming algorithm. */
protected Stemmer m_stemmer = new NullStemmer();
/** Whether or not to use a stop list */
protected boolean m_useStopList;
/**
* The number of training instances at which to periodically prune the
* dictionary of min frequency words. Empty or null string indicates don't
* prune
*/
protected int m_periodicP = 0;
/**
* Only consider dictionary words (features) that occur at least this many
* times
*/
protected double m_minWordP = 3;
/** Use word frequencies rather than bag-of-words if true */
protected boolean m_wordFrequencies = false;
/** normailize document length ? */
protected boolean m_normalize = false;
/** The length that each document vector should have in the end */
protected double m_norm = 1.0;
/** The L-norm to use */
protected double m_lnorm = 2.0;
/** Leplace-like correction factor for zero frequency */
protected double m_leplace = 1.0;
/** Holds the current instance number */
protected double m_t;
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Multinomial naive bayes for text data. Operates "
+ "directly (and only) on String attributes. "
+ "Other types of input attributes are accepted but "
+ "ignored during training and classification";
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.STRING_ATTRIBUTES);
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
result.enable(Capability.MISSING_CLASS_VALUES);
result.enable(Capability.NOMINAL_CLASS);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* Generates the classifier.
*
* @param data set of instances serving as training data
* @throws Exception if the classifier has not been generated successfully
*/
@Override
public void buildClassifier(Instances data) throws Exception {
reset();
// can classifier handle the data?
getCapabilities().testWithFail(data);
m_data = new Instances(data, 0);
data = new Instances(data);
m_wordsPerClass = new double[data.numClasses()];
m_probOfClass = new double[data.numClasses()];
m_probOfWordGivenClass = new HashMap<Integer, LinkedHashMap<String, Count>>();
double laplace = 1.0;
for (int i = 0; i < data.numClasses(); i++) {
LinkedHashMap<String, Count> dict = new LinkedHashMap<String, Count>(
10000 / data.numClasses());
m_probOfWordGivenClass.put(i, dict);
m_probOfClass[i] = laplace;
// this needs to be updated for laplace correction every time we see a new
// word (attribute)
m_wordsPerClass[i] = 0;
}
for (int i = 0; i < data.numInstances(); i++) {
updateClassifier(data.instance(i));
}
}
/**
* Updates the classifier with the given instance.
*
* @param instance the new training instance to include in the model
* @throws Exception if the instance could not be incorporated in the model.
*/
@Override
public void updateClassifier(Instance instance) throws Exception {
updateClassifier(instance, true);
}
protected void updateClassifier(Instance instance, boolean updateDictionary)
throws Exception {
if (!instance.classIsMissing()) {
int classIndex = (int) instance.classValue();
m_probOfClass[classIndex] += instance.weight();
tokenizeInstance(instance, updateDictionary);
m_t++;
}
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @throws Exception if there is a problem generating the prediction
*/
@Override
public double[] distributionForInstance(Instance instance) throws Exception {
tokenizeInstance(instance, false);
double[] probOfClassGivenDoc = new double[m_data.numClasses()];
double[] logDocGivenClass = new double[m_data.numClasses()];
for (int i = 0; i < m_data.numClasses(); i++) {
logDocGivenClass[i] += Math.log(m_probOfClass[i]);
LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass.get(i);
int allWords = 0;
// for document normalization (if in use)
double iNorm = 0;
double fv = 0;
if (m_normalize) {
for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
String word = feature.getKey();
Count c = feature.getValue();
// check the word against all the dictionaries (all classes)
boolean ok = false;
for (int clss = 0; clss < m_data.numClasses(); clss++) {
if (m_probOfWordGivenClass.get(clss).get(word) != null) {
ok = true;
break;
}
}
// only normalize with respect to those words that we've seen during
// training
// (i.e. dictionary over all classes)
if (ok) {
// word counts or bag-of-words?
fv = (m_wordFrequencies) ? c.m_count : 1.0;
iNorm += Math.pow(Math.abs(fv), m_lnorm);
}
}
iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
}
// System.out.println("---- " + m_inputVector.size());
for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
String word = feature.getKey();
Count dictCount = dictForClass.get(word);
// System.out.print(word + " ");
/*
* if (dictCount != null) { System.out.println(dictCount.m_count); }
* else { System.out.println("*1"); }
*/
// check the word against all the dictionaries (all classes)
boolean ok = false;
for (int clss = 0; clss < m_data.numClasses(); clss++) {
if (m_probOfWordGivenClass.get(clss).get(word) != null) {
ok = true;
break;
}
}
// ignore words we haven't seen in the training data
if (ok) {
double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
// double freq = (feature.getValue().m_count / iNorm * m_norm);
if (m_normalize) {
freq /= iNorm * m_norm;
}
allWords += freq;
if (dictCount != null) {
logDocGivenClass[i] += freq * Math.log(dictCount.m_count);
} else {
// leplace for zero frequency
logDocGivenClass[i] += freq * Math.log(m_leplace);
}
}
}
if (m_wordsPerClass[i] > 0) {
logDocGivenClass[i] -= allWords * Math.log(m_wordsPerClass[i]);
}
}
double max = logDocGivenClass[Utils.maxIndex(logDocGivenClass)];
for (int i = 0; i < m_data.numClasses(); i++) {
probOfClassGivenDoc[i] = Math.exp(logDocGivenClass[i] - max);
}
Utils.normalize(probOfClassGivenDoc);
return probOfClassGivenDoc;
}
protected void tokenizeInstance(Instance instance, boolean updateDictionary) {
if (m_inputVector == null) {
m_inputVector = new LinkedHashMap<String, Count>();
} else {
m_inputVector.clear();
}
if (m_useStopList && m_stopwords == null) {
m_stopwords = new Stopwords();
try {
if (getStopwords().exists() && !getStopwords().isDirectory()) {
m_stopwords.read(getStopwords());
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
for (int i = 0; i < instance.numAttributes(); i++) {
if (instance.attribute(i).isString() && !instance.isMissing(i)) {
m_tokenizer.tokenize(instance.stringValue(i));
while (m_tokenizer.hasMoreElements()) {
String word = (String) m_tokenizer.nextElement();
if (m_lowercaseTokens) {
word = word.toLowerCase();
}
word = m_stemmer.stem(word);
if (m_useStopList) {
if (m_stopwords.is(word)) {
continue;
}
}
Count docCount = m_inputVector.get(word);
if (docCount == null) {
m_inputVector.put(word, new Count(instance.weight()));
} else {
docCount.m_count += instance.weight();
}
}
}
}
if (updateDictionary) {
int classValue = (int) instance.classValue();
LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass
.get(classValue);
// document normalization
double iNorm = 0;
double fv = 0;
if (m_normalize) {
for (Count c : m_inputVector.values()) {
// word counts or bag-of-words?
fv = (m_wordFrequencies) ? c.m_count : 1.0;
iNorm += Math.pow(Math.abs(fv), m_lnorm);
}
iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
}
for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
String word = feature.getKey();
double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
// double freq = (feature.getValue().m_count / iNorm * m_norm);
if (m_normalize) {
freq /= (iNorm * m_norm);
}
// check all classes
for (int i = 0; i < m_data.numClasses(); i++) {
LinkedHashMap<String, Count> dict = m_probOfWordGivenClass.get(i);
if (dict.get(word) == null) {
dict.put(word, new Count(m_leplace));
m_wordsPerClass[i] += m_leplace;
}
}
Count dictCount = dictForClass.get(word);
/*
* if (dictCount == null) { dictForClass.put(word, new Count(m_leplace +
* freq)); m_wordsPerClass[classValue] += (m_leplace + freq); } else {
*/
dictCount.m_count += freq;
m_wordsPerClass[classValue] += freq;
// }
}
pruneDictionary();
}
}
protected void pruneDictionary() {
if (m_periodicP <= 0 || m_t % m_periodicP > 0) {
return;
}
Set<Integer> classesSet = m_probOfWordGivenClass.keySet();
for (Integer classIndex : classesSet) {
LinkedHashMap<String, Count> dictForClass = m_probOfWordGivenClass
.get(classIndex);
Iterator<Map.Entry<String, Count>> entries = dictForClass.entrySet()
.iterator();
while (entries.hasNext()) {
Map.Entry<String, Count> entry = entries.next();
if (entry.getValue().m_count < m_minWordP) {
m_wordsPerClass[classIndex] -= entry.getValue().m_count;
entries.remove();
}
}
}
}
/**
* Reset the classifier.
*/
public void reset() {
m_t = 1;
m_wordsPerClass = null;
m_probOfWordGivenClass = null;
m_probOfClass = null;
}
/**
* the stemming algorithm to use, null means no stemming at all (i.e., the
* NullStemmer is used).
*
* @param value the configured stemming algorithm, or null
* @see NullStemmer
*/
public void setStemmer(Stemmer value) {
if (value != null)
m_stemmer = value;
else
m_stemmer = new NullStemmer();
}
/**
* Returns the current stemming algorithm, null if none is used.
*
* @return the current stemming algorithm, null if none set
*/
public Stemmer getStemmer() {
return m_stemmer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stemmerTipText() {
return "The stemming algorithm to use on the words.";
}
/**
* the tokenizer algorithm to use.
*
* @param value the configured tokenizing algorithm
*/
public void setTokenizer(Tokenizer value) {
m_tokenizer = value;
}
/**
* Returns the current tokenizer algorithm.
*
* @return the current tokenizer algorithm
*/
public Tokenizer getTokenizer() {
return m_tokenizer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String tokenizerTipText() {
return "The tokenizing algorithm to use on the strings.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useWordFrequenciesTipText() {
return "Use word frequencies rather than binary "
+ "bag of words representation";
}
/**
* Set whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public void setUseWordFrequencies(boolean u) {
m_wordFrequencies = u;
}
/**
* Get whether to use word frequencies rather than binary bag of words
* representation.
*
* @return true if word frequencies are to be used.
*/
public boolean getUseWordFrequencies() {
return m_wordFrequencies;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lowercaseTokensTipText() {
return "Whether to convert all tokens to lowercase";
}
/**
* Set whether to convert all tokens to lowercase
*
* @param l true if all tokens are to be converted to lowercase
*/
public void setLowercaseTokens(boolean l) {
m_lowercaseTokens = l;
}
/**
* Get whether to convert all tokens to lowercase
*
* @return true true if all tokens are to be converted to lowercase
*/
public boolean getLowercaseTokens() {
return m_lowercaseTokens;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String periodicPruningTipText() {
return "How often (number of instances) to prune "
+ "the dictionary of low frequency terms. "
+ "0 means don't prune. Setting a positive "
+ "integer n means prune after every n instances";
}
/**
* Set how often to prune the dictionary
*
* @param p how often to prune
*/
public void setPeriodicPruning(int p) {
m_periodicP = p;
}
/**
* Get how often to prune the dictionary
*
* @return how often to prune the dictionary
*/
public int getPeriodicPruning() {
return m_periodicP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minWordFrequencyTipText() {
return "Ignore any words that don't occur at least "
+ "min frequency times in the training data. If periodic "
+ "pruning is turned on, then the dictionary is pruned "
+ "according to this value";
}
/**
* Set the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param minFreq the minimum word frequency to use
*/
public void setMinWordFrequency(double minFreq) {
m_minWordP = minFreq;
}
/**
* Get the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @return the minimum word frequency to use
*/
public double getMinWordFrequency() {
return m_minWordP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normalizeDocLengthTipText() {
return "If true then document length is normalized according "
+ "to the settings for norm and lnorm";
}
/**
* Set whether to normalize the length of each document
*
* @param norm true if document lengths is to be normalized
*/
public void setNormalizeDocLength(boolean norm) {
m_normalize = norm;
}
/**
* Get whether to normalize the length of each document
*
* @return true if document lengths is to be normalized
*/
public boolean getNormalizeDocLength() {
return m_normalize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normTipText() {
return "The norm of the instances after normalization.";
}
/**
* Get the instance's Norm.
*
* @return the Norm
*/
public double getNorm() {
return m_norm;
}
/**
* Set the norm of the instances
*
* @param newNorm the norm to wich the instances must be set
*/
public void setNorm(double newNorm) {
m_norm = newNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String LNormTipText() {
return "The LNorm to use for document length normalization.";
}
/**
* Get the L Norm used.
*
* @return the L-norm used
*/
public double getLNorm() {
return m_lnorm;
}
/**
* Set the L-norm to used
*
* @param newLNorm the L-norm
*/
public void setLNorm(double newLNorm) {
m_lnorm = newLNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useStopListTipText() {
return "If true, ignores all words that are on the stoplist.";
}
/**
* Set whether to ignore all words that are on the stoplist.
*
* @param u true to ignore all words on the stoplist.
*/
public void setUseStopList(boolean u) {
m_useStopList = u;
}
/**
* Get whether to ignore all words that are on the stoplist.
*
* @return true to ignore all words on the stoplist.
*/
public boolean getUseStopList() {
return m_useStopList;
}
/**
* sets the file containing the stopwords, null or a directory unset the
* stopwords. If the file exists, it automatically turns on the flag to use
* the stoplist.
*
* @param value the file containing the stopwords
*/
public void setStopwords(File value) {
if (value == null)
value = new File(System.getProperty("user.dir"));
m_stopwordsFile = value;
if (value.exists() && value.isFile())
setUseStopList(true);
}
/**
* returns the file used for obtaining the stopwords, if the file represents a
* directory then the default ones are used.
*
* @return the file containing the stopwords
*/
public File getStopwords() {
return m_stopwordsFile;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stopwordsTipText() {
return "The file containing the stopwords (if this is a directory then the default ones are used).";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration<Option> listOptions() {
Vector<Option> newVector = new Vector<Option>();
newVector.add(new Option("\tUse word frequencies instead of "
+ "binary bag of words.", "W", 0, "-W"));
newVector.add(new Option("\tHow often to prune the dictionary "
+ "of low frequency words (default = 0, i.e. don't prune)", "P", 1,
"-P <# instances>"));
newVector.add(new Option("\tMinimum word frequency. Words with less "
+ "than this frequence are ignored.\n\tIf periodic pruning "
+ "is turned on then this is also used to determine which\n\t"
+ "words to remove from the dictionary (default = 3).", "M", 1,
"-M <double>"));
newVector.addElement(new Option(
"\tNormalize document length (use in conjunction with -norm and "
+ "-lnorm)", "normalize", 0, "-normalize"));
newVector.addElement(new Option(
"\tSpecify the norm that each instance must have (default 1.0)",
"norm", 1, "-norm <num>"));
newVector.addElement(new Option("\tSpecify L-norm to use (default 2.0)",
"lnorm", 1, "-lnorm <num>"));
newVector.addElement(new Option("\tConvert all tokens to lowercase "
+ "before adding to the dictionary.", "lowercase", 0, "-lowercase"));
newVector.addElement(new Option("\tIgnore words that are in the stoplist.",
"stoplist", 0, "-stoplist"));
newVector
.addElement(new Option(
"\tA file containing stopwords to override the default ones.\n"
+ "\tUsing this option automatically sets the flag ('-stoplist') to use the\n"
+ "\tstoplist if the file exists.\n"
+ "\tFormat: one stopword per line, lines starting with '#'\n"
+ "\tare interpreted as comments and ignored.", "stopwords", 1,
"-stopwords <file>"));
newVector.addElement(new Option(
"\tThe tokenizing algorihtm (classname plus parameters) to use.\n"
+ "\t(default: " + WordTokenizer.class.getName() + ")",
"tokenizer", 1, "-tokenizer <spec>"));
newVector.addElement(new Option(
"\tThe stemmering algorihtm (classname plus parameters) to use.",
"stemmer", 1, "-stemmer <spec>"));
return newVector.elements();
}
/**
* Parses a given list of options.
* <p/>
*
<!-- options-start -->
* Valid options are:
* <p/>
*
* <pre>
* -W
* Use word frequencies instead of binary bag of words.
* </pre>
*
* <pre>
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
* </pre>
*
* <pre>
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
* </pre>
*
* <pre>
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm)
* </pre>
*
* <pre>
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
* </pre>
*
* <pre>
* -lnorm <num>
* Specify L-norm to use (default 2.0)
* </pre>
*
* <pre>
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
* </pre>
*
* <pre>
* -stoplist
* Ignore words that are in the stoplist.
* </pre>
*
* <pre>
* -stopwords <file>
* A file containing stopwords to override the default ones.
* Using this option automatically sets the flag ('-stoplist') to use the
* stoplist if the file exists.
* Format: one stopword per line, lines starting with '#'
* are interpreted as comments and ignored.
* </pre>
*
* <pre>
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
* </pre>
*
* <pre>
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
* </pre>
*
<!-- options-end -->
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
@Override
public void setOptions(String[] options) throws Exception {
reset();
super.setOptions(options);
setUseWordFrequencies(Utils.getFlag("W", options));
String pruneFreqS = Utils.getOption("P", options);
if (pruneFreqS.length() > 0) {
setPeriodicPruning(Integer.parseInt(pruneFreqS));
}
String minFreq = Utils.getOption("M", options);
if (minFreq.length() > 0) {
setMinWordFrequency(Double.parseDouble(minFreq));
}
setNormalizeDocLength(Utils.getFlag("normalize", options));
String normFreqS = Utils.getOption("norm", options);
if (normFreqS.length() > 0) {
setNorm(Double.parseDouble(normFreqS));
}
String lnormFreqS = Utils.getOption("lnorm", options);
if (lnormFreqS.length() > 0) {
setLNorm(Double.parseDouble(lnormFreqS));
}
setLowercaseTokens(Utils.getFlag("lowercase", options));
setUseStopList(Utils.getFlag("stoplist", options));
String stopwordsS = Utils.getOption("stopwords", options);
if (stopwordsS.length() > 0) {
setStopwords(new File(stopwordsS));
} else {
setStopwords(null);
}
String tokenizerString = Utils.getOption("tokenizer", options);
if (tokenizerString.length() == 0) {
setTokenizer(new WordTokenizer());
} else {
String[] tokenizerSpec = Utils.splitOptions(tokenizerString);
if (tokenizerSpec.length == 0)
throw new Exception("Invalid tokenizer specification string");
String tokenizerName = tokenizerSpec[0];
tokenizerSpec[0] = "";
Tokenizer tokenizer = (Tokenizer) Class.forName(tokenizerName)
.newInstance();
if (tokenizer instanceof OptionHandler)
((OptionHandler) tokenizer).setOptions(tokenizerSpec);
setTokenizer(tokenizer);
}
String stemmerString = Utils.getOption("stemmer", options);
if (stemmerString.length() == 0) {
setStemmer(null);
} else {
String[] stemmerSpec = Utils.splitOptions(stemmerString);
if (stemmerSpec.length == 0)
throw new Exception("Invalid stemmer specification string");
String stemmerName = stemmerSpec[0];
stemmerSpec[0] = "";
Stemmer stemmer = (Stemmer) Class.forName(stemmerName).newInstance();
if (stemmer instanceof OptionHandler)
((OptionHandler) stemmer).setOptions(stemmerSpec);
setStemmer(stemmer);
}
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
@Override
public String[] getOptions() {
ArrayList<String> options = new ArrayList<String>();
if (getUseWordFrequencies()) {
options.add("-W");
}
options.add("-P");
options.add("" + getPeriodicPruning());
options.add("-M");
options.add("" + getMinWordFrequency());
if (getNormalizeDocLength()) {
options.add("-normalize");
}
options.add("-norm");
options.add("" + getNorm());
options.add("-lnorm");
options.add("" + getLNorm());
if (getLowercaseTokens()) {
options.add("-lowercase");
}
if (getUseStopList()) {
options.add("-stoplist");
}
if (!getStopwords().isDirectory()) {
options.add("-stopwords");
options.add(getStopwords().getAbsolutePath());
}
options.add("-tokenizer");
String spec = getTokenizer().getClass().getName();
if (getTokenizer() instanceof OptionHandler)
spec += " "
+ Utils.joinOptions(((OptionHandler) getTokenizer()).getOptions());
options.add(spec.trim());
if (getStemmer() != null) {
options.add("-stemmer");
spec = getStemmer().getClass().getName();
if (getStemmer() instanceof OptionHandler) {
spec += " "
+ Utils.joinOptions(((OptionHandler) getStemmer()).getOptions());
}
options.add(spec.trim());
}
return options.toArray(new String[1]);
}
/**
* Returns a textual description of this classifier.
*
* @return a textual description of this classifier.
*/
@Override
public String toString() {
if (m_probOfClass == null) {
return "NaiveBayesMultinomialText: No model built yet.\n";
}
StringBuffer result = new StringBuffer();
// build a master dictionary over all classes
HashSet<String> master = new HashSet<String>();
for (int i = 0; i < m_data.numClasses(); i++) {
LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i);
for (String key : classDict.keySet()) {
master.add(key);
}
}
result.append("The independent probability of a class\n");
result.append("--------------------------------------\n");
for (int i = 0; i < m_data.numClasses(); i++) {
result.append(m_data.classAttribute().value(i)).append("\t")
.append(Double.toString(m_probOfClass[i])).append("\n");
}
result.append("\nThe probability of a word given the class\n");
result.append("-----------------------------------------\n\t");
for (int i = 0; i < m_data.numClasses(); i++) {
result.append(m_data.classAttribute().value(i)).append("\t");
}
result.append("\n");
Iterator<String> masterIter = master.iterator();
while (masterIter.hasNext()) {
String word = masterIter.next();
result.append(word + "\t");
for (int i = 0; i < m_data.numClasses(); i++) {
LinkedHashMap<String, Count> classDict = m_probOfWordGivenClass.get(i);
Count c = classDict.get(word);
if (c == null) {
result.append("-------------\t");
} else {
result.append(Double.toString(Math.exp(c.m_count))).append("\t");
}
}
result.append("\n");
}
return result.toString();
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 9785 $");
}
protected int m_numModels = 0;
@Override
public NaiveBayesMultinomialText aggregate(
NaiveBayesMultinomialText toAggregate) throws Exception {
if (m_numModels == Integer.MIN_VALUE) {
throw new Exception("Can't aggregate further - model has already been "
+ "aggregated and finalized");
}
if (m_probOfClass == null) {
throw new Exception("No model built yet, can't aggregate");
}
// just check the class attribute for compatibility as we will be
// merging dictionaries
if (!m_data.classAttribute().equals(toAggregate.m_data.classAttribute())) {
throw new Exception("Can't aggregate - class attribute in data headers "
+ "does not match: "
+ m_data.classAttribute().equalsMsg(
toAggregate.m_data.classAttribute()));
}
for (int i = 0; i < m_probOfClass.length; i++) {
m_probOfClass[i] += toAggregate.m_probOfClass[i];
}
Map<Integer, LinkedHashMap<String, Count>> dicts = toAggregate.m_probOfWordGivenClass;
Iterator<Map.Entry<Integer, LinkedHashMap<String, Count>>> perClass = dicts
.entrySet().iterator();
while (perClass.hasNext()) {
Map.Entry<Integer, LinkedHashMap<String, Count>> currentClassDict = perClass
.next();
LinkedHashMap<String, Count> masterDict = m_probOfWordGivenClass
.get(currentClassDict.getKey());
if (masterDict == null) {
// we haven't seen this class during our training
masterDict = new LinkedHashMap<String, Count>();
m_probOfWordGivenClass.put(currentClassDict.getKey(), masterDict);
}
// now process words seen for this class
Iterator<Map.Entry<String, Count>> perClassEntries = currentClassDict
.getValue().entrySet().iterator();
while (perClassEntries.hasNext()) {
Map.Entry<String, Count> entry = perClassEntries.next();
Count masterCount = masterDict.get(entry.getKey());
if (masterCount == null) {
// we haven't seen this entry (or its been pruned)
masterCount = new Count(entry.getValue().m_count);
masterDict.put(entry.getKey(), masterCount);
} else {
// add up
masterCount.m_count += entry.getValue().m_count;
}
}
}
m_numModels++;
return this;
}
@Override
public void finalizeAggregation() throws Exception {
if (m_numModels == 0) {
throw new Exception("Unable to finalize aggregation - "
+ "haven't seen any models to aggregate");
}
// Nothing more to do - we don't need to average anything,
// therefore further models can be aggregated at any time
}
/**
* Main method for testing this class.
*
* @param args the options
*/
public static void main(String[] args) {
runClassifier(new NaiveBayesMultinomialText(), args);
}
}