/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.component.trainer;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import edu.emory.clir.clearnlp.bin.helper.AbstractNLPTrain;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer;
import edu.emory.clir.clearnlp.collection.list.FloatArrayList;
import edu.emory.clir.clearnlp.collection.pair.ObjectDoublePair;
import edu.emory.clir.clearnlp.component.AbstractStatisticalComponent;
import edu.emory.clir.clearnlp.component.configuration.AbstractConfiguration;
import edu.emory.clir.clearnlp.component.evaluation.AbstractEval;
import edu.emory.clir.clearnlp.dependency.DEPTree;
import edu.emory.clir.clearnlp.reader.TSVReader;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.IOUtils;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public abstract class AbstractNLPTrainer
{
protected AbstractConfiguration t_configuration;
// ====================================== CONSTRUCTORS ======================================
public AbstractNLPTrainer(InputStream configuration)
{
t_configuration = createConfiguration(configuration);
}
public ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> train(List<String> trainFiles, List<String> developFiles)
{
Object lexicons = getLexicons(trainFiles);
ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> prev = train(trainFiles, developFiles, lexicons, null, 0);
if (!t_configuration.isBootstrap() || AbstractNLPTrain.d_stop > 0) return prev;
ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> curr;
byte[] backup;
int boot = 1;
// prev.o.getModels()
try
{
while (true)
{
// save the previous model
// backup = prev.o.toByteArray();
backup = prev.o.modelsToByteArray();
curr = train(trainFiles, developFiles, lexicons, prev.o.getModels(), boot++);
if (prev.d >= curr.d)
{
// prev.o = createComponentForDecode(backup);
prev.o.byteArrayToModels(backup);
return prev;
}
prev = curr;
}
}
catch (Exception e) {e.printStackTrace();}
throw new IllegalStateException();
}
private Object getLexicons(List<String> trainFiles)
{
AbstractStatisticalComponent<?,?,?,?,?> component = createComponentForCollect();
Object lexicons = null;
if (component != null)
{
BinUtils.LOG.info("Collecting lexicons:\n");
process(component, trainFiles, true);
lexicons = component.getLexicons();
}
return lexicons;
}
private ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>> train(List<String> trainFiles, List<String> developFiles, Object lexicons, StringModel[] models, int boot)
{
// train
AbstractStatisticalComponent<?,?,?,?,?> component = (models == null) ? createComponentForTrain(lexicons) : createComponentForBootstrap(lexicons, models);
BinUtils.LOG.info("Generating training instances: "+boot+"\n");
process(component, trainFiles, true);
// evaluate
AbstractTrainer[] trainers = t_configuration.getTrainers(component.getModels());
component = createComponentForEvaluate(lexicons, component.getModels());
double score = trainPipeline(component, trainers, developFiles);
return new ObjectDoublePair<AbstractStatisticalComponent<?,?,?,?,?>>(component, score);
}
/** Initializes the training configuration. */
protected abstract AbstractConfiguration createConfiguration(InputStream in);
/** Creates an NLP component for collecting lexicons. */
protected abstract AbstractStatisticalComponent<?,?,?,?,?> createComponentForCollect();
/** Creates an NLP component for training. */
protected abstract AbstractStatisticalComponent<?,?,?,?,?> createComponentForTrain(Object lexicons);
/** Creates an NLP component for bootstrap. */
protected abstract AbstractStatisticalComponent<?,?,?,?,?> createComponentForBootstrap(Object lexicons, StringModel[] models);
/** Creates an NLP component for evaluation. */
protected abstract AbstractStatisticalComponent<?,?,?,?,?> createComponentForEvaluate(Object lexicons, StringModel[] models);
/** Creates an NLP component for decode. */
protected abstract AbstractStatisticalComponent<?,?,?,?,?> createComponentForDecode(byte[] models);
private double trainPipeline(AbstractStatisticalComponent<?,?,?,?,?> component, AbstractTrainer[] trainers, List<String> developFiles)
{
AbstractTrainer trainer;
double score = 0;
for (int i=0; i<trainers.length; i++)
{
trainer = trainers[i];
BinUtils.LOG.info(trainer.trainerInfoFull()+"\n");
}
switch (trainers[0].getTrainerType())
{
case ONLINE : score = trainOnline (component, trainers, developFiles); break;
case ONE_VS_ALL: score = trainOneVsAll(component, trainers, developFiles); break;
}
BinUtils.LOG.info("\n");
return score;
}
private double trainOnline(AbstractStatisticalComponent<?,?,?,?,?> component, AbstractTrainer[] trainers, List<String> developFiles)
{
int i, count, iter = -1, size = trainers.length;
FloatArrayList[] weights = new FloatArrayList[size];
StringModel[] models = component.getModels();
AbstractEval<?> eval = component.getEval();
boolean[] train = new boolean[size];
double currScore, prevScore = 0;
Arrays.fill(train, true);
do
{
count = 0;
iter++;
for (i=0; i<size; i++)
{
if (train[i])
{
trainers[i].train();
eval.clear();
process(component, developFiles, false);
currScore = eval.getScore();
BinUtils.LOG.info(String.format("%3d:%3d: %s\n", iter, i, eval.toString()));
if (prevScore < currScore)
{
count++;
prevScore = currScore;
weights[i] = models[i].getWeightVector().cloneWeights();
}
else
{
train[i] = false;
models[i].getWeightVector().setWeights(weights[i]);
}
}
}
}
while (count > 0);
return prevScore;
}
private double trainOneVsAll(AbstractStatisticalComponent<?,?,?,?,?> component, AbstractTrainer[] trainers, List<String> developFiles)
{
// AbstractEval<?> eval = component.getEval();
// trainer.train();
// process(component, developFiles, false);
// double currScore = eval.getScore();
// BinUtils.LOG.info(eval.toString());
// return currScore;
return 0;
}
// private double trainPipeline(AbstractStatisticalComponent<?,?,?,?,?> component, AbstractTrainer[] trainers, List<String> developFiles)
// {
// AbstractTrainer trainer;
// double score = 0;
//
// try
// {
// for (int i=0; i<trainers.length; i++)
// {
// trainer = trainers[i];
// BinUtils.LOG.info(trainer.trainerInfoFull()+"\n");
//
// switch (trainer.getTrainerType())
// {
// case ONLINE : score = trainOnline (component, (AbstractOnlineTrainer) trainer, developFiles, i); break;
// case ONE_VS_ALL: score = trainOneVsAll(component, (AbstractOneVsAllTrainer)trainer, developFiles); break;
// }
// }
// }
// catch (Exception e) {e.printStackTrace();}
//
// BinUtils.LOG.info("\n");
// return score;
// }
//
// private double trainOnline(AbstractStatisticalComponent<?,?,?,?,?> component, AbstractOnlineTrainer trainer, List<String> developFiles, int modelID) throws Exception
// {
// StringModel model = component.getModel(modelID);
// AbstractEval<?> eval = component.getEval();
// double currScore, prevScore = 0;
// byte[] prevWeights = null;
//
// for (int iter=1; ; iter++)
// {
// trainer.train();
// eval.clear();
// process(component, developFiles, false);
// currScore = eval.getScore();
// BinUtils.LOG.info(String.format("%3d: %s\n", iter, eval.toString()));
//
// if (0 < AbstractNLPTrain.d_stop && AbstractNLPTrain.d_stop <= currScore)
// break;
//
// if (prevScore < currScore)
// {
// prevScore = currScore;
// prevWeights = model.saveWeightVectorToByteArray();
// }
// else
// {
// model.loadWeightVectorFromByteArray(prevWeights);
// break;
// }
// }
//
// return prevScore;
// }
public void process(AbstractStatisticalComponent<?,?,?,?,?> component, List<String> filelist, boolean log)
{
for (String filename : filelist)
{
process(component, filename);
if (log) BinUtils.LOG.info(".");
}
if (log) BinUtils.LOG.info("\n\n");
}
public void process(AbstractStatisticalComponent<?,?,?,?,?> component, String filename)
{
TSVReader reader = (TSVReader)t_configuration.getReader();
reader.open(IOUtils.createFileInputStream(filename));
DEPTree tree;
while ((tree = reader.next()) != null)
component.process(tree);
reader.close();
}
}