package qa.qcri.aidr.predict.classification.nominal;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import org.apache.log4j.Logger;
import weka.classifiers.Classifier;
import weka.core.Instances;
import qa.qcri.aidr.predict.DataStore;
import qa.qcri.aidr.predict.common.*;
import qa.qcri.aidr.predict.data.Document;
import qa.qcri.aidr.predict.dbentities.ModelFamilyEC;
/**
* ModelController handles classification of DocumentSet objects, with the
* assumption that feature extraction has been done previously in the pipeline.
*
* @author jrogstadius & Imran
*/
public class ModelController extends PipelineProcess {
private static Logger logger = Logger.getLogger(ModelController.class);
ModelSet models = new ModelSet();
ModelRetrainTrigger trainingFeedMonitor;
long lastCheckForDefinitionUpdates = 0;
final long definitionChangeFrequencyMs = 5 * 60000;
//labelTrainingWeights contains <attributeID,<labelID, trainingWeight>>
HashMap<Integer, HashMap<Integer, Double>> labelTrainingWeights;
long lastLabelTrainingWeightUpdate = 0;
//classifiedDocCount contains <modelID,<labelID, documentCount>>
HashMap<Integer, HashMap<Integer, Integer>> classifiedDocCount = new HashMap<>();
long lastDocCountSaveTime = 0;
long nowTime;
public ModelController() {
trainingFeedMonitor = new ModelRetrainTrigger();
trainingFeedMonitor.onRetrain
.subscribe(new Function<EventArgs<CrisisAttributePair>>() {
@Override
public void execute(EventArgs<CrisisAttributePair> args) {
onRetrainModel(args.result);
}
});
Thread t = new Thread(trainingFeedMonitor);
t.start();
}
@Override
protected void processItem(Document item) {
// TODO: Decide if pre-labeled items should be be re-classified.
// Currently they are considered training examples and skipped.
if (item.hasHumanLabels()) {
return;
}
synchronized (this) {
loadModelsIfNeeded();
Model[] ms = models.getModels(item.getCrisisID().intValue());
for (Model m : ms) {
//Classify the document and get the output label
NominalLabelBC label = m.classify(item);
if (!classifiedDocCount.containsKey(m.getModelID())) {
classifiedDocCount.put(m.getModelID(), new HashMap<Integer, Integer>());
}
HashMap<Integer, Integer> labelCount = classifiedDocCount.get(m.getModelID());
if (!labelCount.containsKey(label.getNominalLabelID())) {
labelCount.put(label.getNominalLabelID(), 1);
} else {
int count = labelCount.get(label.getNominalLabelID());
labelCount.put(label.getNominalLabelID(), count + 1);
}
}
item.setValueAsTrainingSample(calculateValueAsTrainingSample(item));
writeClassifiedDocCountToDB();
}
}
private void writeClassifiedDocCountToDB() {
Calendar d = Calendar.getInstance();
nowTime = d.getTimeInMillis();
if (nowTime - lastDocCountSaveTime < 60000 || classifiedDocCount.isEmpty()) {
return;
}
HashMap<Integer, HashMap<Integer, Integer>> docCounts = classifiedDocCount;
classifiedDocCount = new HashMap<>();
Function<HashMap<Integer, HashMap<Integer, Integer>>> task =
new Function<HashMap<Integer, HashMap<Integer, Integer>>>() {
@Override
public void execute(HashMap<Integer, HashMap<Integer, Integer>> data) {
DataStore.saveClassifiedDocumentCounts(data);
lastDocCountSaveTime = nowTime;
}
};
AsyncWorker<HashMap<Integer, HashMap<Integer, Integer>>> worker = new AsyncWorker<>(task, docCounts);
worker.start();
}
@Override
protected void idle() {
writeClassifiedDocCountToDB();
}
private void loadModelsIfNeeded() {
long now = (new Date()).getTime();
if ((now - lastCheckForDefinitionUpdates) > definitionChangeFrequencyMs) {
ArrayList<ModelFamilyEC> activeModels = DataStore.getActiveModels();
for (ModelFamilyEC family : activeModels) {
if (family.getState() == ModelFamilyEC.State.RUNNING
&& !models.hasModel(family.getCrisisID(), family.getNominalAttribute().getNominalAttributeID())) {
loadModel(family.getCrisisID(),
family.getNominalAttribute().getNominalAttributeID(),
family.getCurrentModelID());
}
}
}
lastCheckForDefinitionUpdates = now;
}
// This method is called when the training set grows sufficiently (event
// based, see constructor)
private void onRetrainModel(CrisisAttributePair info) {
logger.info("Training a new model for event "
+ info.eventID + " and attribute " + info.attributeID);
Model oldModel = models.getModel(info.eventID, info.attributeID);
Model newModel;
try {
newModel = ModelFactory.buildModel(info.eventID, info.attributeID,
oldModel);
} catch (Exception e) {
logger.error("Exception while attempting to build model", e);
return;
}
if (newModel != null && newModel != oldModel) {
replaceModel(info.eventID, info.attributeID, newModel);
} else if (newModel == null) {
logger.info(
"Performance was too low, new model was discarded");
} else {
logger.info("New model did not outperform old model");
}
}
private void replaceModel(int eventID, int attributeID, Model newModel) {
logger.info("Replacing model for event " + eventID
+ " and attribute " + attributeID);
// Insert the new model into the database and update the currentModelID
// for this event and ontology
int modelID = DataStore.saveModelToDatabase(eventID, attributeID,
newModel);
if (modelID == DataStore.MODEL_ID_ERROR) {
logger.error("Model was not correctly written to the database. Aborting write to disk.");
throw new RuntimeException("Model was not correctly written to the database. Aborting write to disk.");
}
newModel.setModelID(modelID);
// Write the new model to the model store
try {
weka.core.SerializationHelper.writeAll(
getModelPath(eventID, attributeID, newModel.getModelID()), // filename
new Object[]{newModel.getClassifier(),
newModel.getAttributeSpecification()} // classifier
);
} catch (Exception e) {
logger.error("Exception when writing new model to file", e);
return; // There is no point proceeding if the model could not be
// written
}
// Pause classification and change the active model
synchronized (this) {
models.setModel(eventID, attributeID, newModel);
}
}
public void initialize() {
// Load models from disk into memory
DataStore.getAttributesLabels();
ArrayList<ModelFamilyEC> families = DataStore.getActiveModels();
for (ModelFamilyEC family : families) {
if (family.getState() == ModelFamilyEC.State.RUNNING) {
loadModel(family.getCrisisID(),
family.getNominalAttribute().getNominalAttributeID(),
family.getCurrentModelID());
}
}
trainingFeedMonitor.initialize(families);
}
private boolean loadModel(int eventID, int attributeID, int modelID) {
// Load models from dish and deserialize
Object[] o;
try {
String path = getModelPath(eventID, attributeID, modelID);
o = weka.core.SerializationHelper.readAll(path);
} catch (Exception e) {
System.out.println("Could not load model from disk (crisis " + eventID
+ ", attribute " + attributeID + ", model " + modelID
+ "). Delete model reference in DB and retrain? [y/n]");
try {
if (System.in.read() == 'y') {
DataStore.deleteModel(modelID);
onRetrainModel(new CrisisAttributePair(eventID, attributeID));
}
} catch (IOException ex) {
logger.warn("Unable to read input.");
}
return false;
}
Classifier classifier = (Classifier) o[0];
Instances specification = (Instances) o[1];
Model model = new Model(attributeID, classifier, specification);
model.setModelID(modelID);
models.setModel(eventID, attributeID, model);
logger.info("Loaded model for crisis " + eventID
+ ", attribute " + attributeID);
return true;
}
private static String getModelPath(int eventID, int attributeID, int modelID) {
String modelsPath = TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.MODEL_STORE_PATH);
if (!modelsPath.endsWith(File.separator)) {
modelsPath += File.separator;
}
return modelsPath + eventID + "_" + attributeID + "_"
+ modelID + ".model";
}
private void unloadModel(int eventID, int attributeID) {
// TODO: Unload models from memory when they get deleted or are no
// longer in use
models.removeModel(eventID, attributeID);
}
public double calculateValueAsTrainingSample(Document doc) {
ArrayList<NominalLabelBC> labels = doc.getLabels(NominalLabelBC.class);
if (labels.isEmpty()) {
return 0.5;
}
double sum = 0;
for (NominalLabelBC label : labels) {
double a = label.getConfidence();
double b = getLabelTrainingWeight(label.getAttributeID(), label.getNominalLabelID());
// weight increases with:
// increasing probability of this item belonging to a class with few
// samples decreasing confidence in the classification range of
// weight is [0-1]. For domain see
// http://www.wolframalpha.com/input/?i=plot+y%3D1+-+%280.5+%2B+0.5*%28%282*a-1%29*%282*b-1%29%29%29+from+a%3D0+to+1%2C+b%3D0+to+1
double weight = 1 - (0.5 + 0.5 * ((2 * a - 1) * (2 * b - 1)));
sum += weight;
}
return sum / (double) labels.size();
}
private double getLabelTrainingWeight(int attributeID, int labelID) {
if (labelTrainingWeights == null
|| (System.currentTimeMillis() - lastLabelTrainingWeightUpdate) > 300000) { // 5
// minutes
updateLabelTrainingWeights();
}
if (!labelTrainingWeights.containsKey(attributeID)
|| !labelTrainingWeights.get(attributeID)
.containsKey(labelID)) {
return 1;
}
return labelTrainingWeights.get(attributeID).get(labelID);
}
private void updateLabelTrainingWeights() {
labelTrainingWeights = DataStore.getNominalLabelTrainingValues();
lastLabelTrainingWeightUpdate = System.currentTimeMillis();
}
}