/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * SGDText.java * Copyright (C) 2011 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.functions; import java.io.File; import java.io.Serializable; import java.util.ArrayList; import java.util.Enumeration; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; import java.util.Random; import java.util.Vector; import weka.classifiers.RandomizableClassifier; import weka.classifiers.UpdateableClassifier; import weka.core.Capabilities; import weka.core.Capabilities.Capability; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.SelectedTag; import weka.core.Stopwords; import weka.core.Tag; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.stemmers.NullStemmer; import weka.core.stemmers.Stemmer; import weka.core.tokenizers.Tokenizer; import weka.core.tokenizers.WordTokenizer; /** <!-- globalinfo-start --> * Implements stochastic gradient descent for learning a linear binary class SVM or binary class logistic regression on text data. Operates directly on String attributes. * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -F * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression) * (default = 0)</pre> * * <pre> -L * The learning rate (default = 0.01).</pre> * * <pre> -R <double> * The lambda regularization constant (default = 0.0001)</pre> * * <pre> -E <integer> * The number of epochs to perform (batch learning only, default = 500)</pre> * * <pre> -W * Use word frequencies instead of binary bag of words.</pre> * * <pre> -P <# instances> * How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)</pre> * * <pre> -M <double> * Minimum word frequency. Words with less than this frequence are ignored. * If periodic pruning is turned on then this is also used to determine which * words to remove from the dictionary (default = 3).</pre> * * <pre> -norm <num> * Specify the norm that each instance must have (default 1.0)</pre> * * <pre> -lnorm <num> * Specify L-norm to use (default 2.0)</pre> * * <pre> -lowercase * Convert all tokens to lowercase before adding to the dictionary.</pre> * * <pre> -S * Ignore words that are in the stoplist.</pre> * * <pre> -stopwords <file> * A file containing stopwords to override the default ones. * Using this option automatically sets the flag ('-S') to use the * stoplist if the file exists. * Format: one stopword per line, lines starting with '#' * are interpreted as comments and ignored.</pre> * * <pre> -tokenizer <spec> * The tokenizing algorihtm (classname plus parameters) to use. * (default: weka.core.tokenizers.WordTokenizer)</pre> * * <pre> -stemmer <spec> * The stemmering algorihtm (classname plus parameters) to use.</pre> * <!-- options-end --> * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz) * */ public class SGDText extends RandomizableClassifier implements UpdateableClassifier, WeightedInstancesHandler { /** For serialization */ private static final long serialVersionUID = 7200171484002029584L; private static class Count implements Serializable { /** * For serialization */ private static final long serialVersionUID = 2104201532017340967L; public double m_count; public double m_weight; public Count(double c) { m_count = c; } } /** * The number of training instances at which to periodically prune the dictionary * of min frequency words. Empty or null string indicates don't prune */ protected int m_periodicP = 0; /** Only consider dictionary words (features) that occur at least this many times */ protected double m_minWordP = 3; /** Use word frequencies rather than bag-of-words if true */ protected boolean m_wordFrequencies = false; /** The length that each document vector should have in the end */ protected double m_norm = 1.0; /** The L-norm to use */ protected double m_lnorm = 2.0; /** The dictionary (and term weights) */ protected LinkedHashMap<String, Count> m_dictionary; /** Default (rainbow) stopwords */ protected transient Stopwords m_stopwords; /** * a file containing stopwords for using others than the default Rainbow * ones. */ protected File m_stopwordsFile = new File(System.getProperty("user.dir")); /** The tokenizer to use */ protected Tokenizer m_tokenizer = new WordTokenizer(); /** Whether or not to convert all tokens to lowercase */ protected boolean m_lowercaseTokens; /** The stemming algorithm. */ protected Stemmer m_stemmer = new NullStemmer(); /** Whether or not to use a stop list */ protected boolean m_useStopList; /** The regularization parameter */ protected double m_lambda = 0.0001; /** The learning rate */ protected double m_learningRate = 0.01; /** Holds the current iteration number */ protected double m_t; /** Holds the bias term */ protected double m_bias; /** The number of training instances */ protected double m_numInstances; /** The header of the training data */ protected Instances m_data; /** * The number of epochs to perform (batch learning). Total iterations is * m_epochs * num instances */ protected int m_epochs = 500; /** * Holds the current document vector (LinkedHashMap is more efficient * when iterating over EntrySet than HashMap) */ protected transient LinkedHashMap<String, Count> m_inputVector; /** the hinge loss function. */ public static final int HINGE = 0; /** the log loss function. */ public static final int LOGLOSS = 1; /** The current loss function to minimize */ protected int m_loss = HINGE; /** Loss functions to choose from */ public static final Tag [] TAGS_SELECTION = { new Tag(HINGE, "Hinge loss (SVM)"), new Tag(LOGLOSS, "Log loss (logistic regression)") }; protected double dloss(double z) { if (m_loss == HINGE) { return (z < 1) ? 1 : 0; } else { // log loss if (z < 0) { return 1.0 / (Math.exp(z) + 1.0); } else { double t = Math.exp(-z); return t / (t + 1); } } } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); //attributes result.enable(Capability.STRING_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); result.enable(Capability.BINARY_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); // instances result.setMinimumNumberInstances(0); return result; } /** * the stemming algorithm to use, null means no stemming at all (i.e., the * NullStemmer is used). * * @param value the configured stemming algorithm, or null * @see NullStemmer */ public void setStemmer(Stemmer value) { if (value != null) m_stemmer = value; else m_stemmer = new NullStemmer(); } /** * Returns the current stemming algorithm, null if none is used. * * @return the current stemming algorithm, null if none set */ public Stemmer getStemmer() { return m_stemmer; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String stemmerTipText() { return "The stemming algorithm to use on the words."; } /** * the tokenizer algorithm to use. * * @param value the configured tokenizing algorithm */ public void setTokenizer(Tokenizer value) { m_tokenizer = value; } /** * Returns the current tokenizer algorithm. * * @return the current tokenizer algorithm */ public Tokenizer getTokenizer() { return m_tokenizer; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String tokenizerTipText() { return "The tokenizing algorithm to use on the strings."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String useWordFrequenciesTipText() { return "Use word frequencies rather than binary " + "bag of words representation"; } /** * Set whether to use word frequencies rather than binary * bag of words representation. * * @param u true if word frequencies are to be used. */ public void setUseWordFrequencies(boolean u) { m_wordFrequencies = u; } /** * Get whether to use word frequencies rather than binary * bag of words representation. * * @param u true if word frequencies are to be used. */ public boolean getUseWordFrequencies() { return m_wordFrequencies; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String lowercaseTokensTipText() { return "Whether to convert all tokens to lowercase"; } /** * Set whether to convert all tokens to lowercase * * @param l true if all tokens are to be converted to * lowercase */ public void setLowercaseTokens(boolean l) { m_lowercaseTokens = l; } /** * Get whether to convert all tokens to lowercase * * @return true true if all tokens are to be converted to * lowercase */ public boolean getLowercaseTokens() { return m_lowercaseTokens; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String useStopListTipText() { return "If true, ignores all words that are on the stoplist."; } /** * Set whether to ignore all words that are on the stoplist. * * @param u true to ignore all words on the stoplist. */ public void setUseStopList(boolean u) { m_useStopList = u; } /** * Get whether to ignore all words that are on the stoplist. * * @return true to ignore all words on the stoplist. */ public boolean getUseStopList() { return m_useStopList; } /** * sets the file containing the stopwords, null or a directory unset the * stopwords. If the file exists, it automatically turns on the flag to * use the stoplist. * * @param value the file containing the stopwords */ public void setStopwords(File value) { if (value == null) value = new File(System.getProperty("user.dir")); m_stopwordsFile = value; if (value.exists() && value.isFile()) setUseStopList(true); } /** * returns the file used for obtaining the stopwords, if the file represents * a directory then the default ones are used. * * @return the file containing the stopwords */ public File getStopwords() { return m_stopwordsFile; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String stopwordsTipText() { return "The file containing the stopwords (if this is a directory then the default ones are used)."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String periodicPruningTipText() { return "How often (number of instances) to prune " + "the dictionary of low frequency terms. " + "0 means don't prune. Setting a positive " + "integer n means prune after every n instances"; } /** * Set how often to prune the dictionary * * @param p how often to prune */ public void setPeriodicPruning(int p) { m_periodicP = p; } /** * Get how often to prune the dictionary * * @return how often to prune the dictionary */ public int getPeriodicPruning() { return m_periodicP; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String minWordFrequencyTipText() { return "Ignore any words that don't occur at least " + "min frequency times in the training data. If periodic " + "pruning is turned on, then the dictionary is pruned " + "according to this value"; } /** * Set the minimum word frequency. Words that don't occur * at least min freq times are ignored when updating weights. * If periodic pruning is turned on, then min frequency is used * when removing words from the dictionary. * * @param minFreq the minimum word frequency to use */ public void setMinWordFrequency(double minFreq) { m_minWordP = minFreq; } /** * Get the minimum word frequency. Words that don't occur * at least min freq times are ignored when updating weights. * If periodic pruning is turned on, then min frequency is used * when removing words from the dictionary. * * @param return the minimum word frequency to use */ public double getMinWordFrequency() { return m_minWordP; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String normTipText() { return "The norm of the instances after normalization."; } /** * Get the instance's Norm. * * @return the Norm */ public double getNorm() { return m_norm; } /** * Set the norm of the instances * * @param newNorm the norm to wich the instances must be set */ public void setNorm(double newNorm) { m_norm = newNorm; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String LNormTipText() { return "The LNorm to use for document length normalization."; } /** * Get the L Norm used. * * @return the L-norm used */ public double getLNorm() { return m_lnorm; } /** * Set the L-norm to used * * @param newLNorm the L-norm */ public void setLNorm(double newLNorm) { m_lnorm = newLNorm; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String lambdaTipText() { return "The regularization constant. (default = 0.0001)"; } /** * Set the value of lambda to use * * @param lambda the value of lambda to use */ public void setLambda(double lambda) { m_lambda = lambda; } /** * Get the current value of lambda * * @return the current value of lambda */ public double getLambda() { return m_lambda; } /** * Set the learning rate. * * @param lr the learning rate to use. */ public void setLearningRate(double lr) { m_learningRate = lr; } /** * Get the learning rate. * * @return the learning rate */ public double getLearningRate() { return m_learningRate; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String learningRateTipText() { return "The learning rate."; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String epochsTipText() { return "The number of epochs to perform (batch learning). " + "The total number of iterations is epochs * num" + " instances."; } /** * Set the number of epochs to use * * @param e the number of epochs to use */ public void setEpochs(int e) { m_epochs = e; } /** * Get current number of epochs * * @return the current number of epochs */ public int getEpochs() { return m_epochs; } /** * Set the loss function to use. * * @param function the loss function to use. */ public void setLossFunction(SelectedTag function) { if (function.getTags() == TAGS_SELECTION) { m_loss = function.getSelectedTag().getID(); } } /** * Get the current loss function. * * @return the current loss function. */ public SelectedTag getLossFunction() { return new SelectedTag(m_loss, TAGS_SELECTION); } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String lossFunctionTipText() { return "The loss function to use. Hinge loss (SVM), " + "log loss (logistic regression) or " + "squared loss (regression)."; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration<Option> listOptions() { Vector<Option> newVector = new Vector<Option>(); newVector.add(new Option("\tSet the loss function to minimize. 0 = " + "hinge loss (SVM), 1 = log loss (logistic regression)\n\t" + "(default = 0)", "F", 1, "-F")); newVector.add(new Option("\tThe learning rate (default = 0.01).", "L", 1, "-L")); newVector.add(new Option("\tThe lambda regularization constant " + "(default = 0.0001)", "R", 1, "-R <double>")); newVector.add(new Option("\tThe number of epochs to perform (" + "batch learning only, default = 500)", "E", 1, "-E <integer>")); newVector.add(new Option("\tUse word frequencies instead of " + "binary bag of words.", "W", 0, "-W")); newVector.add(new Option("\tHow often to prune the dictionary " + "of low frequency words (default = 0, i.e. don't prune)", "P", 1, "-P <# instances>")); newVector.add(new Option("\tMinimum word frequency. Words with less " + "than this frequence are ignored.\n\tIf periodic pruning " + "is turned on then this is also used to determine which\n\t" + "words to remove from the dictionary (default = 3).", "M", 1, "-M <double>")); newVector.addElement(new Option( "\tSpecify the norm that each instance must have (default 1.0)", "norm", 1, "-norm <num>")); newVector.addElement(new Option( "\tSpecify L-norm to use (default 2.0)", "lnorm", 1, "-lnorm <num>")); newVector.addElement(new Option("\tConvert all tokens to lowercase " + "before adding to the dictionary.", "lowercase", 0, "-lowercase")); newVector.addElement(new Option( "\tIgnore words that are in the stoplist.", "S", 0, "-S")); newVector.addElement(new Option( "\tA file containing stopwords to override the default ones.\n" + "\tUsing this option automatically sets the flag ('-S') to use the\n" + "\tstoplist if the file exists.\n" + "\tFormat: one stopword per line, lines starting with '#'\n" + "\tare interpreted as comments and ignored.", "stopwords", 1, "-stopwords <file>")); newVector.addElement(new Option( "\tThe tokenizing algorihtm (classname plus parameters) to use.\n" + "\t(default: " + WordTokenizer.class.getName() + ")", "tokenizer", 1, "-tokenizer <spec>")); newVector.addElement(new Option( "\tThe stemmering algorihtm (classname plus parameters) to use.", "stemmer", 1, "-stemmer <spec>")); return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -F * Set the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression) * (default = 0)</pre> * * <pre> -L * The learning rate (default = 0.01).</pre> * * <pre> -R <double> * The lambda regularization constant (default = 0.0001)</pre> * * <pre> -E <integer> * The number of epochs to perform (batch learning only, default = 500)</pre> * * <pre> -W * Use word frequencies instead of binary bag of words.</pre> * * <pre> -P <# instances> * How often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)</pre> * * <pre> -M <double> * Minimum word frequency. Words with less than this frequence are ignored. * If periodic pruning is turned on then this is also used to determine which * words to remove from the dictionary (default = 3).</pre> * * <pre> -norm <num> * Specify the norm that each instance must have (default 1.0)</pre> * * <pre> -lnorm <num> * Specify L-norm to use (default 2.0)</pre> * * <pre> -lowercase * Convert all tokens to lowercase before adding to the dictionary.</pre> * * <pre> -S * Ignore words that are in the stoplist.</pre> * * <pre> -stopwords <file> * A file containing stopwords to override the default ones. * Using this option automatically sets the flag ('-S') to use the * stoplist if the file exists. * Format: one stopword per line, lines starting with '#' * are interpreted as comments and ignored.</pre> * * <pre> -tokenizer <spec> * The tokenizing algorihtm (classname plus parameters) to use. * (default: weka.core.tokenizers.WordTokenizer)</pre> * * <pre> -stemmer <spec> * The stemmering algorihtm (classname plus parameters) to use.</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { reset(); super.setOptions(options); String lossString = Utils.getOption('F', options); if (lossString.length() != 0) { setLossFunction(new SelectedTag(Integer.parseInt(lossString), TAGS_SELECTION)); } String lambdaString = Utils.getOption('R', options); if (lambdaString.length() > 0) { setLambda(Double.parseDouble(lambdaString)); } String learningRateString = Utils.getOption('L', options); if (learningRateString.length() > 0) { setLearningRate(Double.parseDouble(learningRateString)); } String epochsString = Utils.getOption("E", options); if (epochsString.length() > 0) { setEpochs(Integer.parseInt(epochsString)); } setUseWordFrequencies(Utils.getFlag("W", options)); String pruneFreqS = Utils.getOption("P", options); if (pruneFreqS.length() > 0) { setPeriodicPruning(Integer.parseInt(pruneFreqS)); } String minFreq = Utils.getOption("M", options); if (minFreq.length() > 0) { setMinWordFrequency(Double.parseDouble(minFreq)); } String normFreqS = Utils.getOption("norm", options); if (normFreqS.length() > 0) { setNorm(Double.parseDouble(normFreqS)); } String lnormFreqS = Utils.getOption("lnorm", options); if (lnormFreqS.length() > 0) { setLNorm(Double.parseDouble(lnormFreqS)); } setLowercaseTokens(Utils.getFlag("lowercase", options)); setUseStopList(Utils.getFlag("S", options)); String stopwordsS = Utils.getOption("stopwords", options); if (stopwordsS.length() > 0) { setStopwords(new File(stopwordsS)); } else { setStopwords(null); } String tokenizerString = Utils.getOption("tokenizer", options); if (tokenizerString.length() == 0) { setTokenizer(new WordTokenizer()); } else { String[] tokenizerSpec = Utils.splitOptions(tokenizerString); if (tokenizerSpec.length == 0) throw new Exception("Invalid tokenizer specification string"); String tokenizerName = tokenizerSpec[0]; tokenizerSpec[0] = ""; Tokenizer tokenizer = (Tokenizer) Class.forName(tokenizerName).newInstance(); if (tokenizer instanceof OptionHandler) ((OptionHandler) tokenizer).setOptions(tokenizerSpec); setTokenizer(tokenizer); } String stemmerString = Utils.getOption("stemmer", options); if (stemmerString.length() == 0) { setStemmer(null); } else { String[] stemmerSpec = Utils.splitOptions(stemmerString); if (stemmerSpec.length == 0) throw new Exception("Invalid stemmer specification string"); String stemmerName = stemmerSpec[0]; stemmerSpec[0] = ""; Stemmer stemmer = (Stemmer) Class.forName(stemmerName).newInstance(); if (stemmer instanceof OptionHandler) ((OptionHandler) stemmer).setOptions(stemmerSpec); setStemmer(stemmer); } } /** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { ArrayList<String> options = new ArrayList<String>(); options.add("-F"); options.add("" + getLossFunction().getSelectedTag().getID()); options.add("-L"); options.add("" + getLearningRate()); options.add("-R"); options.add("" + getLambda()); options.add("-E"); options.add("" + getEpochs()); if (getUseWordFrequencies()) { options.add("-W"); } options.add("-P"); options.add("" + getPeriodicPruning()); options.add("-M"); options.add("" + getMinWordFrequency()); options.add("-norm"); options.add("" + getNorm()); options.add("-lnorm"); options.add("" + getLNorm()); if (getLowercaseTokens()) { options.add("-lowercase"); } if (getUseStopList()) { options.add("-S"); } if (!getStopwords().isDirectory()) { options.add("-stopwords"); options.add(getStopwords().getAbsolutePath()); } options.add("-tokenizer"); String spec = getTokenizer().getClass().getName(); if (getTokenizer() instanceof OptionHandler) spec += " " + Utils.joinOptions( ((OptionHandler) getTokenizer()).getOptions()); options.add(spec.trim()); return options.toArray(new String[1]); } /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Implements stochastic gradient descent for learning" + " a linear binary class SVM or binary class" + " logistic regression on text data. Operates directly on String " + "attributes."; } /** * Reset the classifier. */ public void reset() { m_t = 1; m_dictionary = null; } /** * Method for building the classifier. * * @param data the set of training instances. * @throws Exception if the classifier can't be built successfully. */ public void buildClassifier(Instances data) throws Exception { reset(); // can classifier handle the data? getCapabilities().testWithFail(data); m_dictionary = new LinkedHashMap<String, Count>(10000); m_numInstances = data.numInstances(); m_data = new Instances(data, 0); data = new Instances(data); if (data.numInstances() > 0) { data.randomize(new Random(getSeed())); train(data); } } protected void train(Instances data) throws Exception { for (int e = 0; e < m_epochs; e++) { for (int i = 0; i < data.numInstances(); i++) { if (e == 0) { updateClassifier(data.instance(i), true); } else { updateClassifier(data.instance(i), false); } } } } /** * Updates the classifier with the given instance. * * @param instance the new training instance to include in the model * @exception Exception if the instance could not be incorporated in * the model. */ public void updateClassifier(Instance instance) throws Exception { updateClassifier(instance, true); } protected void updateClassifier(Instance instance, boolean updateDictionary) throws Exception { if (!instance.classIsMissing()) { // tokenize tokenizeInstance(instance, updateDictionary); // --- double wx = dotProd(m_inputVector); double y = (instance.classValue() == 0) ? -1 : 1; double z = y * (wx + m_bias); // Compute multiplier for weight decay double multiplier = 1.0; if (m_numInstances == 0) { multiplier = 1.0 - (m_learningRate * m_lambda) / m_t; } else { multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances; } for (Count c : m_dictionary.values()) { c.m_weight *= multiplier; } // Only need to do the following if the loss is non-zero if (m_loss != HINGE || (z < 1)) { // Compute Factor for updates double factor = m_learningRate * y * dloss(z); // Update coefficients for attributes for (Map.Entry<String, Count> feature : m_inputVector.entrySet()) { String word = feature.getKey(); double value = (m_wordFrequencies) ? feature.getValue().m_count : 1; Count c = m_dictionary.get(word); if (c != null) { c.m_weight += factor * value; } } // update the bias m_bias += factor; } m_t++; } } protected void tokenizeInstance(Instance instance, boolean updateDictionary) { if (m_inputVector == null) { m_inputVector = new LinkedHashMap<String, Count>(); } else { m_inputVector.clear(); } if (m_useStopList && m_stopwords == null) { m_stopwords = new Stopwords(); try { if (getStopwords().exists() && !getStopwords().isDirectory()) { m_stopwords.read(getStopwords()); } } catch (Exception ex) { ex.printStackTrace(); } } for (int i = 0; i < instance.numAttributes(); i++) { if (instance.attribute(i).isString() && !instance.isMissing(i)) { m_tokenizer.tokenize(instance.stringValue(i)); while (m_tokenizer.hasMoreElements()) { String word = ((String)m_tokenizer.nextElement()).intern(); if (m_lowercaseTokens) { word = word.toLowerCase().intern(); } word = m_stemmer.stem(word); if (m_useStopList) { if (m_stopwords.is(word)) { continue; } } Count docCount = m_inputVector.get(word); if (docCount == null) { m_inputVector.put(word, new Count(instance.weight())); } else { docCount.m_count += instance.weight(); } if (updateDictionary) { Count count = m_dictionary.get(word); if (count == null) { m_dictionary.put(word, new Count(instance.weight())); } else { count.m_count += instance.weight(); } } } } } if (updateDictionary) { pruneDictionary(); } } protected void pruneDictionary() { if (m_periodicP <= 0 || m_t % m_periodicP > 0) { return; } Iterator<Map.Entry<String, Count>> entries = m_dictionary.entrySet().iterator(); while (entries.hasNext()) { Map.Entry<String, Count> entry = entries.next(); if (entry.getValue().m_count < m_minWordP) { entries.remove(); } } } public double[] distributionForInstance(Instance inst) throws Exception { double[] result = new double[2]; tokenizeInstance(inst, false); double wx = dotProd(m_inputVector); double z = (wx + m_bias); if (z <= 0) { if (m_loss == LOGLOSS) { result[0] = 1.0 / (1.0 + Math.exp(z)); result[1] = 1.0 - result[0]; } else { result[0] = 1; } } else { if (m_loss == LOGLOSS) { result[1] = 1.0 / (1.0 + Math.exp(-z)); result[0] = 1.0 - result[1]; } else { result[1] = 1; } } return result; } protected double dotProd(Map<String, Count> document) { double result = 0; // document normalization double iNorm = 0; double fv = 0; for (Count c : document.values()) { // word counts or bag-of-words? fv = (m_wordFrequencies) ? c.m_count : 1.0; iNorm += Math.pow(Math.abs(fv), m_lnorm); } iNorm = Math.pow(iNorm, 1.0 / m_lnorm); for (Map.Entry<String, Count> feature : document.entrySet()) { String word = feature.getKey(); double freq = (feature.getValue().m_count / iNorm * m_norm); Count weight = m_dictionary.get(word); if (weight != null && weight.m_count >= m_minWordP) { result += freq * weight.m_weight; } } return result; } public String toString() { if (m_dictionary == null) { return "SGDText: No model built yet.\n"; } StringBuffer buff = new StringBuffer(); buff.append("SGDText:\n\n"); buff.append("Loss function: "); if (m_loss == HINGE) { buff.append("Hinge loss (SVM)\n\n"); } else { buff.append("Log loss (logistic regression)\n\n"); } buff.append("Dictionary size: " + m_dictionary.size() + "\n\n"); buff.append(m_data.classAttribute().name() + " = \n\n"); int printed = 0; Iterator<Map.Entry<String, Count>> entries = m_dictionary.entrySet().iterator(); while (entries.hasNext()) { Map.Entry<String, Count> entry = entries.next(); if (printed > 0) { buff.append(" + "); } else { buff.append(" "); } buff.append(Utils.doubleToString(entry.getValue().m_weight, 12, 4) + " " + entry.getKey() + "\n"); printed++; } if (m_bias > 0) { buff.append(" + " + Utils.doubleToString(m_bias, 12, 4)); } else { buff.append(" - " + Utils.doubleToString(-m_bias, 12, 4)); } return buff.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 7787 $"); } /** * Main method for testing this class. */ public static void main(String[] args) { runClassifier(new SGDText(), args); } }