/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.classification.trainer;
import java.util.List;
import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.AbstractWeightVector;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
abstract public class AbstractTrainer
{
protected final int RANDOM_SEED = 5;
protected final TrainerType t_type;
protected List<IntInstance> l_instances;
volatile protected AbstractWeightVector w_vector;
public AbstractTrainer(TrainerType type, SparseModel model)
{
l_instances = model.initializeForTraining();
w_vector = model.getWeightVector();
t_type = type;
}
public AbstractTrainer(TrainerType type, StringModel model, int labelCutoff, int featureCutoff)
{
l_instances = model.initializeForTraining(labelCutoff, featureCutoff);
w_vector = model.getWeightVector();
t_type = type;
}
public String trainerInfoFull()
{
StringBuilder build = new StringBuilder();
build.append(trainerInfo()); build.append("\n");
build.append("- Labels : "); build.append(getLabelSize()); build.append("\n");
build.append("- Features : "); build.append(getFeatureSize()); build.append("\n");
build.append("- Instances: "); build.append(getInstanceSize());
return build.toString();
}
abstract public String trainerInfo();
abstract public void train();
public int getLabelSize()
{
return w_vector.getLabelSize();
}
public int getFeatureSize()
{
return w_vector.getFeatureSize();
}
public int getInstanceSize()
{
return l_instances.size();
}
public IntInstance getInstance(int index)
{
return l_instances.get(index);
}
public TrainerType getTrainerType()
{
return t_type;
}
}