/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.component;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import org.tukaani.xz.LZMA2Options;
import org.tukaani.xz.XZInputStream;
import org.tukaani.xz.XZOutputStream;
import edu.emory.clir.clearnlp.classification.instance.StringInstance;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.trainer.AbstractOnlineTrainer;
import edu.emory.clir.clearnlp.classification.trainer.AdaGradSVM;
import edu.emory.clir.clearnlp.classification.vector.StringFeatureVector;
import edu.emory.clir.clearnlp.component.configuration.AbstractConfiguration;
import edu.emory.clir.clearnlp.component.evaluation.AbstractEval;
import edu.emory.clir.clearnlp.component.state.AbstractState;
import edu.emory.clir.clearnlp.component.utils.CFlag;
import edu.emory.clir.clearnlp.dependency.DEPTree;
import edu.emory.clir.clearnlp.feature.AbstractFeatureExtractor;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
abstract public class AbstractStatisticalComponent<LabelType, StateType extends AbstractState<?,LabelType>, EvalType extends AbstractEval<?>, FeatureType extends AbstractFeatureExtractor<?,?,?>, ConfigurationType extends AbstractConfiguration> extends AbstractComponent
{
protected ConfigurationType t_configuration;
protected FeatureType[] f_extractors;
protected StringModel[] s_models;
protected EvalType c_eval;
protected CFlag c_flag;
public AbstractStatisticalComponent() {}
/** Constructs a statistical component for collect. */
public AbstractStatisticalComponent(ConfigurationType configuration)
{
setConfiguration(configuration);
setFlag(CFlag.COLLECT);
}
/** Constructs a statistical component for train. */
public AbstractStatisticalComponent(ConfigurationType configuration, FeatureType[] extractors, Object lexicons, boolean binary, int modelSize)
{
setConfiguration(configuration);
setFlag(CFlag.TRAIN);
setFeatureExtractors(extractors);
setLexicons(lexicons);
setModels(createModels(binary, modelSize));
}
/** Constructs a statistical component for bootstrap or evaluate. */
public AbstractStatisticalComponent(ConfigurationType configuration, FeatureType[] extractors, Object lexicons, StringModel[] models, boolean bootstrap)
{
setConfiguration(configuration);
if (bootstrap)
setFlag(CFlag.BOOTSTRAP);
else
{
setFlag(CFlag.EVALUATE);
initEval();
}
setFeatureExtractors(extractors);
setLexicons(lexicons);
setModels(models);
}
/** Constructs a statistical component for decode. */
public AbstractStatisticalComponent(ConfigurationType configuration, ObjectInputStream in)
{
setConfiguration(configuration);
initDecode(in);
}
/** Constructs a statistical component for decode. */
public AbstractStatisticalComponent(ConfigurationType configuration, byte[] models)
{
setConfiguration(configuration);
initDecode(models);
}
private StringModel[] createModels(boolean binary, int modelSize)
{
StringModel[] models = new StringModel[modelSize];
int i;
for (i=0; i<modelSize; i++)
models[i] = new StringModel(binary);
return models;
}
protected void initDecode(ObjectInputStream in)
{
setFlag(CFlag.DECODE);
try
{
load(in);
}
catch (Exception e) {e.printStackTrace();}
}
protected void initDecode(byte[] models)
{
try
{
ObjectInputStream ois = new ObjectInputStream(new XZInputStream(new BufferedInputStream(new ByteArrayInputStream(models))));
initDecode(ois);
}
catch (IOException e) {e.printStackTrace();}
}
// ====================================== CONFIGURATION ======================================
public void setConfiguration(ConfigurationType configuration)
{
t_configuration = configuration;
}
// ====================================== LOAD/SAVE ======================================
/**
* Loads all models and objects of this component.
* @throws Exception
*/
@SuppressWarnings("unchecked")
public void load(ObjectInputStream in) throws Exception
{
setFeatureExtractors((FeatureType[])in.readObject());
setLexicons(in.readObject());
setModels(loadModels(in));
}
/**
* Saves all models and objects of this component.
* @throws Exception
*/
public void save(ObjectOutputStream out) throws Exception
{
out.writeObject(f_extractors);
out.writeObject(getLexicons());
saveModels(out);
}
private StringModel[] loadModels(ObjectInputStream in) throws Exception
{
int i, len = in.readInt();
StringModel[] models = new StringModel[len];
for (i=0; i<len; i++)
models[i] = new StringModel(in);
return models;
}
private void saveModels(ObjectOutputStream out) throws Exception
{
out.writeInt(s_models.length);
for (StringModel model : s_models)
model.save(out);
}
public byte[] toByteArray() throws Exception
{
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(bos), new LZMA2Options()));
save(oos);
oos.close();
return bos.toByteArray();
}
public byte[] modelsToByteArray() throws Exception
{
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(bos), new LZMA2Options()));
for (StringModel model : s_models) model.save(oos);
oos.close();
return bos.toByteArray();
}
public void byteArrayToModels(byte[] bytes) throws Exception
{
ObjectInputStream oin = new ObjectInputStream(new XZInputStream(new BufferedInputStream(new ByteArrayInputStream(bytes))));
for (StringModel model : s_models) model.load(oin);
oin.close();
}
// ====================================== LEXICONS ======================================
/** @return all objects containing lexicons. */
abstract public Object getLexicons();
/** Sets lexicons used for this component. */
abstract public void setLexicons(Object lexicons);
// ====================================== FEATURES ======================================
public FeatureType[] getFeatureExtractors()
{
return f_extractors;
}
public void setFeatureExtractors(FeatureType[] features)
{
f_extractors = features;
}
// ====================================== MODELS ======================================
public StringModel getModel(int index)
{
return s_models[index];
}
public StringModel[] getModels()
{
return s_models;
}
public void setModels(StringModel[] models)
{
s_models = models;
}
// ====================================== PROCESS ======================================
protected List<StringInstance> process(StateType state)
{
List<StringInstance> instances = isTrainOrBootstrap() ? new ArrayList<>() : null;
LabelType label;
while (!state.isTerminate())
{
switch (c_flag)
{
case TRAIN : label = train(state, instances); break;
case BOOTSTRAP: label = bootstrap(state, instances); break;
default : label = decode(state); break;
}
state.next(label);
}
return instances;
}
protected LabelType train(StateType state, List<StringInstance> instances)
{
StringFeatureVector vector = createStringFeatureVector(state);
LabelType goldLabel = state.getGoldLabel();
if (!vector.isEmpty()) instances.add(new StringInstance(goldLabel.toString(), vector));
return goldLabel;
}
protected LabelType bootstrap(StateType state, List<StringInstance> instances)
{
StringFeatureVector vector = createStringFeatureVector(state);
LabelType goldLabel = state.getGoldLabel();
if (!vector.isEmpty()) instances.add(new StringInstance(goldLabel.toString(), vector));
return getAutoLabel(state, vector);
}
protected LabelType decode(StateType state)
{
StringFeatureVector vector = createStringFeatureVector(state);
return getAutoLabel(state, vector);
}
abstract protected StringFeatureVector createStringFeatureVector(StateType state);
abstract protected LabelType getAutoLabel(StateType state, StringFeatureVector vector);
// ====================================== EVAL ======================================
public EvalType getEval()
{
return c_eval;
}
abstract protected void initEval();
// ====================================== FLAG ======================================
public CFlag getFlag()
{
return c_flag;
}
public void setFlag(CFlag flag)
{
c_flag = flag;
}
public boolean isCollect()
{
return c_flag == CFlag.COLLECT;
}
public boolean isTrain()
{
return c_flag == CFlag.TRAIN;
}
public boolean isBootstrap()
{
return c_flag == CFlag.BOOTSTRAP;
}
public boolean isEvaluate()
{
return c_flag == CFlag.EVALUATE;
}
public boolean isDecode()
{
return c_flag == CFlag.DECODE;
}
public boolean isTrainOrBootstrap()
{
return isTrain() || isBootstrap();
}
public boolean isDecodeOrEvaluate()
{
return isDecode() || isEvaluate();
}
// ====================================== ONLINE TRAIN ======================================
abstract public void onlineTrain(List<DEPTree> trees);
protected void onlineTrainSingleAdaGrad(List<DEPTree> trees)
{
// Given the list of gold-standard trees, measure how accurate the current model performs
double currScore = onlineScore(trees);
if (currScore == 100) return;
onlineBootstrap(trees);
AbstractOnlineTrainer trainer = new AdaGradSVM(s_models[0], 0, 0, false, 0.01, 0.1, 0d);
byte[] prevModels;
double prevScore;
try
{
while (true)
{
prevModels = toByteArray();
prevScore = currScore;
trainer.train();
currScore = onlineScore(trees);
if (prevScore >= currScore)
{
initDecode(prevModels);
break;
}
}
}
catch (Exception e) {e.printStackTrace();}
}
protected double onlineScore(List<DEPTree> trees)
{
CFlag originalFlag = c_flag;
c_flag = CFlag.EVALUATE;
initEval();
for (DEPTree tree : trees)
process(tree);
c_flag = originalFlag;
return c_eval.getScore();
}
protected void onlineBootstrap(List<DEPTree> trees)
{
CFlag originalFlag = c_flag;
c_flag = CFlag.BOOTSTRAP;
for (DEPTree tree : trees)
{
onlineLexicons(tree);
process(tree);
}
c_flag = originalFlag;
}
abstract protected void onlineLexicons(DEPTree tree);
}