/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.bin.helper;
import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.util.List;
import java.util.concurrent.ExecutionException;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.spi.StringArrayOptionHandler;
import org.tukaani.xz.LZMA2Options;
import org.tukaani.xz.XZOutputStream;
import edu.emory.clir.clearnlp.collection.pair.ObjectDoublePair;
import edu.emory.clir.clearnlp.component.AbstractStatisticalComponent;
import edu.emory.clir.clearnlp.component.trainer.AbstractNLPTrainer;
import edu.emory.clir.clearnlp.component.utils.GlobalLexica;
import edu.emory.clir.clearnlp.component.utils.NLPMode;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.FileUtils;
import edu.emory.clir.clearnlp.util.IOUtils;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public abstract class AbstractNLPTrain
{
@Option(name="-c", usage="confinguration file (required)", required=true, metaVar="<filename>")
protected String s_configurationFile;
@Option(name="-f", usage="feature template files (required)", required=true, metaVar="<filename>", handler=StringArrayOptionHandler.class)
protected String[] s_featureFiles;
@Option(name="-m", usage="model filename (optional)", required=false, metaVar="<filename>")
protected String s_modelPath = null;
@Option(name="-t", usage="training path (required)", required=true, metaVar="<filepath>")
protected String s_trainPath;
@Option(name="-d", usage="development path (required)", required=true, metaVar="<filepath>")
protected String s_developPath;
@Option(name="-te", usage="training file extension (default: *)", required=false, metaVar="<string>")
protected String s_trainExt = "*";
@Option(name="-de", usage="development file extension (default: *)", required=false, metaVar="<string>")
protected String s_developExt = "*";
@Option(name="-mode", usage="pos|dep|ner|srl", required=true, metaVar="<mode>")
protected String s_mode = ".*";
// @Option(name="-threads", usage="number of threads (default: 1)", required=false, metaVar="<Integer>")
// protected int n_threads = 1;
@Option(name="-stop", usage="stopping score for training", required=false, metaVar="<double>")
static public double d_stop = 0;
public AbstractNLPTrain() {}
public AbstractNLPTrain(String[] args) throws InterruptedException, ExecutionException
{
BinUtils.initArgs(args, this);
List<String> trainFiles = FileUtils.getFileList(s_trainPath , s_trainExt , false);
List<String> developFiles = FileUtils.getFileList(s_developPath, s_developExt, false);
NLPMode mode = NLPMode.valueOf(s_mode);
// Collections.sort(trainFiles);
ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> p = train(trainFiles, developFiles, s_featureFiles, s_configurationFile, mode);
BinUtils.LOG.info(String.format("Final score: %4.2f\n", p.d));
if (s_modelPath != null) saveModel(p.o, s_modelPath);
}
public ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> train(List<String> trainFiles, List<String> developFiles, String[] featureFiles, String configurationFile, NLPMode mode)
{
InputStream configuration = IOUtils.createFileInputStream(configurationFile);
InputStream[] features = IOUtils.createFileInputStreams(featureFiles);
GlobalLexica.init(IOUtils.createFileInputStream(configurationFile));
AbstractNLPTrainer trainer = getTrainer(mode, configuration, features);
return trainer.train(trainFiles, developFiles);
}
public void saveModel(AbstractStatisticalComponent<?,?,?,?,?> component, String modelPath)
{
ObjectOutputStream out;
try
{
out = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(new FileOutputStream(modelPath)), new LZMA2Options()));
component.save(out);
out.close();
}
catch (Exception e) {e.printStackTrace();}
}
protected abstract AbstractNLPTrainer getTrainer(NLPMode mode, InputStream configuration, InputStream[] features);
// public NLPTrain(String[] args) throws InterruptedException, ExecutionException
// {
// BinUtils.initArgs(args, this);
//
// List<String> trainFiles = FileUtils.getFileList(s_trainPath , s_trainExt , false);
// List<String> developFiles = FileUtils.getFileList(s_developPath, s_developExt, false);
// NLPMode mode = NLPMode.valueOf(s_mode);
//
// List<Callable<ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String>>> tasks = new ArrayList<>();
// Callable<ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String>> c;
// ExecutorService executor = Executors.newFixedThreadPool(n_threads);
//
// for (String configurationFile : s_configurationFiles)
// {
// for (String featureFile : s_featureFiles)
// {
// System.out.println(featureFile);
//
// c = new Callable<ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String>>()
// {
// @Override
// public ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String> call() throws Exception
// {
// final ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> p = train(trainFiles, developFiles, Splitter.splitColons(featureFile), configurationFile, mode);
// return new ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String>(p.o, FileUtils.getBaseName(configurationFile)+", "+FileUtils.getBaseName(featureFile), p.d);
// }
// };
//
// tasks.add(c);
// }
// }
//
// List<Future<ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String>>> futures = executor.invokeAll(tasks);
// ObjectObjectDoubleTriple<AbstractStatisticalComponent<?,?,?,?,?>,String> max = null, t;
// int i, size = futures.size();
//
// for (i=0; i<size; i++)
// {
// t = futures.get(i).get();
// System.out.printf("%s: %5.2f\n", t.o2, t.d);
// if (max == null || max.compareTo(t) < 0) max = t;
// }
//
// executor.shutdown();
// if (size > 1) BinUtils.LOG.info(String.format("Best\n%s: %5.2f\n", max.o2, max.d));
// if (s_modelPath != null) saveModel(max.o1, s_modelPath);
// }
// void onlineTrain()
// {
// try
// {
// DefaultPOSTagger tagger = new DefaultPOSTagger(new ObjectInputStream(new XZInputStream(new BufferedInputStream(new FileInputStream(s_modelPath)))));
// for (DEPTree tree : getTrees())
// {
// tagger.process(tree);
// System.out.println(tree.toStringPOS()+"\n");
// }
// tagger.onlineTrain(getTrees());
// System.out.println("---------------------------\n");
// for (DEPTree tree : getTrees())
// {
// tagger.process(tree);
// System.out.println(tree.toStringPOS()+"\n");
// }
// }
// catch (Exception e) {e.printStackTrace();}
// }
//
// private List<DEPTree> getTrees()
// {
// List<DEPTree> list = Lists.newArrayList();
// DEPTree tree;
//
// tree = new DEPTree(5);
// tree.add(new DEPNode(1, "mr.", "NNP", new DEPFeat()));
// tree.add(new DEPNode(2, "boom", "NNP", new DEPFeat()));
// tree.add(new DEPNode(3, "toissed", "VBD", new DEPFeat()));
// tree.add(new DEPNode(4, "paat", "JJ", new DEPFeat()));
// tree.add(new DEPNode(5, "balll", "NN", new DEPFeat()));
// list.add(tree);
//
// tree = new DEPTree(4);
// tree.add(new DEPNode(1, "John", "NNP", new DEPFeat()));
// tree.add(new DEPNode(2, "bought", "VBD", new DEPFeat()));
// tree.add(new DEPNode(3, "a", "DT", new DEPFeat()));
// tree.add(new DEPNode(4, "car", "NN", new DEPFeat()));
// list.add(tree);
//
// return list;
// }
}