/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* SGDText.java
* Copyright (C) 2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.functions;
import java.io.File;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Aggregateable;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Stopwords;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;
/**
<!-- globalinfo-start -->
* Implements stochastic gradient descent for learning
* a linear binary class SVM or binary class logistic regression on text data.
* Operates directly (and only) on String attributes. Other types of input
* attributes are accepted but ignored during training and classification.
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are:
* <p/>
*
* <pre>
* -F
* Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)
* (default = 0)
* </pre>
*
* <pre>
* -outputProbs
* Output probabilities for SVMs (fits a logsitic
* model to the output of the SVM)
* </pre>
*
* <pre>
* -L
* The learning rate (default = 0.01).
* </pre>
*
* <pre>
* -R <double>
* The lambda regularization constant (default = 0.0001)
* </pre>
*
* <pre>
* -E <integer>
* The number of epochs to perform (batch learning only, default = 500)
* </pre>
*
* <pre>
* -W
* Use word frequencies instead of binary bag of words.
* </pre>
*
* <pre>
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
* </pre>
*
* <pre>
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
* </pre>
*
* <pre>
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm
* </pre>
*
* <pre>
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
* </pre>
*
* <pre>
* -lnorm <num>
* Specify L-norm to use (default 2.0)
* </pre>
*
* <pre>
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
* </pre>
*
* <pre>
* -stoplist
* Ignore words that are in the stoplist.
* </pre>
*
* <pre>
* -stopwords <file>
* A file containing stopwords to override the default ones.
* Using this option automatically sets the flag ('-stoplist') to use the
* stoplist if the file exists.
* Format: one stopword per line, lines starting with '#'
* are interpreted as comments and ignored.
* </pre>
*
* <pre>
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
* </pre>
*
* <pre>
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
* </pre>
*
<!-- options-end -->
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
*
*/
public class SGDText extends RandomizableClassifier implements
UpdateableClassifier, WeightedInstancesHandler, Aggregateable<SGDText> {
/** For serialization */
private static final long serialVersionUID = 7200171484002029584L;
public static class Count implements Serializable {
/**
* For serialization
*/
private static final long serialVersionUID = 2104201532017340967L;
public double m_count;
public double m_weight;
public Count(double c) {
m_count = c;
}
}
/**
* The number of training instances at which to periodically prune the
* dictionary of min frequency words. Empty or null string indicates don't
* prune
*/
protected int m_periodicP = 0;
/**
* Only consider dictionary words (features) that occur at least this many
* times
*/
protected double m_minWordP = 3;
/** Use word frequencies rather than bag-of-words if true */
protected boolean m_wordFrequencies = false;
/** Whether to normalized document length or not */
protected boolean m_normalize = false;
/** The length that each document vector should have in the end */
protected double m_norm = 1.0;
/** The L-norm to use */
protected double m_lnorm = 2.0;
/** The dictionary (and term weights) */
protected LinkedHashMap<String, Count> m_dictionary;
/** Default (rainbow) stopwords */
protected transient Stopwords m_stopwords;
/**
* a file containing stopwords for using others than the default Rainbow ones.
*/
protected File m_stopwordsFile = new File(System.getProperty("user.dir"));
/** The tokenizer to use */
protected Tokenizer m_tokenizer = new WordTokenizer();
/** Whether or not to convert all tokens to lowercase */
protected boolean m_lowercaseTokens;
/** The stemming algorithm. */
protected Stemmer m_stemmer = new NullStemmer();
/** Whether or not to use a stop list */
protected boolean m_useStopList;
/** The regularization parameter */
protected double m_lambda = 0.0001;
/** The learning rate */
protected double m_learningRate = 0.01;
/** Holds the current iteration number */
protected double m_t;
/** Holds the bias term */
protected double m_bias;
/** The number of training instances */
protected double m_numInstances;
/** The header of the training data */
protected Instances m_data;
/**
* The number of epochs to perform (batch learning). Total iterations is
* m_epochs * num instances
*/
protected int m_epochs = 500;
/**
* Holds the current document vector (LinkedHashMap is more efficient when
* iterating over EntrySet than HashMap)
*/
protected transient LinkedHashMap<String, Count> m_inputVector;
/** the hinge loss function. */
public static final int HINGE = 0;
/** the log loss function. */
public static final int LOGLOSS = 1;
/** The current loss function to minimize */
protected int m_loss = HINGE;
/** Loss functions to choose from */
public static final Tag[] TAGS_SELECTION = {
new Tag(HINGE, "Hinge loss (SVM)"),
new Tag(LOGLOSS, "Log loss (logistic regression)") };
/** Used for producing probabilities for SVM via SGD logistic regression */
protected SGD m_svmProbs;
/**
* True if a logistic regression is to be fit to the output of the SVM for
* producing probability estimates
*/
protected boolean m_fitLogistic = false;
protected Instances m_fitLogisticStructure;
protected double dloss(double z) {
if (m_loss == HINGE) {
return (z < 1) ? 1 : 0;
} else {
// log loss
if (z < 0) {
return 1.0 / (Math.exp(z) + 1.0);
} else {
double t = Math.exp(-z);
return t / (t + 1);
}
}
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.STRING_ATTRIBUTES);
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
result.enable(Capability.BINARY_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
// instances
result.setMinimumNumberInstances(0);
return result;
}
/**
* the stemming algorithm to use, null means no stemming at all (i.e., the
* NullStemmer is used).
*
* @param value the configured stemming algorithm, or null
* @see NullStemmer
*/
public void setStemmer(Stemmer value) {
if (value != null)
m_stemmer = value;
else
m_stemmer = new NullStemmer();
}
/**
* Returns the current stemming algorithm, null if none is used.
*
* @return the current stemming algorithm, null if none set
*/
public Stemmer getStemmer() {
return m_stemmer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stemmerTipText() {
return "The stemming algorithm to use on the words.";
}
/**
* the tokenizer algorithm to use.
*
* @param value the configured tokenizing algorithm
*/
public void setTokenizer(Tokenizer value) {
m_tokenizer = value;
}
/**
* Returns the current tokenizer algorithm.
*
* @return the current tokenizer algorithm
*/
public Tokenizer getTokenizer() {
return m_tokenizer;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String tokenizerTipText() {
return "The tokenizing algorithm to use on the strings.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useWordFrequenciesTipText() {
return "Use word frequencies rather than binary "
+ "bag of words representation";
}
/**
* Set whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public void setUseWordFrequencies(boolean u) {
m_wordFrequencies = u;
}
/**
* Get whether to use word frequencies rather than binary bag of words
* representation.
*
* @param u true if word frequencies are to be used.
*/
public boolean getUseWordFrequencies() {
return m_wordFrequencies;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lowercaseTokensTipText() {
return "Whether to convert all tokens to lowercase";
}
/**
* Set whether to convert all tokens to lowercase
*
* @param l true if all tokens are to be converted to lowercase
*/
public void setLowercaseTokens(boolean l) {
m_lowercaseTokens = l;
}
/**
* Get whether to convert all tokens to lowercase
*
* @return true true if all tokens are to be converted to lowercase
*/
public boolean getLowercaseTokens() {
return m_lowercaseTokens;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String useStopListTipText() {
return "If true, ignores all words that are on the stoplist.";
}
/**
* Set whether to ignore all words that are on the stoplist.
*
* @param u true to ignore all words on the stoplist.
*/
public void setUseStopList(boolean u) {
m_useStopList = u;
}
/**
* Get whether to ignore all words that are on the stoplist.
*
* @return true to ignore all words on the stoplist.
*/
public boolean getUseStopList() {
return m_useStopList;
}
/**
* sets the file containing the stopwords, null or a directory unset the
* stopwords. If the file exists, it automatically turns on the flag to use
* the stoplist.
*
* @param value the file containing the stopwords
*/
public void setStopwords(File value) {
if (value == null)
value = new File(System.getProperty("user.dir"));
m_stopwordsFile = value;
if (value.exists() && value.isFile())
setUseStopList(true);
}
/**
* returns the file used for obtaining the stopwords, if the file represents a
* directory then the default ones are used.
*
* @return the file containing the stopwords
*/
public File getStopwords() {
return m_stopwordsFile;
}
/**
* Returns the tip text for this property.
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String stopwordsTipText() {
return "The file containing the stopwords (if this is a directory then the default ones are used).";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String periodicPruningTipText() {
return "How often (number of instances) to prune "
+ "the dictionary of low frequency terms. "
+ "0 means don't prune. Setting a positive "
+ "integer n means prune after every n instances";
}
/**
* Set how often to prune the dictionary
*
* @param p how often to prune
*/
public void setPeriodicPruning(int p) {
m_periodicP = p;
}
/**
* Get how often to prune the dictionary
*
* @return how often to prune the dictionary
*/
public int getPeriodicPruning() {
return m_periodicP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minWordFrequencyTipText() {
return "Ignore any words that don't occur at least "
+ "min frequency times in the training data. If periodic "
+ "pruning is turned on, then the dictionary is pruned "
+ "according to this value";
}
/**
* Set the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param minFreq the minimum word frequency to use
*/
public void setMinWordFrequency(double minFreq) {
m_minWordP = minFreq;
}
/**
* Get the minimum word frequency. Words that don't occur at least min freq
* times are ignored when updating weights. If periodic pruning is turned on,
* then min frequency is used when removing words from the dictionary.
*
* @param return the minimum word frequency to use
*/
public double getMinWordFrequency() {
return m_minWordP;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normalizeDocLengthTipText() {
return "If true then document length is normalized according "
+ "to the settings for norm and lnorm";
}
/**
* Set whether to normalize the length of each document
*
* @param norm true if document lengths is to be normalized
*/
public void setNormalizeDocLength(boolean norm) {
m_normalize = norm;
}
/**
* Get whether to normalize the length of each document
*
* @return true if document lengths is to be normalized
*/
public boolean getNormalizeDocLength() {
return m_normalize;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String normTipText() {
return "The norm of the instances after normalization.";
}
/**
* Get the instance's Norm.
*
* @return the Norm
*/
public double getNorm() {
return m_norm;
}
/**
* Set the norm of the instances
*
* @param newNorm the norm to wich the instances must be set
*/
public void setNorm(double newNorm) {
m_norm = newNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String LNormTipText() {
return "The LNorm to use for document length normalization.";
}
/**
* Get the L Norm used.
*
* @return the L-norm used
*/
public double getLNorm() {
return m_lnorm;
}
/**
* Set the L-norm to used
*
* @param newLNorm the L-norm
*/
public void setLNorm(double newLNorm) {
m_lnorm = newLNorm;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lambdaTipText() {
return "The regularization constant. (default = 0.0001)";
}
/**
* Set the value of lambda to use
*
* @param lambda the value of lambda to use
*/
public void setLambda(double lambda) {
m_lambda = lambda;
}
/**
* Get the current value of lambda
*
* @return the current value of lambda
*/
public double getLambda() {
return m_lambda;
}
/**
* Set the learning rate.
*
* @param lr the learning rate to use.
*/
public void setLearningRate(double lr) {
m_learningRate = lr;
}
/**
* Get the learning rate.
*
* @return the learning rate
*/
public double getLearningRate() {
return m_learningRate;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String learningRateTipText() {
return "The learning rate.";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String epochsTipText() {
return "The number of epochs to perform (batch learning). "
+ "The total number of iterations is epochs * num" + " instances.";
}
/**
* Set the number of epochs to use
*
* @param e the number of epochs to use
*/
public void setEpochs(int e) {
m_epochs = e;
}
/**
* Get current number of epochs
*
* @return the current number of epochs
*/
public int getEpochs() {
return m_epochs;
}
/**
* Set the loss function to use.
*
* @param function the loss function to use.
*/
public void setLossFunction(SelectedTag function) {
if (function.getTags() == TAGS_SELECTION) {
m_loss = function.getSelectedTag().getID();
}
}
/**
* Get the current loss function.
*
* @return the current loss function.
*/
public SelectedTag getLossFunction() {
return new SelectedTag(m_loss, TAGS_SELECTION);
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String lossFunctionTipText() {
return "The loss function to use. Hinge loss (SVM), "
+ "log loss (logistic regression) or " + "squared loss (regression).";
}
/**
* Set whether to fit a logistic regression (itself trained using SGD) to the
* outputs of the SVM (if an SVM is being learned).
*
* @param o true if a logistic regression is to be fit to the output of the
* SVM to produce probability estimates.
*/
public void setOutputProbsForSVM(boolean o) {
m_fitLogistic = o;
}
/**
* Get whether to fit a logistic regression (itself trained using SGD) to the
* outputs of the SVM (if an SVM is being learned).
*
* @return true if a logistic regression is to be fit to the output of the SVM
* to produce probability estimates.
*/
public boolean getOutputProbsForSVM() {
return m_fitLogistic;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String outputProbsForSVMTipText() {
return "Fit a logistic regression to the output of SVM for "
+ "producing probability estimates";
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration<Option> listOptions() {
Vector<Option> newVector = new Vector<Option>();
newVector.add(new Option("\tSet the loss function to minimize. 0 = "
+ "hinge loss (SVM), 1 = log loss (logistic regression)\n\t"
+ "(default = 0)", "F", 1, "-F"));
newVector.add(new Option(
"\tOutput probabilities for SVMs (fits a logsitic\n\t"
+ "model to the output of the SVM)", "output-probs", 0,
"-outputProbs"));
newVector.add(new Option("\tThe learning rate (default = 0.01).", "L", 1,
"-L"));
newVector.add(new Option("\tThe lambda regularization constant "
+ "(default = 0.0001)", "R", 1, "-R <double>"));
newVector.add(new Option("\tThe number of epochs to perform ("
+ "batch learning only, default = 500)", "E", 1, "-E <integer>"));
newVector.add(new Option("\tUse word frequencies instead of "
+ "binary bag of words.", "W", 0, "-W"));
newVector.add(new Option("\tHow often to prune the dictionary "
+ "of low frequency words (default = 0, i.e. don't prune)", "P", 1,
"-P <# instances>"));
newVector.add(new Option("\tMinimum word frequency. Words with less "
+ "than this frequence are ignored.\n\tIf periodic pruning "
+ "is turned on then this is also used to determine which\n\t"
+ "words to remove from the dictionary (default = 3).", "M", 1,
"-M <double>"));
newVector.addElement(new Option(
"\tNormalize document length (use in conjunction with -norm and "
+ "-lnorm)", "normalize", 0, "-normalize"));
newVector.addElement(new Option(
"\tSpecify the norm that each instance must have (default 1.0)",
"norm", 1, "-norm <num>"));
newVector.addElement(new Option("\tSpecify L-norm to use (default 2.0)",
"lnorm", 1, "-lnorm <num>"));
newVector.addElement(new Option("\tConvert all tokens to lowercase "
+ "before adding to the dictionary.", "lowercase", 0, "-lowercase"));
newVector.addElement(new Option("\tIgnore words that are in the stoplist.",
"stoplist", 0, "-stoplist"));
newVector
.addElement(new Option(
"\tA file containing stopwords to override the default ones.\n"
+ "\tUsing this option automatically sets the flag ('-stoplist') to use the\n"
+ "\tstoplist if the file exists.\n"
+ "\tFormat: one stopword per line, lines starting with '#'\n"
+ "\tare interpreted as comments and ignored.", "stopwords", 1,
"-stopwords <file>"));
newVector.addElement(new Option(
"\tThe tokenizing algorihtm (classname plus parameters) to use.\n"
+ "\t(default: " + WordTokenizer.class.getName() + ")",
"tokenizer", 1, "-tokenizer <spec>"));
newVector.addElement(new Option(
"\tThe stemmering algorihtm (classname plus parameters) to use.",
"stemmer", 1, "-stemmer <spec>"));
return newVector.elements();
}
/**
* Parses a given list of options.
* <p/>
*
<!-- options-start -->
* Valid options are:
* <p/>
*
* <pre>
* -F
* Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)
* (default = 0)
* </pre>
*
* <pre>
* -outputProbs
* Output probabilities for SVMs (fits a logsitic
* model to the output of the SVM)
* </pre>
*
* <pre>
* -L
* The learning rate (default = 0.01).
* </pre>
*
* <pre>
* -R <double>
* The lambda regularization constant (default = 0.0001)
* </pre>
*
* <pre>
* -E <integer>
* The number of epochs to perform (batch learning only, default = 500)
* </pre>
*
* <pre>
* -W
* Use word frequencies instead of binary bag of words.
* </pre>
*
* <pre>
* -P <# instances>
* How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)
* </pre>
*
* <pre>
* -M <double>
* Minimum word frequency. Words with less than this frequence are ignored.
* If periodic pruning is turned on then this is also used to determine which
* words to remove from the dictionary (default = 3).
* </pre>
*
* <pre>
* -normalize
* Normalize document length (use in conjunction with -norm and -lnorm
* </pre>
*
* <pre>
* -norm <num>
* Specify the norm that each instance must have (default 1.0)
* </pre>
*
* <pre>
* -lnorm <num>
* Specify L-norm to use (default 2.0)
* </pre>
*
* <pre>
* -lowercase
* Convert all tokens to lowercase before adding to the dictionary.
* </pre>
*
* <pre>
* -stoplist
* Ignore words that are in the stoplist.
* </pre>
*
* <pre>
* -stopwords <file>
* A file containing stopwords to override the default ones.
* Using this option automatically sets the flag ('-stoplist') to use the
* stoplist if the file exists.
* Format: one stopword per line, lines starting with '#'
* are interpreted as comments and ignored.
* </pre>
*
* <pre>
* -tokenizer <spec>
* The tokenizing algorihtm (classname plus parameters) to use.
* (default: weka.core.tokenizers.WordTokenizer)
* </pre>
*
* <pre>
* -stemmer <spec>
* The stemmering algorihtm (classname plus parameters) to use.
* </pre>
*
<!-- options-end -->
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
@Override
public void setOptions(String[] options) throws Exception {
reset();
super.setOptions(options);
String lossString = Utils.getOption('F', options);
if (lossString.length() != 0) {
setLossFunction(new SelectedTag(Integer.parseInt(lossString),
TAGS_SELECTION));
}
setOutputProbsForSVM(Utils.getFlag("output-probs", options));
String lambdaString = Utils.getOption('R', options);
if (lambdaString.length() > 0) {
setLambda(Double.parseDouble(lambdaString));
}
String learningRateString = Utils.getOption('L', options);
if (learningRateString.length() > 0) {
setLearningRate(Double.parseDouble(learningRateString));
}
String epochsString = Utils.getOption("E", options);
if (epochsString.length() > 0) {
setEpochs(Integer.parseInt(epochsString));
}
setUseWordFrequencies(Utils.getFlag("W", options));
String pruneFreqS = Utils.getOption("P", options);
if (pruneFreqS.length() > 0) {
setPeriodicPruning(Integer.parseInt(pruneFreqS));
}
String minFreq = Utils.getOption("M", options);
if (minFreq.length() > 0) {
setMinWordFrequency(Double.parseDouble(minFreq));
}
setNormalizeDocLength(Utils.getFlag("normalize", options));
String normFreqS = Utils.getOption("norm", options);
if (normFreqS.length() > 0) {
setNorm(Double.parseDouble(normFreqS));
}
String lnormFreqS = Utils.getOption("lnorm", options);
if (lnormFreqS.length() > 0) {
setLNorm(Double.parseDouble(lnormFreqS));
}
setLowercaseTokens(Utils.getFlag("lowercase", options));
setUseStopList(Utils.getFlag("stoplist", options));
String stopwordsS = Utils.getOption("stopwords", options);
if (stopwordsS.length() > 0) {
setStopwords(new File(stopwordsS));
} else {
setStopwords(null);
}
String tokenizerString = Utils.getOption("tokenizer", options);
if (tokenizerString.length() == 0) {
setTokenizer(new WordTokenizer());
} else {
String[] tokenizerSpec = Utils.splitOptions(tokenizerString);
if (tokenizerSpec.length == 0)
throw new Exception("Invalid tokenizer specification string");
String tokenizerName = tokenizerSpec[0];
tokenizerSpec[0] = "";
Tokenizer tokenizer = (Tokenizer) Class.forName(tokenizerName)
.newInstance();
if (tokenizer instanceof OptionHandler)
((OptionHandler) tokenizer).setOptions(tokenizerSpec);
setTokenizer(tokenizer);
}
String stemmerString = Utils.getOption("stemmer", options);
if (stemmerString.length() == 0) {
setStemmer(null);
} else {
String[] stemmerSpec = Utils.splitOptions(stemmerString);
if (stemmerSpec.length == 0)
throw new Exception("Invalid stemmer specification string");
String stemmerName = stemmerSpec[0];
stemmerSpec[0] = "";
Stemmer stemmer = (Stemmer) Class.forName(stemmerName).newInstance();
if (stemmer instanceof OptionHandler)
((OptionHandler) stemmer).setOptions(stemmerSpec);
setStemmer(stemmer);
}
}
/**
* Gets the current settings of the classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
@Override
public String[] getOptions() {
ArrayList<String> options = new ArrayList<String>();
options.add("-F");
options.add("" + getLossFunction().getSelectedTag().getID());
if (getOutputProbsForSVM()) {
options.add("-output-probs");
}
options.add("-L");
options.add("" + getLearningRate());
options.add("-R");
options.add("" + getLambda());
options.add("-E");
options.add("" + getEpochs());
if (getUseWordFrequencies()) {
options.add("-W");
}
options.add("-P");
options.add("" + getPeriodicPruning());
options.add("-M");
options.add("" + getMinWordFrequency());
if (getNormalizeDocLength()) {
options.add("-normalize");
}
options.add("-norm");
options.add("" + getNorm());
options.add("-lnorm");
options.add("" + getLNorm());
if (getLowercaseTokens()) {
options.add("-lowercase");
}
if (getUseStopList()) {
options.add("-stoplist");
}
if (!getStopwords().isDirectory()) {
options.add("-stopwords");
options.add(getStopwords().getAbsolutePath());
}
options.add("-tokenizer");
String spec = getTokenizer().getClass().getName();
if (getTokenizer() instanceof OptionHandler)
spec += " "
+ Utils.joinOptions(((OptionHandler) getTokenizer()).getOptions());
options.add(spec.trim());
if (getStemmer() != null) {
options.add("-stemmer");
spec = getStemmer().getClass().getName();
if (getStemmer() instanceof OptionHandler) {
spec += " "
+ Utils.joinOptions(((OptionHandler) getStemmer()).getOptions());
}
options.add(spec.trim());
}
return options.toArray(new String[1]);
}
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the explorer/experimenter
* gui
*/
public String globalInfo() {
return "Implements stochastic gradient descent for learning"
+ " a linear binary class SVM or binary class"
+ " logistic regression on text data. Operates directly (and only) "
+ "on String attributes. Other types of input attributes are accepted "
+ "but ignored during training and classification.";
}
/**
* Reset the classifier.
*/
public void reset() {
m_t = 1;
m_dictionary = null;
}
/**
* Method for building the classifier.
*
* @param data the set of training instances.
* @throws Exception if the classifier can't be built successfully.
*/
@Override
public void buildClassifier(Instances data) throws Exception {
reset();
/*
* boolean hasString = false; for (int i = 0; i < data.numAttributes(); i++)
* { if (data.attribute(i).isString() && data.classIndex() != i) { hasString
* = true; break; } }
*
* if (!hasString) { throw new
* Exception("Incoming data does not have any string attributes!"); }
*/
// can classifier handle the data?
getCapabilities().testWithFail(data);
m_dictionary = new LinkedHashMap<String, Count>(10000);
m_numInstances = data.numInstances();
m_data = new Instances(data, 0);
data = new Instances(data);
if (m_fitLogistic && m_loss == HINGE) {
initializeSVMProbs(data);
}
if (data.numInstances() > 0) {
data.randomize(new Random(getSeed()));
train(data);
}
}
protected void initializeSVMProbs(Instances data) throws Exception {
m_svmProbs = new SGD();
m_svmProbs.setLossFunction(new SelectedTag(SGD.LOGLOSS, TAGS_SELECTION));
m_svmProbs.setLearningRate(m_learningRate);
m_svmProbs.setLambda(m_lambda);
m_svmProbs.setEpochs(m_epochs);
FastVector atts = new FastVector(2);
atts.addElement(new Attribute("pred"));
FastVector attVals = new FastVector(2);
attVals.addElement(data.classAttribute().value(0));
attVals.addElement(data.classAttribute().value(1));
atts.addElement(new Attribute("class", attVals));
m_fitLogisticStructure = new Instances("data", atts, 0);
m_fitLogisticStructure.setClassIndex(1);
m_svmProbs.buildClassifier(m_fitLogisticStructure);
}
protected void train(Instances data) throws Exception {
for (int e = 0; e < m_epochs; e++) {
for (int i = 0; i < data.numInstances(); i++) {
if (e == 0) {
updateClassifier(data.instance(i), true);
} else {
updateClassifier(data.instance(i), false);
}
}
}
}
/**
* Updates the classifier with the given instance.
*
* @param instance the new training instance to include in the model
* @exception Exception if the instance could not be incorporated in the
* model.
*/
@Override
public void updateClassifier(Instance instance) throws Exception {
updateClassifier(instance, true);
}
protected void updateClassifier(Instance instance, boolean updateDictionary)
throws Exception {
if (!instance.classIsMissing()) {
// tokenize
tokenizeInstance(instance, updateDictionary);
// make a meta instance for the logistic model before we update
// the SVM
if (m_loss == HINGE && m_fitLogistic) {
double pred = svmOutput();
double[] vals = new double[2];
vals[0] = pred;
vals[1] = instance.classValue();
DenseInstance metaI = new DenseInstance(instance.weight(), vals);
metaI.setDataset(m_fitLogisticStructure);
m_svmProbs.updateClassifier(metaI);
}
// ---
double wx = dotProd(m_inputVector);
double y = (instance.classValue() == 0) ? -1 : 1;
double z = y * (wx + m_bias);
// Compute multiplier for weight decay
double multiplier = 1.0;
if (m_numInstances == 0) {
multiplier = 1.0 - (m_learningRate * m_lambda) / m_t;
} else {
multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances;
}
for (Count c : m_dictionary.values()) {
c.m_weight *= multiplier;
}
// Only need to do the following if the loss is non-zero
if (m_loss != HINGE || (z < 1)) {
// Compute Factor for updates
double dloss = dloss(z);
double factor = m_learningRate * y * dloss;
// Update coefficients for attributes
for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) {
String word = feature.getKey();
double value = (m_wordFrequencies) ? feature.getValue().m_count : 1;
Count c = m_dictionary.get(word);
if (c != null) {
c.m_weight += factor * value;
}
}
// update the bias
m_bias += factor;
}
m_t++;
}
}
protected void tokenizeInstance(Instance instance, boolean updateDictionary) {
if (m_inputVector == null) {
m_inputVector = new LinkedHashMap<String, Count>();
} else {
m_inputVector.clear();
}
if (m_useStopList && m_stopwords == null) {
m_stopwords = new Stopwords();
try {
if (getStopwords().exists() && !getStopwords().isDirectory()) {
m_stopwords.read(getStopwords());
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
for (int i = 0; i < instance.numAttributes(); i++) {
if (instance.attribute(i).isString() && !instance.isMissing(i)) {
m_tokenizer.tokenize(instance.stringValue(i));
while (m_tokenizer.hasMoreElements()) {
String word = (String) m_tokenizer.nextElement();
if (m_lowercaseTokens) {
word = word.toLowerCase();
}
word = m_stemmer.stem(word);
if (m_useStopList) {
if (m_stopwords.is(word)) {
continue;
}
}
Count docCount = m_inputVector.get(word);
if (docCount == null) {
m_inputVector.put(word, new Count(instance.weight()));
} else {
docCount.m_count += instance.weight();
}
if (updateDictionary) {
Count count = m_dictionary.get(word);
if (count == null) {
m_dictionary.put(word, new Count(instance.weight()));
} else {
count.m_count += instance.weight();
}
}
}
}
}
if (updateDictionary) {
pruneDictionary();
}
}
protected void pruneDictionary() {
if (m_periodicP <= 0 || m_t % m_periodicP > 0) {
return;
}
Iterator<Map.Entry<String, Count>> entries = m_dictionary.entrySet()
.iterator();
while (entries.hasNext()) {
Map.Entry<String, Count> entry = entries.next();
if (entry.getValue().m_count < m_minWordP) {
entries.remove();
}
}
}
protected double svmOutput() {
double wx = dotProd(m_inputVector);
double z = (wx + m_bias);
return z;
}
@Override
public double[] distributionForInstance(Instance inst) throws Exception {
double[] result = new double[2];
tokenizeInstance(inst, false);
double wx = dotProd(m_inputVector);
double z = (wx + m_bias);
if (m_loss == HINGE && m_fitLogistic) {
double pred = z;
double[] vals = new double[2];
vals[0] = pred;
vals[1] = Utils.missingValue();
DenseInstance metaI = new DenseInstance(inst.weight(), vals);
metaI.setDataset(m_fitLogisticStructure);
return m_svmProbs.distributionForInstance(metaI);
}
if (z <= 0) {
if (m_loss == LOGLOSS) {
result[0] = 1.0 / (1.0 + Math.exp(z));
result[1] = 1.0 - result[0];
} else {
result[0] = 1;
}
} else {
if (m_loss == LOGLOSS) {
result[1] = 1.0 / (1.0 + Math.exp(-z));
result[0] = 1.0 - result[1];
} else {
result[1] = 1;
}
}
return result;
}
protected double dotProd(Map<String, Count> document) {
double result = 0;
// document normalization
double iNorm = 0;
double fv = 0;
if (m_normalize) {
for (Count c : document.values()) {
// word counts or bag-of-words?
fv = (m_wordFrequencies) ? c.m_count : 1.0;
iNorm += Math.pow(Math.abs(fv), m_lnorm);
}
iNorm = Math.pow(iNorm, 1.0 / m_lnorm);
}
for (Map.Entry<String, Count> feature : document.entrySet()) {
String word = feature.getKey();
double freq = (m_wordFrequencies) ? feature.getValue().m_count : 1.0;
// double freq = (feature.getValue().m_count / iNorm * m_norm);
if (m_normalize) {
freq /= iNorm * m_norm;
}
Count weight = m_dictionary.get(word);
if (weight != null && weight.m_count >= m_minWordP) {
result += freq * weight.m_weight;
}
}
return result;
}
@Override
public String toString() {
if (m_dictionary == null) {
return "SGDText: No model built yet.\n";
}
StringBuffer buff = new StringBuffer();
buff.append("SGDText:\n\n");
buff.append("Loss function: ");
if (m_loss == HINGE) {
buff.append("Hinge loss (SVM)\n\n");
} else {
buff.append("Log loss (logistic regression)\n\n");
}
int dictSize = 0;
Iterator<Map.Entry<String, Count>> entries = m_dictionary.entrySet()
.iterator();
while (entries.hasNext()) {
Map.Entry<String, Count> entry = entries.next();
if (entry.getValue().m_count >= m_minWordP) {
dictSize++;
}
}
buff.append("Dictionary size: " + dictSize + "\n\n");
buff.append(m_data.classAttribute().name() + " = \n\n");
int printed = 0;
entries = m_dictionary.entrySet().iterator();
while (entries.hasNext()) {
Map.Entry<String, Count> entry = entries.next();
if (entry.getValue().m_count >= m_minWordP) {
if (printed > 0) {
buff.append(" + ");
} else {
buff.append(" ");
}
buff.append(Utils.doubleToString(entry.getValue().m_weight, 12, 4)
+ " " + entry.getKey() + " " + entry.getValue().m_count + "\n");
printed++;
}
}
if (m_bias > 0) {
buff.append(" + " + Utils.doubleToString(m_bias, 12, 4));
} else {
buff.append(" - " + Utils.doubleToString(-m_bias, 12, 4));
}
return buff.toString();
}
/**
* Get this model's dictionary (including term weights).
*
* @return this model's dictionary.
*/
public LinkedHashMap<String, Count> getDictionary() {
return m_dictionary;
}
/**
* Return the size of the dictionary (minus any low frequency terms that are
* below the threshold but haven't been pruned yet).
*
* @return the size of the dictionary.
*/
public int getDictionarySize() {
int size = 0;
if (m_dictionary != null) {
Iterator<Map.Entry<String, Count>> entries = m_dictionary.entrySet()
.iterator();
while (entries.hasNext()) {
Map.Entry<String, Count> entry = entries.next();
if (entry.getValue().m_count >= m_minWordP) {
size++;
}
}
}
return size;
}
public double bias() {
return m_bias;
}
public void setBias(double bias) {
m_bias = bias;
}
/**
* Returns the revision string.
*
* @return the revision
*/
@Override
public String getRevision() {
return RevisionUtils.extract("$Revision: 9785 $");
}
protected int m_numModels = 0;
/**
* Aggregate an object with this one
*
* @param toAggregate the object to aggregate
* @return the result of aggregation
* @throws Exception if the supplied object can't be aggregated for some
* reason
*/
@Override
public SGDText aggregate(SGDText toAggregate) throws Exception {
if (m_dictionary == null) {
throw new Exception("No model built yet, can't aggregate");
}
LinkedHashMap<String, SGDText.Count> tempDict = toAggregate.getDictionary();
Iterator<Map.Entry<String, SGDText.Count>> entries = tempDict.entrySet()
.iterator();
while (entries.hasNext()) {
Map.Entry<String, SGDText.Count> entry = entries.next();
Count masterCount = m_dictionary.get(entry.getKey());
if (masterCount == null) {
// we havent seen this term (or it's been pruned)
masterCount = new Count(entry.getValue().m_count);
masterCount.m_weight = entry.getValue().m_weight;
m_dictionary.put(entry.getKey(), masterCount);
} else {
// add up
masterCount.m_count += entry.getValue().m_count;
masterCount.m_weight += entry.getValue().m_weight;
}
}
m_bias += toAggregate.bias();
m_numModels++;
return this;
}
/**
* Call to complete the aggregation process. Allows implementers to do any
* final processing based on how many objects were aggregated.
*
* @throws Exception if the aggregation can't be finalized for some reason
*/
@Override
public void finalizeAggregation() throws Exception {
if (m_numModels == 0) {
throw new Exception("Unable to finalize aggregation - " +
"haven't seen any models to aggregate");
}
Iterator<Map.Entry<String, SGDText.Count>> entries = m_dictionary.entrySet()
.iterator();
while (entries.hasNext()) {
Map.Entry<String, Count> entry = entries.next();
entry.getValue().m_count /= (m_numModels + 1); // plus one for us
entry.getValue().m_weight /= (m_numModels + 1);
}
m_bias /= (m_numModels + 1);
// aggregation complete
m_numModels = 0;
}
/**
* Main method for testing this class.
*/
public static void main(String[] args) {
runClassifier(new SGDText(), args);
}
}