// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.edu
package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.util.*;
import java.io.*;
import java.util.*;
import java.util.zip.GZIPInputStream;
/**
* Subclass of CRFClassifier that performs dropout feature-noisying training
*
* @author Mengqiu Wang
*/
public class CRFClassifierFloat<IN extends CoreMap> extends CRFClassifier<IN> {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(CRFClassifierFloat.class);
protected CRFClassifierFloat() {
super(new SeqClassifierFlags());
}
public CRFClassifierFloat(Properties props) {
super(props);
}
public CRFClassifierFloat(SeqClassifierFlags flags) {
super(flags);
}
@Override
protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) {
CRFLogConditionalObjectiveFloatFunction func = new CRFLogConditionalObjectiveFloatFunction(data, labels,
featureIndex, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma);
cliquePotentialFunctionHelper = func;
QNMinimizer minimizer;
if (flags.interimOutputFreq != 0) {
FloatFunction monitor = new ResultStoringFloatMonitor(flags.interimOutputFreq, flags.serializeTo);
minimizer = new QNMinimizer(monitor);
} else {
minimizer = new QNMinimizer();
}
if (pruneFeatureItr == 0) {
minimizer.setM(flags.QNsize);
} else {
minimizer.setM(flags.QNsize2);
}
float[] initialWeights;
if (flags.initialWeights == null) {
initialWeights = func.initial();
} else {
try {
log.info("Reading initial weights from file " + flags.initialWeights);
DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(
flags.initialWeights))));
initialWeights = ConvertByteArray.readFloatArr(dis);
} catch (IOException e) {
throw new RuntimeException("Could not read from float initial weight file " + flags.initialWeights);
}
}
log.info("numWeights: " + initialWeights.length);
float[] weights = minimizer.minimize(func, (float) flags.tolerance, initialWeights);
return ArrayMath.floatArrayToDoubleArray(weights);
}
} // end class CRFClassifierFloat