// Stanford Classifier - a multiclass maxent classifier
// LinearClassifierFactory
// Copyright (c) 2003-2016 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.edu
// http://www-nlp.stanford.edu/software/classifier.shtml
package edu.stanford.nlp.classify;
import java.io.BufferedReader;
import java.util.List;
import java.util.function.Function;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.MultiClassAccuracyStats;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.util.logging.Redwood;
/**
* Builds various types of linear classifiers, with functionality for
* setting objective function, optimization method, and other parameters.
* Classifiers can be defined with passed constructor arguments or using setter methods.
* Defaults to Quasi-newton optimization of a {@code LogConditionalObjectiveFunction}.
* (Merges old classes: CGLinearClassifierFactory, QNLinearClassifierFactory, and MaxEntClassifierFactory.)
* Note that a bias term is not assumed, and so if you want to learn
* a bias term you should add an "always-on" feature to your examples.
*
* @author Jenny Finkel
* @author Chris Cox (merged factories, 8/11/04)
* @author Dan Klein (CGLinearClassifierFactory, MaxEntClassifierFactory)
* @author Galen Andrew (tuneSigma),
* @author Marie-Catherine de Marneffe (CV in tuneSigma)
* @author Sarah Spikes (Templatization, though I don't know what to do with the Minimizer)
* @author Ramesh Nallapati (nmramesh@cs.stanford.edu) {@link #trainSemiSupGE} methods
*/
public class LinearClassifierFactory<L, F> extends AbstractLinearClassifierFactory<L, F> {
private static final long serialVersionUID = 7893768984379107397L;
private double TOL;
//public double sigma;
private int mem = 15;
private boolean verbose = false;
//private int prior;
//private double epsilon = 0.0;
private LogPrior logPrior;
//private Minimizer<DiffFunction> minimizer;
//private boolean useSum = false;
private boolean tuneSigmaHeldOut = false;
private boolean tuneSigmaCV = false;
//private boolean resetWeight = true;
private int folds;
// range of values to tune sigma across
private double min = 0.1;
private double max = 10.0;
private boolean retrainFromScratchAfterSigmaTuning = false;
private Factory<Minimizer<DiffFunction>> minimizerCreator = null;
private int evalIters = -1;
private Evaluator[] evaluators; // = null;
/** A logger for this class */
private final static Redwood.RedwoodChannels logger = Redwood.channels(LinearClassifierFactory.class);
/** This is the {@code Factory<Minimizer<DiffFunction>>} that we use over and over again. */
private class QNFactory implements Factory<Minimizer<DiffFunction>> {
private static final long serialVersionUID = 9028306475652690036L;
@Override
public Minimizer<DiffFunction> create() {
QNMinimizer qnMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem);
if (! verbose) {
qnMinimizer.shutUp();
}
return qnMinimizer;
}
} // end class QNFactory
public LinearClassifierFactory() {
this((Factory<Minimizer<DiffFunction>>) null);
}
/** NOTE: Constructors that take in a Minimizer create a LinearClassifierFactory that will reuse the minimizer
* and will not be threadsafe (unless the Minimizer itself is ThreadSafe, which is probably not the case).
*/
public LinearClassifierFactory(Minimizer<DiffFunction> min) {
this(min, 1e-4, false);
}
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min) {
this(min, 1e-4, false);
}
public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum) {
this(min, tol, useSum, 1.0);
}
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum) {
this(min, tol, useSum, 1.0);
}
public LinearClassifierFactory(double tol, boolean useSum, double sigma) {
this((Factory<Minimizer<DiffFunction>>) null, tol, useSum, sigma);
}
public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, double sigma) {
this(min, tol, useSum, LogPrior.LogPriorType.QUADRATIC.ordinal(), sigma);
}
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, double sigma) {
this(min, tol, useSum, LogPrior.LogPriorType.QUADRATIC.ordinal(), sigma);
}
public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, int prior, double sigma) {
this(min, tol, useSum, prior, sigma, 0.0);
}
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, int prior, double sigma) {
this(min, tol, useSum, prior, sigma, 0.0);
}
public LinearClassifierFactory(double tol, boolean useSum, int prior, double sigma, double epsilon) {
this((Factory<Minimizer<DiffFunction>>) null, tol, useSum, new LogPrior(prior, sigma, epsilon));
}
public LinearClassifierFactory(double tol, boolean useSum, int prior, double sigma, double epsilon, final int mem) {
this((Factory<Minimizer<DiffFunction>>) null, tol, useSum, new LogPrior(prior, sigma, epsilon));
this.mem = mem;
}
/**
* Create a factory that builds linear classifiers from training data.
*
* @param min The method to be used for optimization (minimization) (default: {@link QNMinimizer})
* @param tol The convergence threshold for the minimization (default: 1e-4)
* @param useSum Asks to the optimizer to minimize the sum of the
* likelihoods of individual data items rather than their product (default: false)
* NOTE: this is currently ignored!!!
* @param prior What kind of prior to use, as an enum constant from class
* LogPrior
* @param sigma The strength of the prior (smaller is stronger for most
* standard priors) (default: 1.0)
* @param epsilon A second parameter to the prior (currently only used
* by the Huber prior)
*/
public LinearClassifierFactory(Minimizer<DiffFunction> min, double tol, boolean useSum, int prior, double sigma, double epsilon) {
this(min, tol, useSum, new LogPrior(prior, sigma, epsilon));
}
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> min, double tol, boolean useSum, int prior, double sigma, double epsilon) {
this(min, tol, useSum, new LogPrior(prior, sigma, epsilon));
}
public LinearClassifierFactory(final Minimizer<DiffFunction> min, double tol, boolean useSum, LogPrior logPrior) {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = -6439748445540743949L;
@Override
public Minimizer<DiffFunction> create() {
return min;
}
};
this.TOL = tol;
//this.useSum = useSum;
this.logPrior = logPrior;
}
/**
* Create a factory that builds linear classifiers from training data. This is the recommended constructor to
* bottom out with. Use of a minimizerCreator makes the classifier threadsafe.
*
* @param minimizerCreator A Factory for creating minimizers. If this is null, a standard quasi-Newton minimizer
* factory will be used.
* @param tol The convergence threshold for the minimization (default: 1e-4)
* @param useSum Asks to the optimizer to minimize the sum of the
* likelihoods of individual data items rather than their product (Klein and Manning 2001 WSD.)
* NOTE: this is currently ignored!!! At some point support for this option was deleted
* @param logPrior What kind of prior to use, this class specifies its type and hyperparameters.
*/
public LinearClassifierFactory(Factory<Minimizer<DiffFunction>> minimizerCreator, double tol, boolean useSum, LogPrior logPrior) {
if (minimizerCreator == null) {
this.minimizerCreator = new QNFactory();
} else {
this.minimizerCreator = minimizerCreator;
}
this.TOL = tol;
//this.useSum = useSum;
this.logPrior = logPrior;
}
/**
* Set the tolerance. 1e-4 is the default.
*/
public void setTol(double tol) {
this.TOL = tol;
}
/**
* Set the prior.
*
* @param logPrior One of the priors defined in
* {@code LogConditionalObjectiveFunction}.
* {@code LogPrior.QUADRATIC} is the default.
*/
public void setPrior(LogPrior logPrior) {
this.logPrior = logPrior;
}
/**
* Set the verbose flag for {@link CGMinimizer}.
* {@code false} is the default.
*/
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
/**
* Sets the minimizer. {@link QNMinimizer} is the default.
*/
public void setMinimizerCreator(Factory<Minimizer<DiffFunction>> minimizerCreator) {
this.minimizerCreator = minimizerCreator;
}
/**
* Sets the epsilon value for {@link LogConditionalObjectiveFunction}.
*/
public void setEpsilon(double eps) {
logPrior.setEpsilon(eps);
}
public void setSigma(double sigma) {
logPrior.setSigma(sigma);
}
public double getSigma() {
return logPrior.getSigma();
}
/**
* Sets the minimizer to QuasiNewton. {@link QNMinimizer} is the default.
*/
public void useQuasiNewton() {
this.minimizerCreator = new QNFactory();
}
public void useQuasiNewton(final boolean useRobust) {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = -9108222058357693242L;
@Override
public Minimizer<DiffFunction> create() {
QNMinimizer qnMinimizer = new QNMinimizer(LinearClassifierFactory.this.mem, useRobust);
if (!verbose) {
qnMinimizer.shutUp();
}
return qnMinimizer;
}
};
}
public void useStochasticQN(final double initialSMDGain, final int stochasticBatchSize){
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = -7760753348350678588L;
@Override
public Minimizer<DiffFunction> create() {
SQNMinimizer<DiffFunction> sqnMinimizer = new SQNMinimizer<>(LinearClassifierFactory.this.mem, initialSMDGain, stochasticBatchSize, false);
if (!verbose) {
sqnMinimizer.shutUp();
}
return sqnMinimizer;
}
};
}
public void useStochasticMetaDescent(){
useStochasticMetaDescent(0.1, 15, StochasticCalculateMethods.ExternalFiniteDifference, 20);
}
public void useStochasticMetaDescent(final double initialSMDGain, final int stochasticBatchSize,
final StochasticCalculateMethods stochasticMethod,final int passes) {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = 6860437108371914482L;
@Override
public Minimizer<DiffFunction> create() {
SMDMinimizer<DiffFunction> smdMinimizer = new SMDMinimizer<>(initialSMDGain, stochasticBatchSize, stochasticMethod, passes);
if (!verbose) {
smdMinimizer.shutUp();
}
return smdMinimizer;
}
};
}
public void useStochasticGradientDescent(){
useStochasticGradientDescent(0.1,15);
}
public void useStochasticGradientDescent(final double gainSGD, final int stochasticBatchSize){
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = 2564615420955196299L;
@Override
public Minimizer<DiffFunction> create() {
InefficientSGDMinimizer<DiffFunction> sgdMinimizer = new InefficientSGDMinimizer<>(gainSGD, stochasticBatchSize);
if (!verbose) {
sgdMinimizer.shutUp();
}
return sgdMinimizer;
}
};
}
public void useInPlaceStochasticGradientDescent() {
useInPlaceStochasticGradientDescent(-1, -1, 1.0);
}
public void useInPlaceStochasticGradientDescent(final int SGDPasses, final int tuneSampleSize, final double sigma) {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = -5319225231759162616L;
@Override
public Minimizer<DiffFunction> create() {
SGDMinimizer<DiffFunction> sgdMinimizer = new SGDMinimizer<>(sigma, SGDPasses, tuneSampleSize);
if (!verbose) {
sgdMinimizer.shutUp();
}
return sgdMinimizer;
}
};
}
public void useHybridMinimizerWithInPlaceSGD(final int SGDPasses, final int tuneSampleSize, final double sigma) {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = -3042400543337763144L;
@Override
public Minimizer<DiffFunction> create() {
SGDMinimizer<DiffFunction> firstMinimizer = new SGDMinimizer<>(sigma, SGDPasses, tuneSampleSize);
QNMinimizer secondMinimizer = new QNMinimizer(mem);
if (!verbose) {
firstMinimizer.shutUp();
secondMinimizer.shutUp();
}
return new HybridMinimizer(firstMinimizer, secondMinimizer, SGDPasses);
}
};
}
public void useStochasticGradientDescentToQuasiNewton(final double SGDGain, final int batchSize, final int sgdPasses,
final int qnPasses, final int hessSamples, final int QNMem,
final boolean outputToFile) {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = 5823852936137599566L;
@Override
public Minimizer<DiffFunction> create() {
SGDToQNMinimizer sgdToQNMinimizer = new SGDToQNMinimizer(SGDGain, batchSize, sgdPasses,
qnPasses, hessSamples, QNMem, outputToFile);
if (!verbose) {
sgdToQNMinimizer.shutUp();
}
return sgdToQNMinimizer;
}
};
}
public void useHybridMinimizer() {
useHybridMinimizer(0.1, 15, StochasticCalculateMethods.ExternalFiniteDifference, 0);
}
public void useHybridMinimizer(final double initialSMDGain, final int stochasticBatchSize,
final StochasticCalculateMethods stochasticMethod, final int cutoffIteration){
this.minimizerCreator = () -> {
SMDMinimizer<DiffFunction> firstMinimizer = new SMDMinimizer<>(initialSMDGain, stochasticBatchSize, stochasticMethod, cutoffIteration);
QNMinimizer secondMinimizer = new QNMinimizer(mem);
if (!verbose) {
firstMinimizer.shutUp();
secondMinimizer.shutUp();
}
return new HybridMinimizer(firstMinimizer, secondMinimizer, cutoffIteration);
};
}
/**
* Set the mem value for {@link QNMinimizer}.
* Only used with quasi-newton minimization. 15 is the default.
*
* @param mem Number of previous function/derivative evaluations to store
* to estimate second derivative. Storing more previous evaluations
* improves training convergence speed. This number can be very
* small, if memory conservation is the priority. For large
* optimization systems (of 100,000-1,000,000 dimensions), setting this
* to 15 produces quite good results, but setting it to 50 can
* decrease the iteration count by about 20% over a value of 15.
*/
public void setMem(int mem) {
this.mem = mem;
}
/**
* Sets the minimizer to {@link CGMinimizer}, with the passed {@code verbose} flag.
*/
public void useConjugateGradientAscent(boolean verbose) {
this.verbose = verbose;
useConjugateGradientAscent();
}
/**
* Sets the minimizer to {@link CGMinimizer}.
*/
public void useConjugateGradientAscent() {
this.minimizerCreator = new Factory<Minimizer<DiffFunction>>() {
private static final long serialVersionUID = -561168861131879990L;
@Override
public Minimizer<DiffFunction> create() {
return new CGMinimizer(!LinearClassifierFactory.this.verbose);
}
};
}
/**
* NOTE: nothing is actually done with this value!
*
* SetUseSum sets the {@code useSum} flag: when turned on,
* the Summed Conditional Objective Function is used. Otherwise, the
* LogConditionalObjectiveFunction is used. The default is false.
*/
public void setUseSum(boolean useSum) {
//this.useSum = useSum;
}
private Minimizer<DiffFunction> getMinimizer() {
// Create a new minimizer
Minimizer<DiffFunction> minimizer = minimizerCreator.create();
if (minimizer instanceof HasEvaluators) {
((HasEvaluators) minimizer).setEvaluators(evalIters, evaluators);
}
return minimizer;
}
/**
* Adapt classifier (adjust the mean of Gaussian prior).
* Under construction -pichuan
*
* @param origWeights the original weights trained from the training data
* @param adaptDataset the Dataset used to adapt the trained weights
* @return adapted weights
*/
public double[][] adaptWeights(double[][] origWeights, GeneralDataset<L, F> adaptDataset) {
Minimizer<DiffFunction> minimizer = getMinimizer();
logger.info("adaptWeights in LinearClassifierFactory. increase weight dim only");
double[][] newWeights = new double[adaptDataset.featureIndex.size()][adaptDataset.labelIndex.size()];
synchronized (System.class) {
System.arraycopy(origWeights, 0, newWeights, 0, origWeights.length);
}
AdaptedGaussianPriorObjectiveFunction<L, F> objective = new AdaptedGaussianPriorObjectiveFunction<>(adaptDataset, logPrior, newWeights);
double[] initial = objective.initial();
double[] weights = minimizer.minimize(objective, TOL, initial);
return objective.to2D(weights);
//Question: maybe the adaptWeights can be done just in LinearClassifier ?? (pichuan)
}
@Override
public double[][] trainWeights(GeneralDataset<L, F> dataset) {
return trainWeights(dataset, null);
}
public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial) {
return trainWeights(dataset, initial, false);
}
public double[][] trainWeights(GeneralDataset<L, F> dataset, double[] initial, boolean bypassTuneSigma) {
Minimizer<DiffFunction> minimizer = getMinimizer();
if(dataset instanceof RVFDataset)
((RVFDataset<L,F>)dataset).ensureRealValues();
double[] interimWeights = null;
if(! bypassTuneSigma) {
if (tuneSigmaHeldOut) {
interimWeights = heldOutSetSigma(dataset); // the optimum interim weights from held-out training data have already been found.
} else if (tuneSigmaCV) {
crossValidateSetSigma(dataset,folds); // TODO: assign optimum interim weights as part of this process.
}
}
LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<>(dataset, logPrior);
if(initial == null && interimWeights != null && ! retrainFromScratchAfterSigmaTuning) {
//logger.info("## taking advantage of interim weights as starting point.");
initial = interimWeights;
}
if (initial == null) {
initial = objective.initial();
}
double[] weights = minimizer.minimize(objective, TOL, initial);
return objective.to2D(weights);
}
/**
* IMPORTANT: dataset and biasedDataset must have same featureIndex, labelIndex
*/
public Classifier<L, F> trainClassifierSemiSup(GeneralDataset<L, F> data, GeneralDataset<L, F> biasedData, double[][] confusionMatrix, double[] initial) {
double[][] weights = trainWeightsSemiSup(data, biasedData, confusionMatrix, initial);
LinearClassifier<L, F> classifier = new LinearClassifier<>(weights, data.featureIndex(), data.labelIndex());
return classifier;
}
public double[][] trainWeightsSemiSup(GeneralDataset<L, F> data, GeneralDataset<L, F> biasedData, double[][] confusionMatrix, double[] initial) {
Minimizer<DiffFunction> minimizer = getMinimizer();
LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<>(data, new LogPrior(LogPrior.LogPriorType.NULL));
BiasedLogConditionalObjectiveFunction biasedObjective = new BiasedLogConditionalObjectiveFunction(biasedData, confusionMatrix, new LogPrior(LogPrior.LogPriorType.NULL));
SemiSupervisedLogConditionalObjectiveFunction semiSupObjective = new SemiSupervisedLogConditionalObjectiveFunction(objective, biasedObjective, logPrior);
if (initial == null) {
initial = objective.initial();
}
double[] weights = minimizer.minimize(semiSupObjective, TOL, initial);
return objective.to2D(weights);
}
/**
* Trains the linear classifier using Generalized Expectation criteria as described in
* <tt>Generalized Expectation Criteria for Semi Supervised Learning of Conditional Random Fields</tt>, Mann and McCallum, ACL 2008.
* The original algorithm is proposed for CRFs but has been adopted to LinearClassifier (which is a simpler special case of a CRF).
* IMPORTANT: the labeled features that are passed as an argument are assumed to be binary valued, although
* other features are allowed to be real valued.
*/
public LinearClassifier<L,F> trainSemiSupGE(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList, List<F> GEFeatures, double convexComboCoeff) {
Minimizer<DiffFunction> minimizer = getMinimizer();
LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<>(labeledDataset, new LogPrior(LogPrior.LogPriorType.NULL));
GeneralizedExpectationObjectiveFunction<L,F> geObjective = new GeneralizedExpectationObjectiveFunction<>(labeledDataset, unlabeledDataList, GEFeatures);
SemiSupervisedLogConditionalObjectiveFunction semiSupObjective = new SemiSupervisedLogConditionalObjectiveFunction(objective, geObjective, null,convexComboCoeff);
double[] initial = objective.initial();
double[] weights = minimizer.minimize(semiSupObjective, TOL, initial);
return new LinearClassifier<>(objective.to2D(weights), labeledDataset.featureIndex(), labeledDataset.labelIndex());
}
/**
* Trains the linear classifier using Generalized Expectation criteria as described in
* <tt>Generalized Expectation Criteria for Semi Supervised Learning of Conditional Random Fields</tt>, Mann and McCallum, ACL 2008.
* The original algorithm is proposed for CRFs but has been adopted to LinearClassifier (which is a simpler, special case of a CRF).
* Automatically discovers high precision, high frequency labeled features to be used as GE constraints.
* IMPORTANT: the current feature selector assumes the features are binary. The GE constraints assume the constraining features are binary anyway, although
* it doesn't make such assumptions about other features.
*/
public LinearClassifier<L,F> trainSemiSupGE(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList) {
List<F> GEFeatures = getHighPrecisionFeatures(labeledDataset,0.9,10);
return trainSemiSupGE(labeledDataset, unlabeledDataList, GEFeatures, 0.5);
}
public LinearClassifier<L,F> trainSemiSupGE(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList, double convexComboCoeff) {
List<F> GEFeatures = getHighPrecisionFeatures(labeledDataset,0.9,10);
return trainSemiSupGE(labeledDataset, unlabeledDataList, GEFeatures, convexComboCoeff);
}
/**
* Returns a list of featured thresholded by minPrecision and sorted by their frequency of occurrence.
* precision in this case, is defined as the frequency of majority label over total frequency for that feature.
*
* @return list of high precision features.
*/
private List<F> getHighPrecisionFeatures(GeneralDataset<L,F> dataset, double minPrecision, int maxNumFeatures){
int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()];
// shouldn't be necessary as Java zero fills arrays
// for(int f = 0; f < dataset.numFeatures(); f++)
// Arrays.fill(feature2label[f],0);
int[][] data = dataset.data;
int[] labels = dataset.labels;
for(int d = 0; d < data.length; d++){
int label = labels[d];
//System.out.println("datum id:"+d+" label id: "+label);
if(data[d] != null){
//System.out.println(" number of features:"+data[d].length);
for(int n = 0; n < data[d].length; n++){
feature2label[data[d][n]][label]++;
}
}
}
Counter<F> feature2freq = new ClassicCounter<>();
for(int f = 0; f < dataset.numFeatures(); f++){
int maxF = ArrayMath.max(feature2label[f]);
int total = ArrayMath.sum(feature2label[f]);
double precision = ((double)maxF)/total;
F feature = dataset.featureIndex.get(f);
if(precision >= minPrecision){
feature2freq.incrementCount(feature, total);
}
}
if(feature2freq.size() > maxNumFeatures){
Counters.retainTop(feature2freq, maxNumFeatures);
}
//for(F feature : feature2freq.keySet())
//System.out.println(feature+" "+feature2freq.getCount(feature));
//System.exit(0);
return Counters.toSortedList(feature2freq);
}
/**
* Train a classifier with a sigma tuned on a validation set.
*
* @return The constructed classifier
*/
public LinearClassifier<L, F> trainClassifierV(GeneralDataset<L, F> train, GeneralDataset<L, F> validation, double min, double max, boolean accuracy) {
labelIndex = train.labelIndex();
featureIndex = train.featureIndex();
this.min = min;
this.max = max;
heldOutSetSigma(train, validation);
double[][] weights = trainWeights(train);
return new LinearClassifier<>(weights, train.featureIndex(), train.labelIndex());
}
/**
* Train a classifier with a sigma tuned on a validation set.
* In this case we are fitting on the last 30% of the training data.
*
* @param train The data to train (and validate) on.
* @return The constructed classifier
*/
public LinearClassifier<L, F> trainClassifierV(GeneralDataset<L, F> train, double min, double max, boolean accuracy) {
labelIndex = train.labelIndex();
featureIndex = train.featureIndex();
tuneSigmaHeldOut = true;
this.min = min;
this.max = max;
heldOutSetSigma(train);
double[][] weights = trainWeights(train);
return new LinearClassifier<>(weights, train.featureIndex(), train.labelIndex());
}
/**
* setTuneSigmaHeldOut sets the {@code tuneSigmaHeldOut} flag: when turned on,
* the sigma is tuned by means of held-out (70%-30%). Otherwise no tuning on sigma is done.
* The default is false.
*/
public void setTuneSigmaHeldOut() {
tuneSigmaHeldOut = true;
tuneSigmaCV = false;
}
/**
* setTuneSigmaCV sets the {@code tuneSigmaCV} flag: when turned on,
* the sigma is tuned by cross-validation. The number of folds is the parameter.
* If there is less data than the number of folds, leave-one-out is used.
* The default is false.
*/
public void setTuneSigmaCV(int folds) {
tuneSigmaCV = true;
tuneSigmaHeldOut = false;
this.folds = folds;
}
/**
* NOTE: Nothing is actually done with this value.
*
* resetWeight sets the {@code restWeight} flag. This flag makes sense only if sigma is tuned:
* when turned on, the weights output by the tuneSigma method will be reset to zero when training the
* classifier.
* The default is false.
*/
public void resetWeight() {
//resetWeight = true;
}
protected static final double[] sigmasToTry = {0.5,1.0,2.0,4.0,10.0, 20.0, 100.0};
/**
* Calls the method {@link #crossValidateSetSigma(GeneralDataset, int)} with 5-fold cross-validation.
* @param dataset the data set to optimize sigma on.
*/
public void crossValidateSetSigma(GeneralDataset<L, F> dataset) {
crossValidateSetSigma(dataset, 5);
}
/**
* Calls the method {@link #crossValidateSetSigma(GeneralDataset, int, Scorer, LineSearcher)} with
* multi-class log-likelihood scoring (see {@link MultiClassAccuracyStats}) and golden-section line search
* (see {@link GoldenSectionLineSearch}).
*
* @param dataset the data set to optimize sigma on.
*/
public void crossValidateSetSigma(GeneralDataset<L, F> dataset,int kfold) {
logger.info("##you are here.");
crossValidateSetSigma(dataset, kfold, new MultiClassAccuracyStats<>(MultiClassAccuracyStats.USE_LOGLIKELIHOOD), new GoldenSectionLineSearch(true, 1e-2, min, max));
}
public void crossValidateSetSigma(GeneralDataset<L, F> dataset,int kfold, final Scorer<L> scorer) {
crossValidateSetSigma(dataset, kfold, scorer, new GoldenSectionLineSearch(true, 1e-2, min, max));
}
public void crossValidateSetSigma(GeneralDataset<L, F> dataset,int kfold, LineSearcher minimizer) {
crossValidateSetSigma(dataset, kfold, new MultiClassAccuracyStats<>(MultiClassAccuracyStats.USE_LOGLIKELIHOOD), minimizer);
}
/**
* Sets the sigma parameter to a value that optimizes the cross-validation score given by {@code scorer}. Search for an optimal value
* is carried out by {@code minimizer}.
*
* @param dataset the data set to optimize sigma on.
*/
public void crossValidateSetSigma(GeneralDataset<L, F> dataset,int kfold, final Scorer<L> scorer, LineSearcher minimizer) {
logger.info("##in Cross Validate, folds = " + kfold);
logger.info("##Scorer is " + scorer);
featureIndex = dataset.featureIndex;
labelIndex = dataset.labelIndex;
final CrossValidator<L, F> crossValidator = new CrossValidator<>(dataset, kfold);
final Function<Triple<GeneralDataset<L, F>,GeneralDataset<L, F>,CrossValidator.SavedState>,Double> scoreFn =
fold -> {
GeneralDataset<L, F> trainSet = fold.first();
GeneralDataset<L, F> devSet = fold.second();
double[] weights = (double[])fold.third().state;
double[][] weights2D;
weights2D = trainWeights(trainSet, weights,true); // must of course bypass sigma tuning here.
fold.third().state = ArrayUtils.flatten(weights2D);
LinearClassifier<L, F> classifier = new LinearClassifier<>(weights2D, trainSet.featureIndex, trainSet.labelIndex);
double score = scorer.score(classifier, devSet);
//System.out.println("score: "+score);
System.out.print(".");
return score;
};
Function<Double,Double> negativeScorer =
sigmaToTry -> {
//sigma = sigmaToTry;
setSigma(sigmaToTry);
Double averageScore = crossValidator.computeAverage(scoreFn);
logger.info("##sigma = "+getSigma() + " -> average Score: " + averageScore);
return -averageScore;
};
double bestSigma = minimizer.minimize(negativeScorer);
logger.info("##best sigma: " + bestSigma);
setSigma(bestSigma);
}
/**
* Set the {@link LineSearcher} to be used in {@link #heldOutSetSigma(GeneralDataset, GeneralDataset)}.
*/
public void setHeldOutSearcher(LineSearcher heldOutSearcher) {
this.heldOutSearcher = heldOutSearcher;
}
private LineSearcher heldOutSearcher; // = null;
public double[] heldOutSetSigma(GeneralDataset<L, F> train) {
Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> data = train.split(0.3);
return heldOutSetSigma(data.first(), data.second());
}
public double[] heldOutSetSigma(GeneralDataset<L, F> train, Scorer<L> scorer) {
Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> data = train.split(0.3);
return heldOutSetSigma(data.first(), data.second(), scorer);
}
public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev) {
return heldOutSetSigma(train, dev, new MultiClassAccuracyStats<>(MultiClassAccuracyStats.USE_LOGLIKELIHOOD), heldOutSearcher == null ? new GoldenSectionLineSearch(true, 1e-2, min, max) : heldOutSearcher);
}
public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, final Scorer<L> scorer) {
return heldOutSetSigma(train, dev, scorer, new GoldenSectionLineSearch(true, 1e-2, min, max));
}
public double[] heldOutSetSigma(GeneralDataset<L, F> train, GeneralDataset<L, F> dev, LineSearcher minimizer) {
return heldOutSetSigma(train, dev, new MultiClassAccuracyStats<>(MultiClassAccuracyStats.USE_LOGLIKELIHOOD), minimizer);
}
/**
* Sets the sigma parameter to a value that optimizes the held-out score given by {@code scorer}. Search for an
* optimal value is carried out by {@code minimizer} dataset the data set to optimize sigma on. kfold
*
* @return an interim set of optimal weights: the weights
*/
public double[] heldOutSetSigma(final GeneralDataset<L, F> trainSet, final GeneralDataset<L, F> devSet, final Scorer<L> scorer, LineSearcher minimizer) {
featureIndex = trainSet.featureIndex;
labelIndex = trainSet.labelIndex;
//double[] resultWeights = null;
Timing timer = new Timing();
NegativeScorer negativeScorer = new NegativeScorer(trainSet,devSet,scorer,timer);
timer.start();
double bestSigma = minimizer.minimize(negativeScorer);
logger.info("##best sigma: " + bestSigma);
setSigma(bestSigma);
return ArrayUtils.flatten(trainWeights(trainSet,negativeScorer.weights,true)); // make sure it's actually the interim weights from best sigma
}
class NegativeScorer implements Function<Double, Double> {
public double[] weights; // = null;
GeneralDataset<L, F> trainSet;
GeneralDataset<L, F> devSet;
Scorer<L> scorer;
Timing timer;
public NegativeScorer(GeneralDataset<L, F> trainSet, GeneralDataset<L, F> devSet, Scorer<L> scorer,Timing timer) {
super();
this.trainSet = trainSet;
this.devSet = devSet;
this.scorer = scorer;
this.timer = timer;
}
@Override
public Double apply(Double sigmaToTry) {
double[][] weights2D;
setSigma(sigmaToTry);
weights2D = trainWeights(trainSet, weights,true); //bypass.
weights = ArrayUtils.flatten(weights2D);
LinearClassifier<L, F> classifier = new LinearClassifier<>(weights2D, trainSet.featureIndex, trainSet.labelIndex);
double score = scorer.score(classifier, devSet);
//System.out.println("score: "+score);
//System.out.print(".");
logger.info("##sigma = " + getSigma() + " -> average Score: " + score);
logger.info("##time elapsed: " + timer.stop() + " milliseconds.");
timer.restart();
return -score;
}
}
/** If set to true, then when training a classifier, after an optimal sigma is chosen a model is relearned from
* scratch. If set to false (the default), then the model is updated from wherever it wound up in the sigma-tuning process.
* The latter is likely to be faster, but it's not clear which model will wind up better. */
public void setRetrainFromScratchAfterSigmaTuning( boolean retrainFromScratchAfterSigmaTuning) {
this.retrainFromScratchAfterSigmaTuning = retrainFromScratchAfterSigmaTuning;
}
public Classifier<L, F> trainClassifier(Iterable<Datum<L, F>> dataIterable) {
Minimizer<DiffFunction> minimizer = getMinimizer();
Index<F> featureIndex = Generics.newIndex();
Index<L> labelIndex = Generics.newIndex();
for (Datum<L, F> d : dataIterable) {
labelIndex.add(d.label());
featureIndex.addAll(d.asFeatures());//If there are duplicates, it doesn't add them again.
}
logger.info(String.format("Training linear classifier with %d features and %d labels", featureIndex.size(), labelIndex.size()));
LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<>(dataIterable, logPrior, featureIndex, labelIndex);
// [cdm 2014] Commented out next line. Why not use the logPrior set up previously and used at creation???
// objective.setPrior(new LogPrior(LogPrior.LogPriorType.QUADRATIC));
double[] initial = objective.initial();
double[] weights = minimizer.minimize(objective, TOL, initial);
LinearClassifier<L, F> classifier = new LinearClassifier<>(objective.to2D(weights), featureIndex, labelIndex);
return classifier;
}
public Classifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, float[] dataWeights, LogPrior prior) {
Minimizer<DiffFunction> minimizer = getMinimizer();
if (dataset instanceof RVFDataset) {
((RVFDataset<L,F>)dataset).ensureRealValues();
}
LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<>(dataset, dataWeights, prior);
double[] initial = objective.initial();
double[] weights = minimizer.minimize(objective, TOL, initial);
LinearClassifier<L, F> classifier = new LinearClassifier<>(objective.to2D(weights), dataset.featureIndex(), dataset.labelIndex());
return classifier;
}
@Override
public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
return trainClassifier(dataset, null);
}
public LinearClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset, double[] initial) {
// Sanity check
if (dataset instanceof RVFDataset) {
((RVFDataset<L, F>) dataset).ensureRealValues();
}
if (initial != null) {
for (double weight : initial) {
if (Double.isNaN(weight) || Double.isInfinite(weight)) {
throw new IllegalArgumentException("Initial weights are invalid!");
}
}
}
// Train classifier
double[][] weights = trainWeights(dataset, initial, false);
LinearClassifier<L, F> classifier = new LinearClassifier<>(weights, dataset.featureIndex(), dataset.labelIndex());
return classifier;
}
public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> dataset, double[][] initialWeights2D) {
double[] initialWeights = (initialWeights2D != null)? ArrayUtils.flatten(initialWeights2D):null;
return trainClassifier(dataset, initialWeights);
}
public LinearClassifier<L, F> trainClassifierWithInitialWeights(GeneralDataset<L, F> dataset, LinearClassifier<L,F> initialClassifier) {
double[][] initialWeights2D = (initialClassifier != null)? initialClassifier.weights():null;
return trainClassifierWithInitialWeights(dataset, initialWeights2D);
}
/**
* Given the path to a file representing the text based serialization of a
* Linear Classifier, reconstitutes and returns that LinearClassifier.
*
* TODO: Leverage Index
*/
public static LinearClassifier<String, String> loadFromFilename(String file) {
try {
BufferedReader in = IOUtils.readerFromString(file);
// Format: read indices first, weights, then thresholds
Index<String> labelIndex = HashIndex.loadFromReader(in);
Index<String> featureIndex = HashIndex.loadFromReader(in);
double[][] weights = new double[featureIndex.size()][labelIndex.size()];
int currLine = 1;
String line = in.readLine();
while (line != null && line.length()>0) {
String[] tuples = line.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER);
if (tuples.length != 3) {
throw new Exception("Error: incorrect number of tokens in weight specifier, line="
+ currLine + " in file " + file);
}
currLine++;
int feature = Integer.parseInt(tuples[0]);
int label = Integer.parseInt(tuples[1]);
double value = Double.parseDouble(tuples[2]);
weights[feature][label] = value;
line = in.readLine();
}
// First line in thresholds is the number of thresholds
int numThresholds = Integer.parseInt(in.readLine());
double[] thresholds = new double[numThresholds];
int curr = 0;
while ((line = in.readLine()) != null) {
double tval = Double.parseDouble(line.trim());
thresholds[curr++] = tval;
}
in.close();
LinearClassifier<String, String> classifier = new LinearClassifier<>(weights, featureIndex, labelIndex);
return classifier;
} catch (Exception e) {
throw new RuntimeIOException("Error in LinearClassifierFactory, loading from file=" + file, e);
}
}
public void setEvaluators(int iters, Evaluator[] evaluators) {
this.evalIters = iters;
this.evaluators = evaluators;
}
public LinearClassifierCreator<L,F> getClassifierCreator(GeneralDataset<L, F> dataset) {
// LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(dataset, logPrior);
return new LinearClassifierCreator<>(dataset.featureIndex, dataset.labelIndex);
}
public static class LinearClassifierCreator<L,F> implements ClassifierCreator, ProbabilisticClassifierCreator
{
LogConditionalObjectiveFunction objective;
Index<F> featureIndex;
Index<L> labelIndex;
public LinearClassifierCreator(LogConditionalObjectiveFunction objective, Index<F> featureIndex, Index<L> labelIndex)
{
this.objective = objective;
this.featureIndex = featureIndex;
this.labelIndex = labelIndex;
}
public LinearClassifierCreator(Index<F> featureIndex, Index<L> labelIndex)
{
this.featureIndex = featureIndex;
this.labelIndex = labelIndex;
}
public LinearClassifier createLinearClassifier(double[] weights) {
double[][] weights2D;
if (objective != null) {
weights2D = objective.to2D(weights);
} else {
weights2D = ArrayUtils.to2D(weights, featureIndex.size(), labelIndex.size());
}
return new LinearClassifier<>(weights2D, featureIndex, labelIndex);
}
@Override
public Classifier createClassifier(double[] weights) {
return createLinearClassifier(weights);
}
@Override
public ProbabilisticClassifier createProbabilisticClassifier(double[] weights) {
return createLinearClassifier(weights);
}
}
}