/**
* Copyright (C) 2013 Isabel Drost-Fromm
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.isabeldrostfromm.sof.naive;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import de.isabeldrostfromm.sof.ExampleProvider;
import de.isabeldrostfromm.sof.ModelTargets;
import de.isabeldrostfromm.sof.ModelTrainer;
import de.isabeldrostfromm.sof.Trainer;
/**
* Utility to start training and testing in one go. Demonstrates document vectorisation with Mahout
*
* TODO add tests
*
* TODO add some interaction features to deal with off_topic etc. states.
* */
public class SofTrainer {
/** Field to predict */
private static final String field = "open_status";
/** Number of training examples to use */
private static final int numTrain = 50 * ModelTargets.STATEVALUES.length;
/** Number of examples to use for testing */
private static final int numTest = 50;
/**
* First run a round of training (currently on the top-k documents returned by ES -
* - resulting in an unbalanced training set wrt. to posting status).
*
* Second run one round of testing and output testing results.
*
* Third store the resulting model in /tmp.
* */
public static void main (String args[]) throws Exception {
ModelTrainer trainer = new Trainer();
ExampleProvider train = RESTProvider.negatedFilterInstance(field, "invalid_status_string", 0, numTrain);
OnlineLogisticRegression model = trainer.train(train);
for (int i = 0; i < ModelTargets.STATEVALUES.length; i++) {
ExampleProvider test = RESTProvider.filterInstance(field, ModelTargets.STATEVALUES[i], numTrain, numTest);
trainer.apply(model, test);
}
trainer.store(model);
}
}