package edu.stanford.nlp.sequences;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.ling.HasWord;
import java.util.Random;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
import java.io.PrintStream;
// TODO: change so that it uses the scoresOf() method properly
/**
* A Gibbs sampler for sequence models. Given a sequence model implementing the SequenceModel
* interface, this class is capable of
* sampling sequences from the distribution over sequences that it defines. It can also use
* this sampling procedure to find the best sequence.
* @author grenager
*/
public class SequenceGibbsSampler implements BestSequenceFinder {
// a random number generator
private static Random random = new Random();
public static int verbose = 0;
private List document;
private int numSamples;
private int sampleInterval;
private SequenceListener listener;
public boolean returnLastFoundSequence = false;
public static int[] copy(int[] a) {
int[] result = new int[a.length];
System.arraycopy(a, 0, result, 0, a.length);
return result;
}
public static int[] getRandomSequence(SequenceModel model) {
int[] result = new int[model.length()];
for (int i = 0; i < result.length; i++) {
int[] classes = model.getPossibleValues(i);
result[i] = classes[random.nextInt(classes.length)];
}
return result;
}
/**
* Finds the best sequence by collecting numSamples samples, scoring them, and then choosing
* the highest scoring sample.
* @return the array of type int representing the highest scoring sequence
*/
public int[] bestSequence(SequenceModel model) {
int[] initialSequence = getRandomSequence(model);
return findBestUsingSampling(model, numSamples, sampleInterval, initialSequence);
}
/**
* Finds the best sequence by collecting numSamples samples, scoring them, and then choosing
* the highest scoring sample.
* @param numSamples
* @param sampleInterval
* @return the array of type int representing the highest scoring sequence
*/
public int[] findBestUsingSampling(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
List samples = collectSamples(model, numSamples, sampleInterval, initialSequence);
int[] best = null;
double bestScore = Double.NEGATIVE_INFINITY;
for (int i = 0; i < samples.size(); i++) {
int[] sequence = (int[]) samples.get(i);
double score = model.scoreOf(sequence);
if (score>bestScore) {
best = sequence;
bestScore = score;
System.err.println("found new best ("+bestScore+")");
System.err.println(ArrayMath.toString(best));
}
}
return best;
}
public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule) {
int[] initialSequence = getRandomSequence(model);
return findBestUsingAnnealing(model, schedule, initialSequence);
}
public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule, int[] initialSequence) {
if (verbose>0) System.err.println("Doing annealing");
listener.setInitialSequence(initialSequence);
List result = new ArrayList();
int[] sequence = initialSequence;
int[] best = null;
double bestScore = Double.NEGATIVE_INFINITY;
double score = Double.NEGATIVE_INFINITY;
if (!returnLastFoundSequence) {
score = model.scoreOf(sequence);
}
for (int i=0; i<schedule.numIterations(); i++) {
sequence = copy(sequence); // so we don't change the initial, or the one we just stored
double temperature = schedule.getTemperature(i);
sampleSequenceForward(model, sequence, temperature); // modifies tagSequence
result.add(sequence);
if (returnLastFoundSequence) {
best = sequence;
} else {
score = model.scoreOf(sequence);
//System.err.println(i+" "+score+" "+Arrays.toString(sequence));
if (score>bestScore) {
best = sequence;
bestScore = score;
}
}
if (verbose>0) System.err.print(".");
}
if (verbose>1) {
System.err.println();
printSamples(result, System.err);
}
if (verbose>0) System.err.println("done.");
//return sequence;
return best;
}
/**
* Collects numSamples samples of sequences, from the distribution over sequences defined
* by the sequence model passed on construction.
* All samples collected are sampleInterval samples apart, in an attempt to reduce
* autocorrelation.
* @param numSamples
* @param sampleInterval
* @return a List containing the sequence samples, as arrays of type int, and their scores
*/
public List<int[]> collectSamples(SequenceModel model, int numSamples, int sampleInterval) {
int[] initialSequence = getRandomSequence(model);
return collectSamples(model, numSamples, sampleInterval, initialSequence);
}
/**
* Collects numSamples samples of sequences, from the distribution over sequences defined
* by the sequence model passed on construction.
* All samples collected are sampleInterval samples apart, in an attempt to reduce
* autocorrelation.
* @param numSamples
* @param sampleInterval
* @param initialSequence
* @return a Counter containing the sequence samples, as arrays of type int, and their scores
*/
public List<int[]> collectSamples(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
if (verbose>0) System.err.print("Collecting samples");
listener.setInitialSequence(initialSequence);
List<int[]> result = new ArrayList<int[]>();
int[] sequence = initialSequence;
for (int i=0; i<numSamples; i++) {
sequence = copy(sequence); // so we don't change the initial, or the one we just stored
sampleSequenceRepeatedly(model, sequence, sampleInterval); // modifies tagSequence
result.add(sequence); // save it to return later
if (verbose>0) System.err.print(".");
System.err.flush();
}
if (verbose>1) {
System.err.println();
printSamples(result, System.err);
}
if (verbose>0) System.err.println("done.");
return result;
}
/**
* Samples the sequence repeatedly, making numSamples passes over the entire sequence.
* @param sequence
* @param numSamples
*/
public void sampleSequenceRepeatedly(SequenceModel model, int[] sequence, int numSamples) {
sequence = copy(sequence); // so we don't change the initial, or the one we just stored
listener.setInitialSequence(sequence);
for (int iter=0; iter<numSamples; iter++) {
sampleSequenceForward(model, sequence);
}
}
/**
* Samples the sequence repeatedly, making numSamples passes over the entire sequence.
* Destructively modifies the sequence in place.
* @param numSamples
*/
public void sampleSequenceRepeatedly(SequenceModel model, int numSamples) {
int[] sequence = getRandomSequence(model);
sampleSequenceRepeatedly(model, sequence, numSamples);
}
/**
* Samples the complete sequence once in the forward direction
* Destructively modifies the sequence in place.
* @param sequence the sequence to start with.
*/
public void sampleSequenceForward(SequenceModel model, int[] sequence) {
sampleSequenceForward(model, sequence, 1.0);
}
/**
* Samples the complete sequence once in the forward direction
* Destructively modifies the sequence in place.
* @param sequence the sequence to start with.
*/
public void sampleSequenceForward(SequenceModel model, int[] sequence, double temperature) {
// System.err.println("Sampling forward");
for (int pos=0; pos<sequence.length; pos++) {
samplePosition(model, sequence, pos, temperature);
}
}
/**
* Samples the complete sequence once in the backward direction
* Destructively modifies the sequence in place.
* @param sequence the sequence to start with.
*/
public void sampleSequenceBackward(SequenceModel model, int[] sequence) {
sampleSequenceBackward(model, sequence, 1.0);
}
/**
* Samples the complete sequence once in the backward direction
* Destructively modifies the sequence in place.
* @param sequence the sequence to start with.
*/
public void sampleSequenceBackward(SequenceModel model, int[] sequence, double temperature) {
for (int pos=sequence.length-1; pos>=0; pos--) {
samplePosition(model, sequence, pos, temperature);
}
}
/**
* Samples a single position in the sequence.
* Destructively modifies the sequence in place.
* returns the score of the new sequence
* @param sequence the sequence to start with
* @param pos the position to sample.
*/
public double samplePosition(SequenceModel model, int[] sequence, int pos) {
return samplePosition(model, sequence, pos, 1.0);
}
/**
* Samples a single position in the sequence.
* Destructively modifies the sequence in place.
* returns the score of the new sequence
* @param sequence the sequence to start with
* @param pos the position to sample.
*/
public double samplePosition(SequenceModel model, int[] sequence, int pos, double temperature) {
double[] distribution = model.scoresOf(sequence, pos);
if (temperature!=1.0) {
if (temperature==0.0) {
// set the max to 1.0
int argmax = ArrayMath.argmax(distribution);
Arrays.fill(distribution, Double.NEGATIVE_INFINITY);
distribution[argmax] = 0.0;
} else {
// take all to a power
// use the temperature to increase/decrease the entropy of the sampling distribution
ArrayMath.multiplyInPlace(distribution, 1.0/temperature);
}
}
ArrayMath.logNormalize(distribution);
ArrayMath.expInPlace(distribution);
int oldTag = sequence[pos];
int newTag = ArrayMath.sampleFromDistribution(distribution, random);
// System.out.println("Sampled " + oldTag + "->" + newTag);
sequence[pos] = newTag;
listener.updateSequenceElement(sequence, pos, oldTag);
return distribution[newTag];
}
public void printSamples(List samples, PrintStream out) {
for (int i = 0; i < document.size(); i++) {
HasWord word = (HasWord) document.get(i);
String s = "null";
if (word!=null) {
s = word.word();
}
out.print(StringUtils.padOrTrim(s, 10));
for (int j = 0; j < samples.size(); j++) {
int[] sequence = (int[]) samples.get(j);
out.print(" " + StringUtils.padLeft(sequence[i], 2));
}
out.println();
}
}
/**
* @param numSamples
* @param sampleInterval
* @param document the underlying document which is a list of HasWord; a slight abstraction violation, but useful for debugging!!
*/
public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document, boolean returnLastFoundSequence) {
this.numSamples = numSamples;
this.sampleInterval = sampleInterval;
this.listener = listener;
this.document = document;
this.returnLastFoundSequence = returnLastFoundSequence;
}
public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document) {
this(numSamples, sampleInterval, listener, document, false);
}
public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener) {
this(numSamples, sampleInterval, listener, null);
}
}