package qa.qcri.aidr.predict.classification.nominal;
import java.util.ArrayList;
import org.apache.log4j.Logger;
import qa.qcri.aidr.predict.DataStore;
import qa.qcri.aidr.predict.classification.ClassifierFactory;
import qa.qcri.aidr.predict.common.AlgorithmType;
import qa.qcri.aidr.predict.common.TaggerConfigurationProperty;
import qa.qcri.aidr.predict.common.TaggerConfigurator;
import weka.attributeSelection.AttributeSelection;
import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomForest;
import weka.core.Attribute;
import weka.core.Instances;
/**
* ModelFactory performs delegated encapsulated training of new classifiers.
* When a model is built, the class handles retrieval of training and evaluation
* data from the database, model training and evaluation.
*
* @author jrogstadius
*/
public class ModelFactory {
/**
* Train a new model for the specified event and ontology.
*
* @param crisisID
* @param attributeID
* @param oldModel An existing model to compare performance against. Null if
* no previous model exists.
* @return A new model if it outperforms the old model, otherwise the old
* model.
* @throws Exception
*/
private static Logger logger = Logger.getLogger(ModelFactory.class);
//private static final double EPSILON = 0.05; // Tolerance for comparing two models: added by koushik
private static final double PERFORMANCE_IMPROVEMENT_MARGIN = Double
.parseDouble(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.PERFORMANCE_IMPROVEMENT_MARGIN));
private static final int TRAINING_EXAMPLES_FORCE_RETRAIN = Integer
.parseInt(TaggerConfigurator
.getInstance()
.getProperty(
TaggerConfigurationProperty.TRAINING_EXAMPLES_FORCE_RETRAIN));
private static final int sampleCountThreshold = Integer
.parseInt(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.SAMPLE_COUNT_THRESHOLD));
public static Model buildModel(int crisisID, int attributeID, Model oldModel)
throws Exception {
// TODO: Improve model training to try different classifiers and
// different mixes of old and new data
// Get training and evaluation data
Instances trainingSet = DataStore.getTrainingSet(crisisID, attributeID);
Instances evaluationSet = DataStore.getEvaluationSet(crisisID,
attributeID, trainingSet);
if (trainingSet.attribute(trainingSet.numAttributes() - 1).numValues() < 2) {
logger.info("ModelFactory" +
"All training examples have the same label. Postponing training.");
return oldModel;
}
if (evaluationSet.numInstances() < 2) {
logger.info("ModelFactory" +
"The evaluation set is too small. Postponing training.");
return oldModel;
}
// Do attribute selection
AttributeSelection selector = getAttributeSelector(trainingSet);
trainingSet = selector.reduceDimensionality(trainingSet);
evaluationSet = selector.reduceDimensionality(evaluationSet);
// Train classifier
Classifier classifier = trainClassifier(trainingSet);
// Create the model object
Model model = new Model(attributeID, classifier, getTemplateSet(trainingSet));
model.setTrainingSampleCount(trainingSet.size());
// Evaluate classifier
model.evaluate(evaluationSet);
double newPerformance = model.getWeightedPerformance();
double oldPerformance = 0;
if (oldModel != null) {
oldModel.evaluate(evaluationSet);
oldPerformance = oldModel.getWeightedPerformance();
}
// Koushik: Changed as per ChaTo's suggestion
/*
if (newPerformance > oldPerformance - EPSILON) {
return model;
} else {
return oldModel;
}*/
if (newPerformance > oldPerformance - PERFORMANCE_IMPROVEMENT_MARGIN) {
return model;
} else if( model.getTrainingSampleCount() > oldModel.getTrainingSampleCount() + TRAINING_EXAMPLES_FORCE_RETRAIN) {
return model;
} else {
return oldModel;
}
}
private static Instances getTemplateSet(Instances dataSet) {
ArrayList<Attribute> attributes = new ArrayList<Attribute>(
dataSet.numAttributes());
for (int i = 0; i < dataSet.numAttributes(); i++) {
attributes.add(dataSet.attribute(i));
}
Instances specification = new Instances("spec", attributes, 0);
specification.setClassIndex(specification.numAttributes() - 1);
return specification;
}
private static AttributeSelection getAttributeSelector(
Instances trainingData) throws Exception {
AttributeSelection selector = new AttributeSelection();
InfoGainAttributeEval evaluator = new InfoGainAttributeEval();
Ranker ranker = new Ranker();
ranker.setNumToSelect(Math.min(500, trainingData.numAttributes() - 1));
selector.setEvaluator(evaluator);
selector.setSearch(ranker);
selector.SelectAttributes(trainingData);
return selector;
}
private static Classifier trainClassifier(Instances trainingSet)
throws Exception {
Classifier model = (Classifier) ClassifierFactory.getClassifier(AlgorithmType.RANDOM_FOREST);
model.buildClassifier(trainingSet);
return model;
}
}