/**
* Copyright (c) 2009, Regents of the University of Colorado All rights
* reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer. Redistributions in binary
* form must reproduce the above copyright notice, this list of conditions and
* the following disclaimer in the documentation and/or other materials provided
* with the distribution. Neither the name of the University of Colorado at
* Boulder nor the names of its contributors may be used to endorse or promote
* products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package clear.engine;
import clear.decode.OneVsAllDecoder;
import clear.dep.DepNode;
import clear.dep.DepTree;
import clear.ftr.map.DepFtrMap;
import clear.ftr.xml.DepFtrXml;
import clear.model.OneVsAllModel;
import clear.parse.AbstractDepParser;
import clear.parse.AbstractParser;
import clear.parse.ShiftEagerParser;
import clear.parse.ShiftPopParser;
import clear.reader.AbstractReader;
import clear.reader.CoNLLXReader;
import clear.reader.DepReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.PrintStream;
import org.apache.commons.compress.archivers.jar.JarArchiveEntry;
import org.apache.commons.compress.archivers.jar.JarArchiveOutputStream;
import org.apache.commons.compress.utils.IOUtils;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
/**
* Trains conditional dependency parser. <b>Last update:</b> 11/19/2010
*
* @author Jinho D. Choi
*/
public class DepTrain extends AbstractTrain {
@Option(name = "-t", usage = "feature template file", required = true, metaVar = "REQUIRED")
private String s_featureXml = null;
@Option(name = "-i", usage = "training file", required = true, metaVar = "REQUIRED")
private String s_trainFile = null;
@Option(name = "-n", usage = "bootstrapping level (default = 2)", required = false, metaVar = "OPTIONAL")
private int n_boot = 2;
private DepFtrXml t_xml = null;
private DepFtrMap t_map = null;
private OneVsAllModel m_model = null;
@Override
public void initElements() {
}
public DepTrain(String[] args) {
CmdLineParser cmd = new CmdLineParser(this);
try {
cmd.parseArgument(args);
init();
train();
} catch (CmdLineException e) {
System.err.println(e.getMessage());
cmd.printUsage(System.err);
} catch (Exception e) {
e.printStackTrace();
}
}
public DepTrain(String configFile, String featureXml, String trainFile, String modelFile, int nBoot) {
s_configFile = configFile;
s_featureXml = featureXml;
s_trainFile = trainFile;
s_modelFile = modelFile;
n_boot = nBoot;
try {
init();
train();
} catch (Exception e) {
e.printStackTrace();
}
}
public final void train() throws Exception {
printConfig();
String modelFile = s_modelFile;
JarArchiveOutputStream zout = new JarArchiveOutputStream(new FileOutputStream(modelFile));
trainDepParser(AbstractParser.FLAG_TRAIN_LEXICON, null);
trainDepParser(AbstractParser.FLAG_TRAIN_INSTANCE, zout);
m_model = (OneVsAllModel) trainModel(0, zout);
a_yx = null;
zout.flush();
zout.close();
for (int i = 1; i <= n_boot; i++) {
modelFile = s_modelFile + ".boot" + i;
out.print("\n== Bootstrapping: " + i + " ==\n");
zout = new JarArchiveOutputStream(new FileOutputStream(modelFile));
trainDepParser(AbstractParser.FLAG_TRAIN_BOOST, zout);
m_model = null;
m_model = (OneVsAllModel) trainModel(0, zout);
a_yx = null;
zout.flush();
zout.close();
}
new File(ENTRY_LEXICA).delete();
}
/**
* Trains the dependency parser.
*/
private void trainDepParser(byte flag, JarArchiveOutputStream zout) throws Exception {
AbstractDepParser parser = null;
OneVsAllDecoder decoder;
if (flag == ShiftPopParser.FLAG_TRAIN_LEXICON) {
out.println("\n* Save lexica");
switch (s_depParser) {
case AbstractDepParser.ALG_SHIFT_EAGER:
parser = new ShiftEagerParser(flag, s_featureXml);
break;
case AbstractDepParser.ALG_SHIFT_POP:
parser = new ShiftPopParser(flag, s_featureXml);
break;
}
} else if (flag == ShiftPopParser.FLAG_TRAIN_INSTANCE) {
out.println("\n* Print training instances");
out.println("- loading lexica");
switch (s_depParser) {
case AbstractDepParser.ALG_SHIFT_EAGER:
parser = new ShiftEagerParser(flag, t_xml, ENTRY_LEXICA);
break;
case AbstractDepParser.ALG_SHIFT_POP:
parser = new ShiftPopParser(flag, t_xml, ENTRY_LEXICA);
break;
}
} else if (flag == ShiftPopParser.FLAG_TRAIN_BOOST) {
out.println("\n* Train conditional");
decoder = new OneVsAllDecoder(m_model);
switch (s_depParser) {
case AbstractDepParser.ALG_SHIFT_EAGER:
parser = new ShiftEagerParser(flag, t_xml, t_map, decoder);
break;
case AbstractDepParser.ALG_SHIFT_POP:
parser = new ShiftPopParser(flag, t_xml, t_map, decoder);
break;
}
}
AbstractReader<DepNode, DepTree> reader = null;
DepTree tree;
int n;
switch (s_format) {
case AbstractReader.FORMAT_DEP:
reader = new DepReader(s_trainFile, true);
break;
case AbstractReader.FORMAT_CONLLX:
reader = new CoNLLXReader(s_trainFile, true);
break;
}
parser.setLanguage(s_language);
reader.setLanguage(s_language);
for (n = 0; (tree = reader.nextTree()) != null; n++) {
parser.parse(tree);
if (n % 1000 == 0) {
out.printf("\r- parsing: %dK", n / 1000);
}
}
out.println("\r- parsing: " + n);
if (flag == ShiftPopParser.FLAG_TRAIN_LEXICON) {
out.println("- saving");
parser.saveTags(ENTRY_LEXICA);
t_xml = parser.getDepFtrXml();
} else if (flag == ShiftPopParser.FLAG_TRAIN_INSTANCE || flag == ShiftPopParser.FLAG_TRAIN_BOOST) {
a_yx = parser.a_trans;
zout.putArchiveEntry(new JarArchiveEntry(ENTRY_PARSER));
PrintStream fout = new PrintStream(zout);
fout.print(s_depParser);
fout.flush();
zout.closeArchiveEntry();
zout.putArchiveEntry(new JarArchiveEntry(ENTRY_FEATURE));
IOUtils.copy(new FileInputStream(s_featureXml), zout);
zout.closeArchiveEntry();
zout.putArchiveEntry(new JarArchiveEntry(ENTRY_LEXICA));
IOUtils.copy(new FileInputStream(ENTRY_LEXICA), zout);
zout.closeArchiveEntry();
if (flag == ShiftPopParser.FLAG_TRAIN_INSTANCE) {
t_map = parser.getDepFtrMap();
}
}
}
protected void printConfig() {
out.println("* Configurations");
out.println("- language : " + s_language);
out.println("- format : " + s_format);
out.println("- parser : " + s_depParser);
out.println("- feature_xml: " + s_featureXml);
out.println("- train_file : " + s_trainFile);
out.println("- model_file : " + s_modelFile);
out.println("- n_boots : " + n_boot);
}
static public void main(String[] args) {
DepTrain depTrain = new DepTrain(args);
}
}