package LBJ2.learn;
import java.lang.reflect.Field;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.nio.channels.FileChannel;
import java.util.Arrays;
import java.util.Date;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import LBJ2.learn.Learner;
import LBJ2.parse.FoldSeparator;
import LBJ2.parse.FoldParser;
import LBJ2.parse.Parser;
import LBJ2.parse.ArrayFileParser;
import LBJ2.util.ExceptionlessInputStream;
import LBJ2.util.ExceptionlessOutputStream;
import LBJ2.util.TableFormat;
import LBJ2.util.Sort;
import LBJ2.util.StudentT;
/**
* Use this class to batch train a {@link Learner}.
*
* @author Nick Rizzolo
**/
public class BatchTrainer
{
/** <!-- writeExample(ExceptionlessOutputStream,int[],double[],int[],double[]) -->
* Writes an example vector to the specified stream, with all features
* being written in the order they appear in the vector.
*
* @param out The output stream.
* @param featureIndexes The lexicon indexes of the features.
* @param featureValues The values or "strengths" of the features.
* @param labelIndexes The lexicon indexes of the labels.
* @param labelValues The values or "strengths" of the labels.
**/
public static void writeExample(ExceptionlessOutputStream out,
int[] featureIndexes,
double[] featureValues, int[] labelIndexes,
double[] labelValues) {
writeExample(out, featureIndexes, featureValues, labelIndexes,
labelValues, featureIndexes.length, null);
}
/** <!-- writeExample(ExceptionlessOutputStream,int[],double[],int[],double[],int) -->
* Writes an example vector to the specified stream, with all features
* being written in the order they appear in the vector.
*
* @param out The output stream.
* @param featureIndexes The lexicon indexes of the features.
* @param featureValues The values or "strengths" of the features.
* @param labelIndexes The lexicon indexes of the labels.
* @param labelValues The values or "strengths" of the labels.
* @param unpruned The number of features in the vector that aren't
* pruned.
**/
public static void writeExample(ExceptionlessOutputStream out,
int[] featureIndexes,
double[] featureValues, int[] labelIndexes,
double[] labelValues, int unpruned) {
writeExample(out, featureIndexes, featureValues, labelIndexes,
labelValues, unpruned, null);
}
/** <!-- writeExample(ExceptionlessOutputStream,int[],double[],int[],double[],Lexicon) -->
* Writes an example vector contained in an object array to the underlying
* output stream, with features sorted according to their representations
* in the given lexicon if present, or in the order they appear in the
* vector otherwise.
*
* @param out The output stream.
* @param featureIndexes The lexicon indexes of the features.
* @param featureValues The values or "strengths" of the features.
* @param labelIndexes The lexicon indexes of the labels.
* @param labelValues The values or "strengths" of the labels.
* @param lex A lexicon.
**/
public static void writeExample(ExceptionlessOutputStream out,
int[] featureIndexes,
double[] featureValues, int[] labelIndexes,
double[] labelValues, Lexicon lex) {
writeExample(out, featureIndexes, featureValues, labelIndexes,
labelValues, featureIndexes.length, lex);
}
/** <!-- writeExample(ExceptionlessOutputStream,int[],double[],int[],double[],int,Lexicon) -->
* Writes an example vector contained in an object array to the underlying
* output stream, with features sorted according to their representations
* in the given lexicon if present, or in the order they appear in the
* vector otherwise.
*
* @param out The output stream.
* @param featureIndexes The lexicon indexes of the features.
* @param featureValues The values or "strengths" of the features.
* @param labelIndexes The lexicon indexes of the labels.
* @param labelValues The values or "strengths" of the labels.
* @param unpruned The number of features in the vector that aren't
* pruned.
* @param lexicon A lexicon.
**/
public static void writeExample(ExceptionlessOutputStream out,
final int[] featureIndexes,
double[] featureValues, int[] labelIndexes,
double[] labelValues, int unpruned,
final Lexicon lexicon) {
int[] I = null;
if (lexicon != null) {
I = new int[featureIndexes.length];
for (int i = 0; i < I.length; ++i) I[i] = i;
Sort.sort(I, 0, unpruned,
new Sort.IntComparator() {
public int compare(int i1, int i2) {
return lexicon.lookupKey(featureIndexes[i1])
.compareTo(lexicon.lookupKey(featureIndexes[i2]));
}
});
}
out.writeInt(labelIndexes.length);
for (int i = 0; i < labelIndexes.length; ++i) {
out.writeInt(labelIndexes[i]);
out.writeDouble(labelValues[i]);
}
out.writeInt(unpruned);
out.writeInt(featureIndexes.length - unpruned);
if (lexicon == null) {
for (int i = 0; i < featureIndexes.length; ++i) {
out.writeInt(featureIndexes[i]);
out.writeDouble(featureValues[i]);
}
}
else {
for (int i = 0; i < featureIndexes.length; ++i) {
out.writeInt(featureIndexes[I[i]]);
out.writeDouble(featureValues[I[i]]);
}
}
}
// Instance member variables.
/** The learning classifier being trained. */
protected Learner learner;
/** The parser from which training data for {@link #learner} is received. */
protected Parser parser;
/**
* The number of training examples in between status messages printed to
* <code>STDOUT</code>, or 0 to suppress these messages.
**/
protected int progressOutput;
/** Spacing for making status messages prettier. */
protected String messageIndent;
/** {@link #learner}'s class. */
protected Class learnerClass;
/** {@link #learner}'s <code>isTraining</code> field. */
protected Field fieldIsTraining;
/** The number of examples extracted during pre-extraction. */
protected int examples;
/** The number of features extracted during pre-extraction. */
protected int lexiconSize;
// Constructors.
/** <!-- <init>(Learner,String) -->
* Creates a new trainer that doesn't produce status messages.
*
* @param l The learner to be trained.
* @param p The path to an example file.
**/
public BatchTrainer(Learner l, String p) { this(l, p, true); }
/** <!-- <init>(Learner,String,int) -->
* Creates a new trainer that produces status messages.
*
* @param l The learner to be trained.
* @param p The path to an example file.
* @param o The number of examples in between status messages on STDOUT.
**/
public BatchTrainer(Learner l, String p, int o) { this(l, p, true, o); }
/** <!-- <init>(Learner,String,int,String) -->
* Creates a new trainer that produces status messages with the specified
* indentation spacing for status messages.
*
* @param l The learner to be trained.
* @param p The path to an example file.
* @param o The number of examples in between status messages on STDOUT.
* @param i The indentation spacing for status messages.
**/
public BatchTrainer(Learner l, String p, int o, String i) {
this(l, p, true, o, i);
}
/** <!-- <init>(Learner,String,boolean) -->
* Creates a new trainer that doesn't produce status messages.
*
* @param l The learner to be trained.
* @param p The path to an example file.
* @param z Whether or not the example file is compressed.
**/
public BatchTrainer(Learner l, String p, boolean z) {
this(l, new ArrayFileParser(p, z));
}
/** <!-- <init>(Learner,String,boolean,int) -->
* Creates a new trainer that produces status messages.
*
* @param l The learner to be trained.
* @param p The path to an example file.
* @param z Whether or not the example file is compressed.
* @param o The number of examples in between status messages on STDOUT.
**/
public BatchTrainer(Learner l, String p, boolean z, int o) {
this(l, new ArrayFileParser(p, z), o);
}
/** <!-- <init>(Learner,String,boolean,int,String) -->
* Creates a new trainer that produces status messages with the specified
* indentation spacing for status messages.
*
* @param l The learner to be trained.
* @param p The path to an example file.
* @param z Whether or not the example file is compressed.
* @param o The number of examples in between status messages on STDOUT.
* @param i The indentation spacing for status messages.
**/
public BatchTrainer(Learner l, String p, boolean z, int o, String i) {
this(l, new ArrayFileParser(p, z), o, i);
}
/** <!-- <init>(Learner,Parser) -->
* Creates a new trainer that doesn't produce status messages.
*
* @param l The learner to be trained.
* @param p The parser from which training data is received.
**/
public BatchTrainer(Learner l, Parser p) { this(l, p, 0); }
/** <!-- <init>(Learner,Parser,int) -->
* Creates a new trainer that produces status messages.
*
* @param l The learner to be trained.
* @param p The parser from which training data is received.
* @param o The number of examples in between status messages on STDOUT.
**/
public BatchTrainer(Learner l, Parser p, int o) { this(l, p, o, ""); }
/** <!-- <init>(Learner,Parser,int,String) -->
* Creates a new trainer that produces status messages with the specified
* indentation spacing for status messages.
*
* @param l The learner to be trained.
* @param p The parser from which training data is received.
* @param o The number of examples in between status messages on STDOUT.
* @param i The indentation spacing for status messages.
**/
public BatchTrainer(Learner l, Parser p, int o, String i) {
learner = l;
parser = p;
progressOutput = o;
messageIndent = i;
learnerClass = learner.getClass();
try { fieldIsTraining = learnerClass.getField("isTraining"); }
catch (Exception e) {
System.err.println("Can't access " + learnerClass
+ "'s 'isTraining' field: " + e);
System.exit(1);
}
}
/** Returns the value of {@link #progressOutput}. */
public int getProgressOutput() { return progressOutput; }
/** Returns the value of {@link #parser}. */
public Parser getParser() { return parser; }
/** <!-- setIsTraining(boolean) -->
* Sets the static <code>isTraining</code> flag inside {@link #learner}'s
* runtime class to the specified value. This probably doesn't need to
* be tinkered with after pre-extraction, since it can only affect the
* code that does the extraction.
*
* @param b The new value for the flag.
**/
protected void setIsTraining(boolean b) {
try { fieldIsTraining.setBoolean(null, b); }
catch (Exception e) {
System.err.println("Can't set " + learnerClass
+ "'s 'isTraining' field: " + e);
System.exit(1);
}
}
/** <!-- getIsTraining() -->
* Returns the value of the static <code>isTraining</code> flag inside
* {@link #learner}'s runtime class.
**/
protected boolean getIsTraining() {
try { return fieldIsTraining.getBoolean(null); }
catch (Exception e) {
System.err.println("Can't get " + learnerClass
+ "'s 'isTraining' field: " + e);
System.exit(1);
}
return false;
}
/** <!-- preExtract(String) -->
* Performs labeled feature vector pre-extraction into the specified file
* (or memory), replacing {@link #parser} with one that reads from that
* file (or memory). After pre-extraction, the lexicon is written to disk.
* It is assumed that {@link #learner} already knows where to write the
* lexicon. If it doesn't, call {@link Learner#setLexiconLocation(String)}
* or {@link Learner#setLexiconLocation(java.net.URL)} on that object
* before calling this method.
*
* <p> Calling this method is equivalent to calling
* {@link #preExtract(String,boolean)} with the second argument
* <code>true</code>.
*
* @param exampleFile The full path to a file into which examples will be
* written, or <code>null</code> to extract into
* memory.
* @return The resulting lexicon.
**/
public Lexicon preExtract(String exampleFile) {
return preExtract(exampleFile, true);
}
/** <!-- preExtract(String,boolean) -->
* Performs labeled feature vector pre-extraction into the specified file
* (or memory), replacing {@link #parser} with one that reads from that
* file (or memory). After pre-extraction, the lexicon is written to disk.
* It is assumed that {@link #learner} already knows where to write the
* lexicon. If it doesn't, call {@link Learner#setLexiconLocation(String)}
* or {@link Learner#setLexiconLocation(java.net.URL)} on that object
* before calling this method.
*
* @param exampleFile The full path to a file into which examples will be
* written, or <code>null</code> to extract into
* memory.
* @param zip Whether or not to compress the extracted examples.
* @return The resulting lexicon.
**/
public Lexicon preExtract(String exampleFile, boolean zip) {
Learner preExtractLearner =
preExtract(exampleFile, zip, Lexicon.CountPolicy.none);
preExtractLearner.saveLexicon();
return preExtractLearner.getLexicon();
}
/** <!-- preExtract(String,Lexicon.CountPolicy) -->
* Performs labeled feature vector pre-extraction into the specified file
* (or memory), replacing {@link #parser} with one that reads from that
* file (or memory). If <code>exampleFile</code> already exists, this
* method writes the examples to a temporary file, then copies the contents
* to the existing file after pre-extraction completes. This is done in
* case the parser providing the examples to this method is reading the
* existing file.
*
* <p> Note that this method does <i>not</i> write the feature lexicon it
* produces to disk. Calling this method is equivalent to calling
* {@link #preExtract(String,boolean,Lexicon.CountPolicy)} with the second
* argument <code>true</code>.
*
* @param exampleFile The full path to a file into which examples will be
* written, or <code>null</code> to extract into
* memory.
* @param countPolicy The feature counting policy for the learner's
* feature lexicon.
* @return A new learning classifier containing the lexicon built during
* pre-extraction.
**/
public Learner preExtract(String exampleFile,
Lexicon.CountPolicy countPolicy) {
return preExtract(exampleFile, true, countPolicy);
}
/** <!-- preExtract(String,boolean,Lexicon.CountPolicy) -->
* Performs labeled feature vector pre-extraction into the specified file
* (or memory), replacing {@link #parser} with one that reads from that
* file (or memory). If <code>exampleFile</code> already exists, this
* method writes the examples to a temporary file, then copies the contents
* to the existing file after pre-extraction completes. This is done in
* case the parser providing the examples to this method is reading the
* existing file.
*
* <p> Note that this method does <i>not</i> write the feature lexicon it
* produces to disk.
*
* @param exampleFile The full path to a file into which examples will be
* written, or <code>null</code> to extract into
* memory.
* @param zip Whether or not to compress the extracted examples.
* @param countPolicy The feature counting policy for the learner's
* feature lexicon.
* @return A new learning classifier containing the lexicon built during
* pre-extraction.
**/
public Learner preExtract(String exampleFile, boolean zip,
Lexicon.CountPolicy countPolicy) {
Learner preExtractLearner = learner.emptyClone();
preExtractLearner.setLabelLexicon(learner.getLabelLexicon());
Lexicon lexicon = learner.getLexicon();
preExtractLearner.setLexicon(lexicon);
preExtractLearner.countFeatures(countPolicy);
learner.setLexicon(null);
setIsTraining(true);
examples = 0;
// Establish an output stream for writing examples.
ExceptionlessOutputStream eos = null;
ByteArrayOutputStream baos = null;
File fExampleFile = null;
File fTempFile = null;
boolean copy = false;
if (exampleFile != null) {
fExampleFile = new File(exampleFile);
if (fExampleFile.exists()) {
int lastSlash = exampleFile.lastIndexOf(File.separatorChar);
try {
if (lastSlash == -1) fTempFile = File.createTempFile("LBJ", null);
else
fTempFile =
File.createTempFile(
"LBJ", null, new File(exampleFile.substring(0, lastSlash)));
}
catch (Exception e) {
System.err.println(
"LBJ ERROR: BatchTrainer.preExtract: Can't create temporary "
+ "file: " + e);
System.exit(1);
}
fTempFile.deleteOnExit();
copy = true;
}
else fTempFile = fExampleFile;
try {
if (zip)
eos =
ExceptionlessOutputStream.openCompressedStream(
fTempFile.toURI().toURL());
else
eos =
ExceptionlessOutputStream.openBufferedStream(
fTempFile.toURI().toURL());
}
catch (Exception e) {
System.err.println(
"LBJ ERROR: BatchTrainer.preExtract: Can't convert file name '"
+ fTempFile + "' to URL: " + e);
System.exit(1);
}
}
else {
baos = new ByteArrayOutputStream(1 << 18);
if (zip) {
ZipOutputStream zos = new ZipOutputStream(baos);
try {
zos.putNextEntry(
new ZipEntry(ExceptionlessInputStream.zipEntryName));
}
catch (Exception e) {
System.err.println("ERROR: Can't create in-memory zip data:");
e.printStackTrace();
System.exit(1);
}
eos = new ExceptionlessOutputStream(new BufferedOutputStream(zos));
}
else eos = new ExceptionlessOutputStream(baos);
}
// Write examples to the output stream.
boolean alreadyExtracted = parser instanceof ArrayFileParser;
if (alreadyExtracted) ((ArrayFileParser) parser).setIncludePruned(true);
for (Object example = parser.next(); example != null;
example = parser.next()) {
if (progressOutput > 0 && examples % progressOutput == 0)
System.out.println(
" " + learner.name + ", pre-extract: " + messageIndent + examples
+ " examples at " + new Date());
if (example == FoldSeparator.separator) eos.writeInt(-1);
else {
++examples;
Object[] exampleArray =
alreadyExtracted ? (Object[]) example
: preExtractLearner.getExampleArray(example);
int[] featureIndexes = (int[]) exampleArray[0];
double[] featureValues = (double[]) exampleArray[1];
int[] labelIndexes = (int[]) exampleArray[2];
double[] labelValues = (double[]) exampleArray[3];
if (alreadyExtracted && countPolicy != Lexicon.CountPolicy.none) {
int labelIndex =
countPolicy == Lexicon.CountPolicy.perClass
? labelIndexes[0] : -1;
for (int i = 0; i < featureIndexes.length; ++i)
lexicon.lookup(lexicon.lookupKey(featureIndexes[i]), true,
labelIndex);
}
writeExample(eos, featureIndexes, featureValues, labelIndexes,
labelValues, lexicon);
}
}
if (progressOutput > 0)
System.out.println(
" " + learner.name + ", pre-extract: " + messageIndent + examples
+ " examples at " + new Date());
parser.close();
eos.close();
if (copy) {
try {
FileChannel in = (new FileInputStream(fTempFile)).getChannel();
FileChannel out = (new FileOutputStream(fExampleFile)).getChannel();
in.transferTo(0, fTempFile.length(), out);
in.close();
out.close();
}
catch (Exception e) {
System.err.println("LBJ ERROR: Can't copy example file:");
e.printStackTrace();
System.exit(1);
}
}
setIsTraining(false);
lexiconSize = preExtractLearner.getLexicon().size();
// Set up a new parser to read the pre-extracted examples.
if (fTempFile != null)
parser = new ArrayFileParser(fTempFile.getPath(), zip);
else parser = new ArrayFileParser(baos.toByteArray(), zip);
learner.setLabelLexicon(preExtractLearner.getLabelLexicon());
return preExtractLearner;
}
/** <!-- fillInSizes() -->
* This method sets the {@link #examples} and {@link #lexiconSize}
* variables by querying {@link #parser} and {@link #learner} respectively.
* It sets {@link #examples} to 0 if {@link #parser} is not an
* {@link LBJ2.parse.ArrayFileParser} and {@link #lexiconSize} to 0 if
* {@link #learner} doesn't either have the lexicon loaded or know where to
* find it.
**/
public void fillInSizes() {
if (parser instanceof ArrayFileParser) {
ArrayFileParser afp = (ArrayFileParser) parser;
examples = afp.getNumExamples();
}
else examples = 0;
lexiconSize = learner.getPrunedLexiconSize();
}
/** <!-- pruneDataset(String,Lexicon.PruningPolicy,Learner) -->
* Prunes the data returned by {@link #parser} according to the given
* policy, under the assumption that feature counts have already been
* compiled in the given learner's lexicon. The pruned data is written to
* the given file (or memory), and at the end of the method,
* {@link #parser} is replaced with a new parser that reads from that file
* (or memory). The pruned lexicon is also written to disk.
*
* <p> If <code>exampleFile</code> already exists, this method writes the
* examples to a temporary file, then copies the contents to the existing
* file after pruning completes. This is done in case the parser providing
* the examples to this method is reading the existing file.
*
* <p> When calling this method, it must be the case that {@link #parser}
* is a {@link LBJ2.parse.ArrayFileParser}. This condition is easy to
* satisfy, since the
* {@link #preExtract(String,boolean,Lexicon.CountPolicy)} method will
* usually be called prior to this method to count the features in the
* dataset, and this method also replaces {@link #parser} with a
* {@link LBJ2.parse.ArrayFileParser}.
*
* <p> It is assumed that <code>preExtractLearner</code> already knows
* where to write the lexicon. If it doesn't, call
* {@link Learner#setLexiconLocation(String)} or
* {@link Learner#setLexiconLocation(java.net.URL)} on that object before
* calling this method.
*
* <p> Calling this method is equivalent to calling
* {@link #pruneDataset(String,boolean,Lexicon.PruningPolicy,Learner)} with
* the second argument <code>true</code>.
*
* @param exampleFile The full path to a file into which examples
* will be written, or <code>null</code> to
* extract into memory.
* @param policy The type of feature pruning.
* @param preExtractLearner A learner whose lexicon contains all the
* necessary feature count information.
**/
public void pruneDataset(String exampleFile, Lexicon.PruningPolicy policy,
Learner preExtractLearner) {
pruneDataset(exampleFile, true, policy, preExtractLearner);
}
/** <!-- pruneDataset(String,boolean,Lexicon.PruningPolicy,Learner) -->
* Prunes the data returned by {@link #parser} according to the given
* policy, under the assumption that feature counts have already been
* compiled in the given learner's lexicon. The pruned data is written to
* the given file (or memory), and at the end of the method,
* {@link #parser} is replaced with a new parser that reads from that file
* (or memory). The pruned lexicon is also written to disk.
*
* <p> If <code>exampleFile</code> already exists, this method writes the
* examples to a temporary file, then copies the contents to the existing
* file after pruning completes. This is done in case the parser providing
* the examples to this method is reading the existing file.
*
* <p> When calling this method, it must be the case that {@link #parser}
* is an {@link LBJ2.parse.ArrayFileParser ArrayFileParser}. This
* condition is easy to satisfy, since the
* {@link #preExtract(String,boolean,Lexicon.CountPolicy)} method will
* usually be called prior to this method to count the features in the
* dataset, and this method also replaces {@link #parser} with an
* {@link LBJ2.parse.ArrayFileParser ArrayFileParser}.
*
* <p> It is assumed that <code>preExtractLearner</code> already knows
* where to write the lexicon. If it doesn't, call
* {@link Learner#setLexiconLocation(String)} or
* {@link Learner#setLexiconLocation(java.net.URL)} on that object before
* calling this method.
*
* @param exampleFile The full path to a file into which examples
* will be written, or <code>null</code> to
* extract into memory.
* @param zip Whether or not to compress the extracted
* examples.
* @param policy The type of feature pruning.
* @param preExtractLearner A learner whose lexicon contains all the
* necessary feature count information.
**/
public void pruneDataset(String exampleFile, boolean zip,
Lexicon.PruningPolicy policy,
Learner preExtractLearner) {
Lexicon lexicon = preExtractLearner.getLexicon();
if (!policy.isNone()
&& lexicon.getCountPolicy() == Lexicon.CountPolicy.none)
throw new IllegalArgumentException(
"LBJ ERROR: BatchTrainer.pruneDataset: Can't prune with policy '"
+ policy + "' if features haven't been counted.");
if (!(parser instanceof ArrayFileParser))
throw new IllegalArgumentException(
"LBJ ERROR: BatchTrainer.pruneDataset can't be called unless "
+ "feature pre-extraction has already been performed.");
ArrayFileParser afp = (ArrayFileParser) parser;
afp.setIncludePruned(true);
int[] swapMap = lexicon.prune(policy);
// Establish an output stream for writing examples.
ExceptionlessOutputStream eos = null;
ByteArrayOutputStream baos = null;
File fExampleFile = null;
File fTempFile = null;
boolean copy = false;
if (exampleFile != null) {
fExampleFile = new File(exampleFile);
if (fExampleFile.exists()) {
int lastSlash = exampleFile.lastIndexOf(File.separatorChar);
try {
if (lastSlash == -1) fTempFile = File.createTempFile("LBJ", null);
else
fTempFile =
File.createTempFile(
"LBJ", null, new File(exampleFile.substring(0, lastSlash)));
}
catch (Exception e) {
System.err.println(
"LBJ ERROR: BatchTrainer.preExtract: Can't create temporary "
+ "file: " + e);
System.exit(1);
}
fTempFile.deleteOnExit();
copy = true;
}
else fTempFile = fExampleFile;
try {
if (zip)
eos =
ExceptionlessOutputStream.openCompressedStream(
fTempFile.toURI().toURL());
else
eos =
ExceptionlessOutputStream.openBufferedStream(
fTempFile.toURI().toURL());
}
catch (Exception e) {
System.err.println(
"LBJ ERROR: BatchTrainer.preExtract: Can't convert file name '"
+ fTempFile + "' to URL: " + e);
System.exit(1);
}
}
else {
baos = new ByteArrayOutputStream(1 << 18);
if (zip) {
ZipOutputStream zos = new ZipOutputStream(baos);
try {
zos.putNextEntry(
new ZipEntry(ExceptionlessInputStream.zipEntryName));
}
catch (Exception e) {
System.err.println("ERROR: Can't create in-memory zip data:");
e.printStackTrace();
System.exit(1);
}
eos = new ExceptionlessOutputStream(new BufferedOutputStream(zos));
}
else eos = new ExceptionlessOutputStream(baos);
}
// Write examples to the output stream.
examples = 0;
for (Object example = afp.next(); example != null; example = afp.next()) {
if (progressOutput > 0 && examples % progressOutput == 0)
System.out.println(" " + learner.name + ", pruning: " + examples
+ " examples at " + new Date());
if (example == FoldSeparator.separator) eos.writeInt(-1);
else {
++examples;
Object[] exampleArray = (Object[]) example;
int[] featureIndexes = (int[]) exampleArray[0];
double[] featureValues = (double[]) exampleArray[1];
int[] labelIndexes = (int[]) exampleArray[2];
double[] labelValues = (double[]) exampleArray[3];
int unpruned = featureIndexes.length;
if (swapMap != null) {
// First, map the old feature indexes to the new ones.
for (int i = 0; i < featureIndexes.length; ++i)
featureIndexes[i] = swapMap[featureIndexes[i]];
// Second, put the pruned features at the end of the example array.
while (unpruned > 0
&& lexicon.isPruned(featureIndexes[unpruned - 1],
labelIndexes[0], policy))
--unpruned;
for (int i = unpruned - 2; i >= 0; --i)
if (lexicon.isPruned(featureIndexes[i], labelIndexes[0], policy))
{
int t = featureIndexes[i];
featureIndexes[i] = featureIndexes[--unpruned];
featureIndexes[unpruned] = t;
double d = featureValues[i];
featureValues[i] = featureValues[unpruned];
featureValues[unpruned] = d;
}
}
writeExample(eos, featureIndexes, featureValues, labelIndexes,
labelValues, unpruned, lexicon);
}
}
if (progressOutput > 0)
System.out.println(" " + learner.name + ", pruning: " + examples
+ " examples at " + new Date());
parser.close();
eos.close();
if (copy) {
try {
FileChannel in = (new FileInputStream(fTempFile)).getChannel();
FileChannel out = (new FileOutputStream(fExampleFile)).getChannel();
in.transferTo(0, fTempFile.length(), out);
in.close();
out.close();
}
catch (Exception e) {
System.err.println("LBJ ERROR: Can't copy example file:");
e.printStackTrace();
System.exit(1);
}
}
lexiconSize = lexicon.getCutoff();
preExtractLearner.saveLexicon();
// Set up a new parser to read the pre-extracted and pruned examples.
if (fTempFile != null)
parser = new ArrayFileParser(fTempFile.getPath(), zip);
else parser = new ArrayFileParser(baos.toByteArray(), zip);
}
/** <!-- interface DoneWithRound -->
* Provides access to a hook into {@link #train(int)} so that additional
* processing can be performed at the end of each round. This processing
* supplements the processing in {@link Learner#doneWithRound()} which is
* already called from withink {@link #train(int)}.
**/
public static interface DoneWithRound
{
/**
* The hook into {@link #train(int)} as described above.
*
* @param r The 1-based number of the training round that just
* completed.
**/
public void doneWithRound(int r);
}
/** <!-- train(int) -->
* Trains {@link #learner} for the specified number of rounds. This
* learning happens on top of any learning that {@link #learner} may have
* already done.
*
* @param rounds The number of passes to make over the training data.
**/
public void train(int rounds) { train(1, rounds); }
/** <!-- train(int,int) -->
* Trains {@link #learner} for the specified number of rounds. This
* learning happens on top of any learning that {@link #learner} may have
* already done.
*
* @param start The 1-based number of the first training round.
* @param rounds The total number of training rounds including those before
* <code>start</code>.
**/
public void train(int start, int rounds) {
train(start, rounds,
new DoneWithRound() { public void doneWithRound(int r) { } });
}
/** <!-- train(int,DoneWithRound) -->
* Trains {@link #learner} for the specified number of rounds. This
* learning happens on top of any learning that {@link #learner} may have
* already done.
*
* @param rounds The number of passes to make over the training data.
* @param dwr Performs post processing at the end of each round.
**/
public void train(int rounds, DoneWithRound dwr) {
train(1, rounds, dwr);
}
/** <!-- train(int,int,DoneWithRound) -->
* Trains {@link #learner} for the specified number of rounds. This
* learning happens on top of any learning that {@link #learner} may have
* already done.
*
* @param start The 1-based number of the first training round.
* @param rounds The total number of training rounds including those before
* <code>start</code>.
* @param dwr Performs post processing at the end of each round.
**/
public void train(int start, int rounds, DoneWithRound dwr) {
if (lexiconSize > 0) {
// If the parser is a FoldParser, it means we're doing cross validation
// in which we train on just part of the data. So the examples variable
// doesn't accurately reflect how many training examples we'll see in
// this episode of training.
learner.initialize(parser instanceof FoldParser ? 0 : examples,
lexiconSize);
}
else setIsTraining(true);
for (int i = start; i <= rounds; ++i) {
int examples = 0;
for (Object example = parser.next(); example != null;
example = parser.next()) {
if (example == FoldSeparator.separator) continue;
if (progressOutput > 0 && examples % progressOutput == 0) {
System.out.print(" " + learner.name + ": " + messageIndent);
if (rounds != 1) System.out.print("Round " + i + ", ");
System.out.println(examples + " examples processed at "
+ new Date());
}
learner.learn(example);
++examples;
}
if (progressOutput > 0) {
System.out.print(" " + learner.name + ": " + messageIndent);
if (rounds != 1) System.out.print("Round " + i + ", ");
System.out.println(examples + " examples processed at " + new Date());
}
parser.reset();
learner.doneWithRound();
dwr.doneWithRound(i);
}
learner.doneLearning();
if (lexiconSize == 0) setIsTraining(false);
}
/** <!-- crossValidation(int[],int,FoldParser.SplitPolicy,double,TestingMetric,boolean) -->
* Performs cross validation, computing a confidence interval on the
* performance of the learner after each of the specified rounds of
* training. This method assumes that {@link #learner} has not yet done
* any learning. The learner will again be empty in this sense when the
* method exits, except that any label lexicon present before the method
* was called will be restored. The label lexicon needs to persist in this
* way so that it can ultimately be written into the model file.
*
* @param rounds An array of training rounds after which
* performance of the learner should be evaluated on
* the testing data.
* @param k The number of folds.
* @param splitPolicy The policy according to which the data is split
* up.
* @param alpha The fraction of the distribution to leave outside
* the confidence interval. For example, <code>alpha
* = .05</code> gives a 95% confidence interval.
* @param metric A metric with which to evaluate the learner on
* testing data.
* @param statusMessages If set <code>true</code> status messages will be
* produced, even if {@link #progressOutput} is zero.
* @return A 2D array <code>results</code> where <code>results[i][0]</code>
* is the average performance of the learner after
* <code>rounds[i]</code> rounds of training and
* <code>results[i][1]</code> is half the size of the corresponding
* confidence interval.
**/
public double[][] crossValidation(final int[] rounds,
int k,
FoldParser.SplitPolicy splitPolicy,
double alpha,
final TestingMetric metric,
boolean statusMessages) {
if (!(k > 1 || splitPolicy == FoldParser.SplitPolicy.manual))
throw new IllegalArgumentException(
"LBJ ERROR: BatchTrainer.crossValidation: if the data splitting "
+ "policy is not 'Manual', the number of folds must be greater "
+ "than 1.");
if (splitPolicy == FoldParser.SplitPolicy.manual) k = -1;
Arrays.sort(rounds);
final int totalRounds = rounds[rounds.length - 1];
// Status messages.
if (statusMessages || progressOutput > 0) {
System.out.print(" " + learner.name + ": " + messageIndent
+ "Cross Validation: ");
if (k != -1) System.out.print("k = " + k + ", ");
System.out.print("Split = " + splitPolicy);
if (totalRounds != 1) System.out.print(", Rounds = " + totalRounds);
System.out.println();
}
// Instantiate a fold parser.
final FoldParser foldParser;
// If we pre-extracted, we know how many examples there are already;
// otherwise FoldParser will have to compute it.
if (examples > 0)
foldParser = new FoldParser(parser, k, splitPolicy, 0, false, examples);
else foldParser = new FoldParser(parser, k, splitPolicy, 0, false);
parser = foldParser;
if (splitPolicy == FoldParser.SplitPolicy.manual) k = foldParser.getK();
final double[][] performances = new double[rounds.length][k];
Lexicon labelLexicon = learner.getLabelLexicon();
// Train and get testing performances for each fold.
for (int i = 0; i < k; foldParser.setPivot(++i)) {
if (statusMessages || progressOutput > 0)
System.out.println(
" " + learner.name + ": " + messageIndent
+ "Training against subset " + i + " at " + new Date());
final int fold = i;
messageIndent += " ";
train(totalRounds,
new DoneWithRound() {
int r = 0;
public void doneWithRound(int round) {
if (round < totalRounds && rounds[r] == round)
performances[r++][fold] =
crossValidationTesting(foldParser, metric, true, false);
}
});
performances[rounds.length - 1][i] =
crossValidationTesting(foldParser, metric, false, statusMessages);
messageIndent = messageIndent.substring(2);
learner.forget();
if (labelLexicon != null && labelLexicon.size() > 0
&& learner.getLabelLexicon().size() == 0)
learner.setLabelLexicon(labelLexicon);
}
parser = foldParser.getParser();
// Compute the confidence interval.
double[][] results = new double[rounds.length][];
boolean usingAccuracy = metric instanceof Accuracy;
for (int r = 0; r < rounds.length; ++r) {
results[r] = StudentT.confidenceInterval(performances[r], alpha);
if (r == rounds.length - 1 && statusMessages || progressOutput > 0) {
double mean = Math.round(results[r][0] * 100000) / 100000.0;
double half = Math.round(results[r][1] * 100000) / 100000.0;
System.out.print(
" " + learner.name + ": " + messageIndent + (100 * (1 - alpha))
+ "% confidence interval after " + rounds[r] + " rounds: "
+ mean);
if (usingAccuracy) System.out.print("%");
System.out.print(" +/- " + half);
if (usingAccuracy) System.out.print("%");
System.out.println();
}
}
return results;
}
/** <!-- crossValidationTesting(FoldParser,TestingMetric,boolean,boolean) -->
* Tests the learner as a subroutine inside cross validation.
*
* @param foldParser The cross validation parser that splits up the
* data.
* @param metric The metric used to evaluate the performance of the
* learner.
* @param clone Whether or not the learner should be cloned (and
* it should be cloned if more learning will take
* place after making this call).
* @param statusMessages If set <code>true</code> status messages will be
* produced, even if {@link #progressOutput} is zero.
* @return The result produced by the testing metric on the current cross
* validation fold expressed as a percentage (instead of a
* fraction) if the testing metric is {@link Accuracy}.
**/
protected double crossValidationTesting(FoldParser foldParser,
TestingMetric metric,
boolean clone,
boolean statusMessages) {
Parser originalParser = foldParser.getParser();
foldParser.setFromPivot(true);
Learner testLearner = learner;
if (clone) {
testLearner = (Learner) learner.clone();
testLearner.doneLearning();
}
double result = 0;
if (originalParser instanceof ArrayFileParser) {
ArrayFileParser afp = (ArrayFileParser) originalParser;
afp.setIncludePruned(true);
result = metric.test(testLearner, null, foldParser);
afp.setIncludePruned(false);
}
else {
setIsTraining(false);
result = metric.test(testLearner, testLearner.getLabeler(), foldParser);
setIsTraining(true);
}
foldParser.reset();
foldParser.setFromPivot(false);
if (metric instanceof Accuracy) result *= 100;
if (statusMessages || progressOutput > 0) {
double printResult = Math.round(result * 100000) / 100000.0;
System.out.print(
" " + learner.name + ": " + messageIndent + "Subset "
+ foldParser.getPivot() + " " + metric.getName() + ": "
+ printResult);
if (metric instanceof Accuracy) System.out.print("%");
System.out.println();
}
return result;
}
/** <!-- tune(Learner.Parameters[],int[],int,FoldParser.SplitPolicy,double,TestingMetric) -->
* Tune learning algorithm parameters using cross validation. Note that
* this interface takes both an array of
* {@link LBJ2.learn.Learner.Parameters} objects and an array of rounds.
* As such, the value in the {@link LBJ2.learn.Learner.Parameters#rounds}
* field is ignored during tuning. It is also overwritten in each of the
* {@link LBJ2.learn.Learner.Parameters} objects when the optimal number of
* rounds is determined in terms of the other parameters in each object.
* Finally, in addition to returning the parameters that got the best
* performance, this method also sets {@link #learner} with those
* parameters at the end of the method.
*
* <p> This method assumes that {@link #learner} has not yet done any
* learning. The learner will again be empty in this sense when the method
* exits, except that any label lexicon present before the method was
* called will be restored. The label lexicon needs to persist in this way
* so that it can ultimately be written into the model file.
*
* @param parameters An array of parameter settings objects.
* @param rounds An array of training rounds after which performance
* of the learner should be evaluated on the testing
* data.
* @param k The number of folds.
* @param splitPolicy The policy according to which the data is split up.
* @param alpha The fraction of the distribution to leave outside
* the confidence interval. For example, <code>alpha =
* .05</code> gives a 95% confidence interval.
* @param metric A metric with which to evaluate the learner.
* @return The element of <code>parameters</code> that resulted in the best
* performance according to <code>metric</code>.
**/
public Learner.Parameters tune(Learner.Parameters[] parameters,
int[] rounds,
int k,
FoldParser.SplitPolicy splitPolicy,
double alpha,
TestingMetric metric) {
int best = -1;
String[] parameterStrings = new String[parameters.length];
double[][] scores = new double[parameters.length][];
for (int i = 0; i < parameters.length; ++i) {
parameterStrings[i] = parameters[i].nonDefaultString();
// Status message.
if (progressOutput > 0)
System.out.println(
" " + learner.name + ": " + messageIndent + "Trying parameters ("
+ parameterStrings[i] + ")");
learner.setParameters(parameters[i]);
messageIndent += " ";
double[][] results =
crossValidation(rounds, k, splitPolicy, alpha, metric, false);
messageIndent = messageIndent.substring(2);
// Update best scores, rounds, and parameters.
int bestRounds = 0;
if (best == -1 || results[0][0] > scores[best][0]) best = i;
scores[i] = results[0];
for (int j = 1; j < results.length; ++j)
if (results[j][0] > scores[i][0]) {
bestRounds = j;
scores[i] = results[j];
if (results[j][0] > scores[best][0]) best = i;
}
parameters[i].rounds = rounds[bestRounds];
}
if (progressOutput > 0) {
// Print a table of results.
double[][] data = new double[parameters.length][4];
for (int i = 0; i < parameters.length; ++i) {
data[i][0] = i + 1;
data[i][1] = scores[i][0];
data[i][2] = scores[i][1];
data[i][3] = parameters[i].rounds;
}
String[] columnLabels = { "Set", metric.getName(), "+/-", "Rounds" };
int[] sigDigits = { 0, 3, 3, 0 };
String[] s =
TableFormat.tableFormat(columnLabels, null, data, sigDigits,
new int[]{ 0 });
System.out.println(" " + learner.name + ": " + messageIndent + "----");
System.out.println(
" " + learner.name + ": " + messageIndent + "Parameter sets:");
for (int i = 0; i < parameterStrings.length; ++i)
System.out.println(
" " + learner.name + ": " + messageIndent + (i+1) + ": "
+ parameterStrings[i]);
for (int i = 0; i < s.length; ++i)
System.out.println(" " + learner.name + ": " + messageIndent + s[i]);
System.out.println(" " + learner.name + ": " + messageIndent + "----");
// Status message.
double bestScore = Math.round(scores[best][0] * 100000) / 100000.0;
System.out.println(
" " + learner.name + ": " + messageIndent + "Best "
+ metric.getName() + ": " + bestScore);
System.out.print(
" " + learner.name + ": " + messageIndent + "with ");
if (parameterStrings[best].length() > 0) {
System.out.println(parameterStrings[best]);
System.out.print(
" " + learner.name + ": " + messageIndent + "and ");
}
System.out.println(parameters[best].rounds + " rounds");
}
learner.setParameters(parameters[best]);
return parameters[best];
}
/** <!-- tune(Learner.Parameters[],int[],Parser,TestingMetric) -->
* Tune learning algorithm parameters against a development set. Note that
* this interface takes both an array of
* {@link LBJ2.learn.Learner.Parameters} objects and an array of rounds.
* As such, the value in the {@link LBJ2.learn.Learner.Parameters#rounds}
* field is ignored during tuning. It is also overwritten in each of the
* {@link LBJ2.learn.Learner.Parameters} objects when the optimal number of
* rounds is determined in terms of the other parameters in each object.
* Finally, in addition to returning the parameters that got the best
* performance, this method also sets {@link #learner} with those
* parameters at the end of the method.
*
* <p> This method assumes that {@link #learner} has not yet done any
* learning. The learner will again be empty in this sense when the method
* exits, except that any label lexicon present before the method was
* called will be restored. The label lexicon needs to persist in this way
* so that it can ultimately be written into the model file.
*
* @param parameters An array of parameter settings objects.
* @param rounds An array of training rounds after which performance of
* the learner should be evaluated on the testing data.
* @param devParser A parser from which development set examples are
* obtained.
* @param metric A metric with which to evaluate the learner.
* @return The element of <code>parameters</code> that resulted in the best
* performance according to <code>metric</code>.
**/
public Learner.Parameters tune(Learner.Parameters[] parameters,
final int[] rounds,
final Parser devParser,
final TestingMetric metric) {
int best = -1;
double[] scores = new double[parameters.length];
String[] parameterStrings = new String[parameters.length];
Arrays.sort(rounds);
final int totalRounds = rounds[rounds.length - 1];
Lexicon labelLexicon = learner.getLabelLexicon();
for (int i = 0; i < parameters.length; ++i) {
parameterStrings[i] = parameters[i].nonDefaultString();
// Status message.
if (progressOutput > 0)
System.out.println(
" " + learner.name + ": " + messageIndent + "Trying parameters ("
+ parameterStrings[i] + ")");
final double[] results = new double[rounds.length];
learner.setParameters(parameters[i]);
messageIndent += " ";
train(totalRounds,
new DoneWithRound() {
int r = 0;
public void doneWithRound(int round) {
if (round < totalRounds && rounds[r] == round)
results[r++] = testMidTraining(devParser, metric, true);
}
});
results[rounds.length - 1] = testMidTraining(devParser, metric, false);
messageIndent = messageIndent.substring(2);
// Update best scores, rounds, and parameters.
int bestRounds = 0;
if (best == -1 || results[0] > scores[best]) best = i;
scores[i] = results[0];
for (int j = 1; j < results.length; ++j)
if (results[j] > scores[i]) {
bestRounds = j;
scores[i] = results[j];
if (results[j] > scores[best]) best = i;
}
parameters[i].rounds = rounds[bestRounds];
learner.forget();
if (labelLexicon != null && labelLexicon.size() > 0
&& learner.getLabelLexicon().size() == 0)
learner.setLabelLexicon(labelLexicon);
}
if (progressOutput > 0) {
// Print a table of results.
double[][] data = new double[parameters.length][3];
for (int i = 0; i < parameters.length; ++i) {
data[i][0] = i + 1;
data[i][1] = scores[i];
data[i][2] = parameters[i].rounds;
}
String[] columnLabels = { "Set", metric.getName(), "Rounds" };
int[] sigDigits = { 0, 3, 0 };
String[] s =
TableFormat.tableFormat(columnLabels, null, data, sigDigits,
new int[]{ 0 });
System.out.println(" " + learner.name + ": " + messageIndent + "----");
System.out.println(
" " + learner.name + ": " + messageIndent + "Parameter sets:");
for (int i = 0; i < parameterStrings.length; ++i)
System.out.println(
" " + learner.name + ": " + messageIndent + (i+1) + ": "
+ parameterStrings[i]);
for (int i = 0; i < s.length; ++i)
System.out.println(" " + learner.name + ": " + messageIndent + s[i]);
System.out.println(" " + learner.name + ": " + messageIndent + "----");
// Status message.
double bestScore = Math.round(scores[best] * 100000) / 100000.0;
System.out.println(
" " + learner.name + ": " + messageIndent + "Best "
+ metric.getName() + ": " + bestScore);
System.out.print(
" " + learner.name + ": " + messageIndent + "with ");
if (parameterStrings[best].length() > 0) {
System.out.println(parameterStrings[best]);
System.out.print(
" " + learner.name + ": " + messageIndent + "and ");
}
System.out.println(parameters[best].rounds + " rounds");
}
learner.setParameters(parameters[best]);
return parameters[best];
}
/** <!-- testMidTraining(Parser,TestingMetric,boolean) -->
* Tests {@link #learner} on the specified data while making provisions
* under the assumption that this test happens in between rounds of
* training.
*
* @param testParser A parser producing labeled testing examples.
* @param metric The metric used to evaluate the performance of the
* learner.
* @param clone Whether or not the learner should be cloned (and it
* should be cloned if more learning will take place
* after making this call).
* @return The result produced by the testing metric on the testing data
* expressed as a percentage (instead of a fraction) if the testing
* metric is {@link Accuracy}.
**/
protected double testMidTraining(Parser testParser,
TestingMetric metric,
boolean clone) {
Learner testLearner = clone ? (Learner) learner.clone() : learner;
testLearner.doneLearning();
double result = 0;
if (testParser instanceof ArrayFileParser) {
ArrayFileParser afp = (ArrayFileParser) testParser;
afp.setIncludePruned(true);
result = metric.test(testLearner, null, testParser);
afp.setIncludePruned(false);
}
else {
setIsTraining(false);
result = metric.test(testLearner, testLearner.getLabeler(), testParser);
setIsTraining(true);
}
testParser.reset();
if (metric instanceof Accuracy) result *= 100;
if (progressOutput > 0) {
double printResult = Math.round(result * 100000) / 100000.0;
System.out.print(
" " + learner.name + ": " + messageIndent + metric.getName() + ": "
+ printResult);
if (metric instanceof Accuracy) System.out.print("%");
System.out.println();
}
return result;
}
}