/* Copyright 2012 Tim Garrett, Mothsoft LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mothsoft.alexis.engine.predictive;
import java.io.File;
import java.sql.Timestamp;
import java.text.DateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import opennlp.maxent.io.SuffixSensitiveGISModelReader;
import opennlp.model.AbstractModel;
import opennlp.model.MaxentModel;
import org.apache.commons.lang.time.StopWatch;
import org.apache.log4j.Logger;
import org.hibernate.ScrollableResults;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import org.springframework.transaction.support.TransactionCallback;
import org.springframework.transaction.support.TransactionTemplate;
import com.mothsoft.alexis.dao.DataSetPointDao;
import com.mothsoft.alexis.dao.DocumentDao;
import com.mothsoft.alexis.dao.ModelDao;
import com.mothsoft.alexis.domain.DataSetPoint;
import com.mothsoft.alexis.domain.Document;
import com.mothsoft.alexis.domain.Model;
import com.mothsoft.alexis.domain.ModelState;
import com.mothsoft.alexis.domain.ModelType;
import com.mothsoft.alexis.domain.SortOrder;
import com.mothsoft.alexis.domain.TimeUnits;
import com.mothsoft.alexis.engine.Task;
public class OpenNLPMaxentModelExecutorTask implements Task {
private static final Logger logger = Logger.getLogger(OpenNLPMaxentModelExecutorTask.class);
private static final Pattern OUTCOME_PATTERN = Pattern.compile("\\+(\\d+)(\\S+)\\:(\\S+)");
private static final String BIN_GZ_EXT = ".bin.gz";
private File baseDirectory;
private DataSetPointDao dataSetPointDao;
private DocumentDao documentDao;
private ModelDao modelDao;
private TransactionTemplate transactionTemplate;
public OpenNLPMaxentModelExecutorTask() {
super();
}
public void setBaseDirectory(File baseDirectory) {
this.baseDirectory = baseDirectory;
}
public void setDataSetPointDao(final DataSetPointDao dataSetPointDao) {
this.dataSetPointDao = dataSetPointDao;
}
public void setDocumentDao(DocumentDao documentDao) {
this.documentDao = documentDao;
}
public void setModelDao(final ModelDao modelDao) {
this.modelDao = modelDao;
}
public void setTransactionManager(final PlatformTransactionManager transactionManager) {
final TransactionDefinition transactionDefinition = new DefaultTransactionDefinition(
DefaultTransactionDefinition.PROPAGATION_REQUIRES_NEW);
this.transactionTemplate = new TransactionTemplate(transactionManager, transactionDefinition);
}
@Override
public void execute() {
final StopWatch stopWatch = new StopWatch();
stopWatch.start();
final List<Long> modelIds = findModelsToExecute();
final int size = modelIds.size();
logger.info(String.format("Found %d models in state READY", size));
int executed = 0;
for (final Long modelId : modelIds) {
boolean success = execute(modelId);
if (success) {
executed++;
}
}
stopWatch.stop();
logger.info(String.format("Executed %d of %d models, took: %s", executed, size, stopWatch.toString()));
}
private List<Long> findModelsToExecute() {
return this.transactionTemplate.execute(new TransactionCallback<List<Long>>() {
@Override
public List<Long> doInTransaction(TransactionStatus arg0) {
final List<Model> models = OpenNLPMaxentModelExecutorTask.this.modelDao.findByTypeAndState(
ModelType.MAXENT, ModelState.READY);
final List<Long> modelIds = new ArrayList<Long>(models.size());
for (final Model model : models) {
modelIds.add(model.getId());
}
return modelIds;
}
});
}
private boolean execute(final Long modelId) {
return this.transactionTemplate.execute(new TransactionCallback<Boolean>() {
@Override
public Boolean doInTransaction(TransactionStatus arg0) {
final Model model = OpenNLPMaxentModelExecutorTask.this.modelDao.get(modelId);
return OpenNLPMaxentModelExecutorTask.this.doExecute(model);
}
});
}
private boolean doExecute(final Model model) {
final StopWatch stopWatch = new StopWatch();
stopWatch.start();
boolean result = false;
try {
logger.info(String.format("Executing model %d", model.getId()));
// load model file
final File userDirectory = new File(baseDirectory, "" + model.getUserId());
final File modelFile = new File(userDirectory, model.getId() + BIN_GZ_EXT);
final AbstractModel maxentModel = new SuffixSensitiveGISModelReader(modelFile).getModel();
final Date now = new Date();
final TimeUnits timeUnits = model.getTimeUnits();
final Timestamp topOfPeriod = new Timestamp(TimeUnits.floor(now, timeUnits).getTime());
final Timestamp endOfPeriod = new Timestamp(topOfPeriod.getTime() + timeUnits.getDuration() - 1);
// first position: sum of changes predicted, second position: number
// of samples--will calculate a boring old mean...
final double[][] changeByPeriod = new double[model.getLookahead()][2];
// initialize
for (int i = 0; i < changeByPeriod.length; i++) {
changeByPeriod[i][0] = 0.0d;
changeByPeriod[i][1] = 0.0d;
}
// find the most recent point value
// FIXME - some sparse data sets may require executing the model on
// all documents since that point or applying some sort of
// dead-reckoning logic for smoothing
final DataSetPoint initial = this.dataSetPointDao.findLastPointBefore(model.getTrainingDataSet(),
endOfPeriod);
// let's get the corner cases out of the way
if (initial == null) {
logger.warn("Insufficient data to execute model!");
return false;
}
// happy path
// build consolidated context of events in this period
// find current value of training data set for this period
final double[] probs = eval(model, topOfPeriod, endOfPeriod, maxentModel);
// predict from the last available point, adjusted for time
// remaining in period
final double y0 = initial.getY();
// map outcomes to periods in the future (at least no earlier than
// this period)
for (int i = 0; i < probs.length; i++) {
// in the form +nU:+/-x, where n is the number of periods, U is
// the unit type for the period, +/- is the direction, and x is
// a discrete value from Model.OUTCOME_ARRAY
final String outcome = maxentModel.getOutcome(i);
final Matcher matcher = OUTCOME_PATTERN.matcher(outcome);
if (!matcher.matches()) {
logger.warn("Can't process outcome: " + outcome + "; skipping");
continue;
}
final int period = Integer.valueOf(matcher.group(1));
final String units = matcher.group(2);
final double percentChange = Double.valueOf(matcher.group(3));
// record the observation and the count of observations
changeByPeriod[period][0] += percentChange;
changeByPeriod[period][1] += 1.0d;
if (logger.isDebugEnabled()) {
final double yi = y0 * (1 + percentChange);
logger.debug(String.format("Outcome: %s, %s: +%d, change: %f, new value: %f, probability: %f",
outcome, units, period, percentChange, yi, probs[i]));
}
}
// build points for predictive data set
double yn = y0;
// we need to track the points and remove any that were not
// predicted by this execution of the model
final Timestamp endOfPredictionRange = new Timestamp(topOfPeriod.getTime()
+ (changeByPeriod.length * timeUnits.getDuration()));
final List<DataSetPoint> existingPoints = this.dataSetPointDao.findByTimeRange(
model.getPredictionDataSet(), topOfPeriod, endOfPredictionRange);
for (int period = 0; period < changeByPeriod.length; period++) {
final double totalPercentChange = changeByPeriod[period][0];
final double sampleCount = changeByPeriod[period][1];
double percentChange;
if (totalPercentChange == 0.0d || sampleCount == 0.0d) {
percentChange = 0.0d;
} else {
percentChange = totalPercentChange / sampleCount;
}
// apply adjustments only if the initial point is within the
// time period, and only for the first time period
boolean applyAdjustment = period == 0 && topOfPeriod.before(initial.getX());
if (applyAdjustment) {
final double adjustmentFactor = findAdjustmentFactor(initial.getX(), timeUnits);
percentChange = (totalPercentChange / sampleCount) * adjustmentFactor;
}
// figure out the next value and coerce to a sane number of
// decimal places (2);
final double newValue = (double) Math.round(yn * (1.0d + percentChange) * 100) / 100;
final Timestamp timestamp = new Timestamp(topOfPeriod.getTime() + (period * timeUnits.getDuration()));
if (logger.isDebugEnabled()) {
logger.debug(String.format("Model %d for data set %d predicted point: (%s, %f)", model.getId(),
model.getTrainingDataSet().getId(), DateFormat.getInstance().format(timestamp), newValue));
}
DataSetPoint ithPoint = this.dataSetPointDao.findByTimestamp(model.getPredictionDataSet(), timestamp);
// conditionally create
if (ithPoint == null) {
ithPoint = new DataSetPoint(model.getPredictionDataSet(), timestamp, newValue);
this.dataSetPointDao.add(ithPoint);
} else {
// or update
ithPoint.setY(newValue);
// updated points retained, other existing removed
existingPoints.remove(ithPoint);
}
// store current and use as starting point for next iteration
yn = newValue;
}
// remove stale points from an old model execution
for (final DataSetPoint toRemove : existingPoints) {
this.dataSetPointDao.remove(toRemove);
}
result = true;
} catch (final Exception e) {
logger.warn("Model " + model.getId() + " failed with: " + e, e);
result = false;
} finally {
stopWatch.stop();
logger.info(String.format("Executing model %d took %s", model.getId(), stopWatch.toString()));
}
return result;
}
/**
* Returns 1 - the percentage of time period completed. This applies the
* percent change predicted uniformly over the time period
*
*/
private double findAdjustmentFactor(final Date date, final TimeUnits timeUnits) {
final Date floor = TimeUnits.floor(date, timeUnits);
final double dividend = (double) (date.getTime() - floor.getTime());
final double divisor = (double) timeUnits.getDuration();
final double percentTimeComplete = dividend / divisor;
return 1.0d - percentTimeComplete;
}
private double[] eval(final Model model, final Timestamp topOfPeriod, final Timestamp endOfPeriod,
final MaxentModel maxentModel) {
final ScrollableResults scrollableResults = this.documentDao.scrollableSearch(model.getUserId(), null, model
.getTopic().getSearchExpression(), SortOrder.DATE_ASC, topOfPeriod, endOfPeriod);
// initialize with an estimated size to prevent a lot of resizing
final Map<String, Integer> contextMap = new LinkedHashMap<String, Integer>(64 * 1024);
try {
while (scrollableResults.next()) {
final Object[] row = scrollableResults.get();
final Document document = (Document) row[0];
if (document == null) {
// caused by stale index
continue;
} else {
OpenNLPMaxentContextBuilder.append(contextMap, document);
}
}
} finally {
scrollableResults.close();
}
final String[] context = new String[contextMap.size()];
final float[] values = new float[contextMap.size()];
// copy map to arrays
OpenNLPMaxentContextBuilder.buildContextArrays(contextMap, context, values);
// eval
return maxentModel.eval(context, values);
}
}