/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package edu.columbia.stat.wood.bnol;
import edu.columbia.stat.wood.bnol.hpyp.HPYP;
import edu.columbia.stat.wood.bnol.hpyp.IntHPYP;
import edu.columbia.stat.wood.bnol.util.Context;
import edu.columbia.stat.wood.bnol.util.GammaDistribution;
import edu.columbia.stat.wood.bnol.util.IntGeometricDistribution;
import edu.columbia.stat.wood.bnol.util.MersenneTwisterFast;
import edu.columbia.stat.wood.bnol.util.MutableDouble;
import edu.columbia.stat.wood.bnol.util.MutableInt;
import edu.columbia.stat.wood.bnol.util.SampleWithoutReplacement;
import gnu.trove.list.array.TIntArrayList;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map.Entry;
/**
* Object to represent the complex deterministic program used to generate the
* next emission state. The start state is 1 and all states are positive
* integers.
* @author nicholasbartlett
*/
public class Machine implements Serializable{
private static final long serialVersionUID = 1 ;
private HashMap<StateEmissionPair, MutableInt> delta = new HashMap();
private HPYP prior;
private int key, H;
private MersenneTwisterFast rng;
/***********************constructor methods********************************/
public Machine(int key, int H, double p){
this.key = key;
this.H = H;
//rng = new MersenneTwisterFast(7);
rng = BNOL.rng;
// the number 11 here is totally arbitrary
MutableDouble[] discounts = new MutableDouble[11];
MutableDouble[] concentrations = new MutableDouble[11];
for(int i = 0; i < discounts.length; i++){
discounts[discounts.length - 1 - i] = new MutableDouble(Math.pow(0.7, i + 1));
concentrations[i] = new MutableDouble(1.0);
}
prior = new IntHPYP(discounts, concentrations, new IntGeometricDistribution(p,1), new GammaDistribution(1,100));
}
public Machine(){};
/***********************public methods*************************************/
/**
* Gets the next machine state given the previous emissions. Each machine
* starts in a given start state (1) and then transitions deterministically
* to give the output.
* @param emissions emissions
* @param index index of time at which we want to get the next machine state
* @return next machine state
*/
public int get(int[][] emissions, int index){
return get(emissions, index, null);
}
/**
* Samples the underlying machine, i.e. the delta matrix, and then samples
* the HPYP being used as a prior on that delta matrix.
* @param emissions entire set of emissions for data
* @param machineKeys array of which machine is used at each step
* @param emissionDistributions emission distributions for each machine state
* @param sweeps number of MH sweeps
* @param temp temperature of sampling steps
* @return joint log likelihood
*/
public double sample(int[][] emissions, int[] machineKeys, S_EmissionDistribution emissionDistributions, int sweeps, double temp) {
// get indices for this particular machine
int[] indices = getIndices(machineKeys);
// clean delta matrix first
clean(emissions, indices);
double logEvidence = 0;
for (int sweep = 0; sweep < sweeps; sweep++) {
// sample the prior hpyp 5 times, which is an arbitrary number
for (int j = 0; j < 5; j++) {
prior.sample(temp);
}
// copy the keys and values of the current delta matrix for sampling
StateEmissionPair[] keys = new StateEmissionPair[delta.size()];
MutableInt[] values = new MutableInt[delta.size()];
int i = 0;
int[] randomIndex = SampleWithoutReplacement.sampleWithoutReplacement(delta.size(), rng);
for(Entry<StateEmissionPair, MutableInt> entry : delta.entrySet()){
keys[randomIndex[i]] = entry.getKey();
values[randomIndex[i++]] = entry.getValue();
}
// go through each mapped key value pair and sample them
logEvidence = logEvidence(emissions, emissionDistributions, indices);
for (int j = 0; j < keys.length; j++) {
int[] context = keys[j].emission;
int currentValue = values[j].value();
prior.unseat(context, currentValue);
int proposal = prior.draw(context);
values[j].set(proposal);
double proposedLogEvidence = logEvidence(emissions, emissionDistributions, indices);
double r = Math.exp(proposedLogEvidence - logEvidence);
r = Math.pow(r < 1.0 ? r : 1.0, 1.0 / temp);
if (rng.nextBoolean(r)) {
prior.unseat(context, proposal);
prior.seat(context, currentValue);
values[j].set(currentValue);
//System.out.println("1");
int[] oldMachineKeys = getAllMachineStates(emissions, indices);
prior.unseat(context, currentValue);
prior.seat(context, proposal);
values[j].set(proposal);
//System.out.println("2");
int[] newMachineKeys = getAllMachineStates(emissions, indices);
//System.out.println("3");
adjustEmissionsDistributions(oldMachineKeys, newMachineKeys, emissionDistributions, emissions, indices);
//System.out.println("4");
logEvidence = proposedLogEvidence;
} else {
//System.out.println("5");
prior.unseat(context, proposal);
//System.out.println("6");
prior.seat(context, currentValue);
//System.out.println("7");
values[j].set(currentValue);
//System.out.println("8");
}
}
assert checkCounts();
// clean the delta matrix
clean(emissions, indices);
}
return prior.sample(temp);
}
/**
* Gets the joint score of the HPYP and data.
* @param emissions emission data
* @param emissionDistributions emission distributions
* @param machineKeys machine keys for emissions
* @return score
*/
public double score(int[][] emissions, S_EmissionDistribution emissionDistributions, int[] machineKeys){
return prior.score(true) + logEvidence(emissions, emissionDistributions, getIndices(machineKeys));
}
/**
* Removes from the delta map any entries which are not used given the data.
* @param emissions emission data
* @param indices indices of emission from this machine
*/
public void clean(int[][] emissions, int[] indices){
HashMap<StateEmissionPair, MutableInt> newDelta = new HashMap();
for(int i = 0; i < indices.length; i++){
get(emissions, indices[i], newDelta);
}
for(StateEmissionPair deltaKey : delta.keySet()){
if(newDelta.get(deltaKey) == null){
prior.unseat(deltaKey.emission, delta.get(deltaKey).value());
}
}
delta = newDelta;
prior.removeEmptyNodes();
}
/**
* Checks the counts in the machine to make sure that the number of customers
* in the HPYP and their locations are correct given the delta map.
* @return true if counts are in agreement
*/
public boolean checkCounts(){
prior.removeEmptyNodes();
HashMap<Context, MutableInt> data = prior.getImpliedData();
for(StateEmissionPair deltaKey : delta.keySet()){
data.get(new Context(deltaKey.emission)).decrement();
}
for(MutableInt value : data.values()){
if(value.value() != 0){
return false;
}
}
return true;
}
/***********************private methods************************************/
private int[] getAllMachineStates(int[][] emissions, int[] indices){
int[] machineStates = new int[indices.length];
for(int i = 0; i < indices.length; i++){
machineStates[i] = get(emissions, indices[i]);
}
return machineStates;
}
private void adjustEmissionsDistributions(int[] oldMachineStates, int[] newMachineStates, S_EmissionDistribution emissionDistributions, int[][] emissions, int[] indices){
assert oldMachineStates.length == newMachineStates.length;
for(int i = 0; i < oldMachineStates.length; i++){
if(oldMachineStates[i] != newMachineStates[i]){
emissionDistributions.unseat(oldMachineStates[i], emissions[indices[i]]);
emissionDistributions.seat(newMachineStates[i], emissions[indices[i]]);
}
}
}
/**
* Gets the next machine state given the previous emissions. Each machine
* starts in a given start state (1) and then transitions deterministically
* to give the output.
* @param emissions emissions
* @param index index of time at which we want to get the next machine state
* @param newDelta hash map for new delta if this is during a cleaning step
* @return next machine state
*/
private int get(int[][] emissions, int index, HashMap<StateEmissionPair, MutableInt> newDelta){
int machineState = 1;
int contextLength = H < index ? H : index;
for(int i = 0; i < contextLength; i++){
machineState = deltaGet(new StateEmissionPair(machineState, emissions[index - contextLength + i]), newDelta);
}
return machineState;
}
/**
* Gets the indices into the argument arrays which pertain to this machine.
* @param machineKeys array of machine keys
* @return array of indices pertaining to this machine
*/
private int[] getIndices(int[] machineKeys){
TIntArrayList indices = new TIntArrayList();
for(int i = 0; i < machineKeys.length; i++){
if(machineKeys[i] == key){
indices.add(i);
}
}
return indices.toArray();
}
/**
* Does a get from the delta map, but if nothing is found it makes a draw
* from the prior and adds it to the map.
* @param key key to get
* @param newDelta hash map for new delta if this is during a cleaning step
* @return retrieved or generated value
*/
private int deltaGet(StateEmissionPair key, HashMap<StateEmissionPair, MutableInt> newDelta){
MutableInt value = delta.get(key);
if(value == null){
int machineState = prior.draw(key.emission);
delta.put(key,value = new MutableInt(machineState));
}
if(newDelta != null){
newDelta.put(key, value);
}
return value.value();
}
/**
* Gets the log evidence of the particular delta configuration given the
* emissions and the emission distributions.
* @param emissions emission data
* @param emissionDistributions emission distributions for each machine state
* @param indices indices where this machine is used
* @return log evidence
*/
private double logEvidence(int[][] emissions, S_EmissionDistribution emissionDistributions, int[] indices){
double logEvidence = 0.0;
for(int i = 0; i < indices.length; i++){
logEvidence += emissionDistributions.logProbability(get(emissions, indices[i]), emissions[indices[i]]);
}
return logEvidence;
}
/***********************private classes************************************/
/**
* Convenient class to hold the state emission pairs for the delta map.
*/
private static class StateEmissionPair implements Serializable{
int state;
int[] emission;
/***********************constructor methods****************************/
/**
* Constructor for state emission pair which sets the internal fields.
* @param state state
* @param emission emission
*/
public StateEmissionPair(int state, int[] emission){
this.state = state;
this.emission = emission;
}
/***********************public methods*********************************/
/**
* {@inheritDoc}
*/
@Override
public boolean equals(Object o){
if(o == null || o.getClass() != getClass()){
return false;
} else {
StateEmissionPair oo = (StateEmissionPair) o;
if(Arrays.equals(emission, oo.emission) && state == oo.state){
return true;
} else {
return false;
}
}
}
/*
* {@inheritDoc}
*/
@Override
public int hashCode() {
int hash = 5;
hash = 37 * hash + this.state;
hash = 37 * hash + Arrays.hashCode(this.emission);
return hash;
}
}
}