package de.isabeldrostfromm.sof.termvector;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import de.isabeldrostfromm.sof.ExampleProvider;
import de.isabeldrostfromm.sof.ModelTargets;
import de.isabeldrostfromm.sof.ModelTrainer;
import de.isabeldrostfromm.sof.Trainer;
public class SofTrainer {
/** Field to predict */
private static final String field = "open_status";
/** Number of training examples to use */
private static final int numTrain = 50 * ModelTargets.STATEVALUES.length;
/** Number of examples to use for testing */
private static final int numTest = 50;
/**
* First run a round of training (currently on the top-k documents returned by ES -
* - resulting in an unbalanced training set wrt. to posting status).
*
* Second run one round of testing and output testing results.
*
* Third store the resulting model in /tmp.
* */
public static void main (String args[]) throws Exception {
ModelTrainer trainer = new Trainer();
ExampleProvider train = RESTProvider.negatedFilterInstance(field, "invalid_status_string", 0, numTrain);
OnlineLogisticRegression model = trainer.train(train);
for (int i = 0; i < ModelTargets.STATEVALUES.length; i++) {
ExampleProvider test = RESTProvider.filterInstance(field, ModelTargets.STATEVALUES[i], numTrain, numTest);
trainer.apply(model, test);
}
trainer.store(model);
}
}