package edu.stanford.nlp.loglinear.benchmarks;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.loglinear.inference.CliqueTree;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorNamespace;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.io.IOException;
import java.util.*;
/**
* Created on 9/11/15.
* @author keenon
* <p>
* This simulates game-player-like activity, with a few CoNLL CliqueTrees playing host to lots and lots of manipulations
* by adding and removing human "observations". In real life, this kind of behavior occurs during sampling lookahead for
* LENSE-like systems.
* <p>
* In order to measure only the realistic parts of behavior, and not the random generation of numbers, we pre-cache a
* few hundred ConcatVectors representing human obs features, then our feature function is just indexing into that cache.
* The cache is designed to require a bit of L1 cache eviction to page through, so that we don't see artificial speed
* gains during dot products b/c we already have both features and weights in L1 cache.
*/
public class GamePlayerBenchmark {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(GamePlayerBenchmark.class);
static final String DATA_PATH = "/u/nlp/data/ner/conll/";
public static void main(String[] args) throws IOException, ClassNotFoundException {
//////////////////////////////////////////////////////////////
// Generate the CoNLL CliqueTrees to use during gameplay
//////////////////////////////////////////////////////////////
CoNLLBenchmark coNLL = new CoNLLBenchmark();
List<CoNLLBenchmark.CoNLLSentence> train = coNLL.getSentences(DATA_PATH + "conll.iob.4class.train");
List<CoNLLBenchmark.CoNLLSentence> testA = coNLL.getSentences(DATA_PATH + "conll.iob.4class.testa");
List<CoNLLBenchmark.CoNLLSentence> testB = coNLL.getSentences(DATA_PATH + "conll.iob.4class.testb");
List<CoNLLBenchmark.CoNLLSentence> allData = new ArrayList<>();
allData.addAll(train);
allData.addAll(testA);
allData.addAll(testB);
Set<String> tagsSet = new HashSet<>();
for (CoNLLBenchmark.CoNLLSentence sentence : allData) for (String nerTag : sentence.ner) tagsSet.add(nerTag);
List<String> tags = new ArrayList<>();
tags.addAll(tagsSet);
coNLL.embeddings = coNLL.getEmbeddings(DATA_PATH + "google-300-trimmed.ser.gz", allData);
log.info("Making the training set...");
ConcatVectorNamespace namespace = new ConcatVectorNamespace();
int trainSize = train.size();
GraphicalModel[] trainingSet = new GraphicalModel[trainSize];
for (int i = 0; i < trainSize; i++) {
if (i % 10 == 0) {
log.info(i + "/" + trainSize);
}
trainingSet[i] = coNLL.generateSentenceModel(namespace, train.get(i), tags);
}
//////////////////////////////////////////////////////////////
// Generate the random human observation feature vectors that we'll use
//////////////////////////////////////////////////////////////
Random r = new Random(10);
int numFeatures = 5;
int featureLength = 30;
ConcatVector[] humanFeatureVectors = new ConcatVector[1000];
for (int i = 0; i < humanFeatureVectors.length; i++) {
humanFeatureVectors[i] = new ConcatVector(numFeatures);
for (int j = 0; j < numFeatures; j++) {
if (r.nextBoolean()) {
humanFeatureVectors[i].setSparseComponent(j, r.nextInt(featureLength), r.nextDouble());
} else {
double[] dense = new double[featureLength];
for (int k = 0; k < dense.length; k++) {
dense[k] = r.nextDouble();
}
humanFeatureVectors[i].setDenseComponent(j, dense);
}
}
}
ConcatVector weights = new ConcatVector(numFeatures);
for (int i = 0; i < numFeatures; i++) {
double[] dense = new double[featureLength];
for (int j = 0; j < dense.length; j++) dense[j] = r.nextDouble();
weights.setDenseComponent(i, dense);
}
//////////////////////////////////////////////////////////////
// Actually perform gameplay-like random mutations
//////////////////////////////////////////////////////////////
log.info("Warming up the JIT...");
for (int i = 0; i < 10; i++) {
log.info(i);
gameplay(r, trainingSet[i], weights, humanFeatureVectors);
}
log.info("Timing actual run...");
long start = System.currentTimeMillis();
for (int i = 0; i < 10; i++) {
log.info(i);
gameplay(r, trainingSet[i], weights, humanFeatureVectors);
}
long duration = System.currentTimeMillis() - start;
log.info("Duration: " + duration);
}
//////////////////////////////////////////////////////////////
// This is an implementation of something like MCTS, trying to take advantage of the general speed gains due to fast
// CliqueTree caching of dot products. It doesn't actually do any clever selection, preferring to select observations
// at random.
//////////////////////////////////////////////////////////////
private static void gameplay(Random r, GraphicalModel model, ConcatVector weights, ConcatVector[] humanFeatureVectors) {
List<Integer> variablesList = new ArrayList<>();
List<Integer> variableSizesList = new ArrayList<>();
for (GraphicalModel.Factor f : model.factors) {
for (int i = 0; i < f.neigborIndices.length; i++) {
int j = f.neigborIndices[i];
if (!variablesList.contains(j)) {
variablesList.add(j);
variableSizesList.add(f.featuresTable.getDimensions()[i]);
}
}
}
int[] variables = variablesList.stream().mapToInt(i -> i).toArray();
int[] variableSizes = variableSizesList.stream().mapToInt(i -> i).toArray();
List<SampleState> childrenOfRoot = new ArrayList<>();
CliqueTree tree = new CliqueTree(model, weights);
int initialFactors = model.factors.size();
// Run some "samples"
long start = System.currentTimeMillis();
long marginalsTime = 0;
for (int i = 0; i < 1000; i++) {
log.info("\tTaking sample " + i);
Stack<SampleState> stack = new Stack<>();
SampleState state = selectOrCreateChildAtRandom(r, model, variables, variableSizes, childrenOfRoot, humanFeatureVectors);
long localMarginalsTime = 0;
// Each "sample" is 10 moves deep
for (int j = 0; j < 10; j++) {
// log.info("\t\tFrame "+j);
state.push(model);
assert (model.factors.size() == initialFactors + j + 1);
///////////////////////////////////////////////////////////
// This is the thing we're really benchmarking
///////////////////////////////////////////////////////////
if (state.cachedMarginal == null) {
long s = System.currentTimeMillis();
state.cachedMarginal = tree.calculateMarginalsJustSingletons();
localMarginalsTime += System.currentTimeMillis() - s;
}
stack.push(state);
state = selectOrCreateChildAtRandom(r, model, variables, variableSizes, state.children, humanFeatureVectors);
}
log.info("\t\t" + localMarginalsTime + " ms");
marginalsTime += localMarginalsTime;
while (!stack.empty()) {
stack.pop().pop(model);
}
assert (model.factors.size() == initialFactors);
}
log.info("Marginals time: " + marginalsTime + " ms");
log.info("Avg time per marginal: " + (marginalsTime / 200) + " ms");
log.info("Total time: " + (System.currentTimeMillis() - start));
}
private static SampleState selectOrCreateChildAtRandom(Random r,
GraphicalModel model,
int[] variables,
int[] variableSizes,
List<SampleState> children,
ConcatVector[] humanFeatureVectors) {
int i = r.nextInt(variables.length);
int variable = variables[i];
int observation = r.nextInt(variableSizes[i]);
for (SampleState s : children) {
if (s.variable == variable && s.observation == observation) return s;
}
int humanObservationVariable = 0;
for (GraphicalModel.Factor f : model.factors) {
for (int j : f.neigborIndices) {
if (j >= humanObservationVariable) humanObservationVariable = j + 1;
}
}
GraphicalModel.Factor f = model.addFactor(new int[]{variable, humanObservationVariable}, new int[]{variableSizes[i], variableSizes[i]}, (assn) -> {
int j = (assn[0] * variableSizes[i]) + assn[1];
return humanFeatureVectors[j];
});
model.factors.remove(f);
SampleState newState = new SampleState(f, variable, observation);
children.add(newState);
return newState;
}
public static class SampleState {
public GraphicalModel.Factor addedFactor;
public int variable;
public int observation;
public List<SampleState> children = new ArrayList<>();
public double[][] cachedMarginal = null;
public SampleState(GraphicalModel.Factor addedFactor, int variable, int observation) {
this.addedFactor = addedFactor;
this.variable = variable;
this.observation = observation;
}
/**
* This applies this SampleState to the model. The name comes from an analogy to a stack. If we take a sample
* path, involving a number of steps through the model, we push() each SampleState onto the model one at a time,
* then when we return from the sample we can pop() each SampleState off the model, and be left with our
* original model state.
*
* @param model the model to push this SampleState onto
*/
public void push(GraphicalModel model) {
assert (!model.factors.contains(addedFactor));
model.factors.add(addedFactor);
model.getVariableMetaDataByReference(variable).put(CliqueTree.VARIABLE_OBSERVED_VALUE, "" + observation);
}
/**
* See push() for an explanation.
*
* @param model the model to pop this SampleState from
*/
public void pop(GraphicalModel model) {
assert (model.factors.contains(addedFactor));
model.factors.remove(addedFactor);
model.getVariableMetaDataByReference(variable).remove(CliqueTree.VARIABLE_OBSERVED_VALUE);
}
}
}