// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.edu
package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.util.*;
import java.util.*;
/**
* Subclass of CRFClassifier that performs dropout feature-noising training.
*
* @author Mengqiu Wang
*/
public class CRFClassifierWithDropout<IN extends CoreMap> extends CRFClassifier<IN> {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(CRFClassifierWithDropout.class);
private List<List<IN>> unsupDocs;
public CRFClassifierWithDropout(SeqClassifierFlags flags) {
super(flags);
}
@Override
protected Collection<List<IN>> loadAuxiliaryData(Collection<List<IN>> docs, DocumentReaderAndWriter<IN> readerAndWriter) {
if (flags.unsupDropoutFile != null) {
log.info("Reading unsupervised dropout data from file: " + flags.unsupDropoutFile);
Timing timer = new Timing();
timer.start();
unsupDocs = new ArrayList<>();
ObjectBank<List<IN>> unsupObjBank = makeObjectBankFromFile(flags.unsupDropoutFile, readerAndWriter);
for (List<IN> doc : unsupObjBank) {
for (IN tok: doc) {
tok.set(CoreAnnotations.AnswerAnnotation.class, flags.backgroundSymbol);
tok.set(CoreAnnotations.GoldAnswerAnnotation.class, flags.backgroundSymbol);
}
unsupDocs.add(doc);
}
long elapsedMs = timer.stop();
log.info("Time to read: : " + Timing.toSecondsString(elapsedMs) + " seconds");
}
if (unsupDocs != null && flags.doFeatureDiscovery) {
List<List<IN>> totalDocs = new ArrayList<>();
totalDocs.addAll(docs);
totalDocs.addAll(unsupDocs);
return totalDocs;
} else
return docs;
}
@Override
protected CRFLogConditionalObjectiveFunction getObjectiveFunction(int[][][][] data, int[][] labels) {
int[][][][] unsupDropoutData = null;
if (unsupDocs != null) {
Timing timer = new Timing();
timer.start();
List<Triple<int[][][], int[], double[][][]>> unsupDataAndLabels = documentsToDataAndLabelsList(unsupDocs);
unsupDropoutData = new int[unsupDataAndLabels.size()][][][];
for (int q=0; q<unsupDropoutData.length; q++)
unsupDropoutData[q] = unsupDataAndLabels.get(q).first();
long elapsedMs = timer.stop();
log.info("Time to read unsupervised dropout data: " + Timing.toSecondsString(elapsedMs) + " seconds, read " + unsupDropoutData.length + " files");
}
return new CRFLogConditionalObjectiveFunctionWithDropout(data, labels, windowSize, classIndex,
labelIndices, map, flags.priorType, flags.backgroundSymbol, flags.sigma, null, flags.dropoutRate, flags.dropoutScale, flags.multiThreadGrad, flags.dropoutApprox, flags.unsupDropoutScale, unsupDropoutData);
}
} // end class CRFClassifierWithDropout