package context.arch.intelligibility.hmm; import java.io.File; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; import java.io.Reader; import java.util.ArrayList; import java.util.List; import be.ac.ulg.montefiore.run.jahmm.Hmm; import be.ac.ulg.montefiore.run.jahmm.ObservationInteger; import be.ac.ulg.montefiore.run.jahmm.ObservationVector; import be.ac.ulg.montefiore.run.jahmm.io.FileFormatException; import be.ac.ulg.montefiore.run.jahmm.io.HmmWriter; import be.ac.ulg.montefiore.run.jahmm.io.ObservationIntegerReader; import be.ac.ulg.montefiore.run.jahmm.io.ObservationSequencesReader; import be.ac.ulg.montefiore.run.jahmm.io.ObservationVectorReader; import be.ac.ulg.montefiore.run.jahmm.io.OpdfVector; import be.ac.ulg.montefiore.run.jahmm.io.OpdfVectorWriter; /** * To learn a HMM from supervised data. This class was written since JAHMM only had classes to do unsupervised learning. * Currently only limited to ObservationVector as observation. * @author Brian Y. Lim * */ public class HmmSupervisedLearner { protected List<ObservationVector> trainingObservations; protected List<ObservationInteger> trainingStates; /** number of hidden states */ protected int NUM_STATES; /** number of observations in a sequence, T (i.e. length of sequence) for the training data */ protected int NUM_OBSERVATIONS; /** number of sensors per observation; number of dimensions in observation */ protected int NUM_OBSERVATION_DIM; /** number of values an observation sensor can take; 2=binary for Kasteren dataset */ protected int NUM_OBSERVATION_VALS; /** number of permutations of observation vectors; should be 2^NUM_OBSERVATIONS_DIM for binary sensors; NUM_OBSERVATION_VALS^NUM_OBSERVATIONS_DIM in general */ protected int NUM_OBSERVATION_PERMS; /* * The following fields are counts to help calculate the HMM probability parameters. * Measured from training set. */ /** from i to j; should be [N_STATES][N_STATES] square matrix; to calculate A parameter */ protected int[][] N_STATE_TO_STATE; /** number of times state[j] occurred in training set */ protected int[] N_STATE; /** sequence length, T, used when doing inferencing */ // protected int SEQUENCE_LENGTH; /** number of times state[j] leads to observation[k]; where k is the integer form of the double[] binary */ protected int[][] N_STATE_TO_OBS; // [NUM_STATE][NUM_OBSERVATION_VALS] /** number of times state[j] leads to attribute[r]; dimension_r=1; assumes independence across attributes */ protected int[][] N_STATE_TO_ATTR; // [NUM_STATE][NUM_OBSERVATION_DIM] /* * The following are HMM parameters */ protected double[] pi; // to model probabilities of states at t=1 protected double[][] a; // to model A matrix: state transition probabilities // protected Map<DoubleArrayWrapper, Double>[] b; // one map for each state; using a map instead of an array so that I don't have to worry about array maintenance/logistics protected double[][] b; // to model B matrix: emission probabilities from states to observations protected double[][] b_naive; // simplified/modified emission probabilities; see comments at top of class public HmmSupervisedLearner(int numStates, int obsDimension, int numObservationValues) { // N = NUM_STATES = numStates; NUM_OBSERVATION_DIM = obsDimension; NUM_OBSERVATION_VALS = numObservationValues; // 2; specified in dataset } public Hmm<ObservationVector> learn(List<ObservationVector> trainingObservations, List<ObservationInteger> trainingStates) { this.trainingObservations = trainingObservations; this.trainingStates = trainingStates; initCounters(); generateCountsFromTraining(); calculateHmmParams(); Hmm<ObservationVector> hmm = new Hmm<ObservationVector>(pi, a, generateOpdfVectors()); return hmm; } public Hmm<ObservationVector> learn(File observationSequencesFile, File stateSequencesFile) { List<ObservationVector> trainingObservations = readObservationsSequencesFromFile(observationSequencesFile, NUM_OBSERVATION_DIM); List<ObservationInteger> trainingStates = readStateSequencesFromFile(stateSequencesFile); return learn(trainingObservations, trainingStates); } protected List<OpdfVector> generateOpdfVectors() { List<OpdfVector> opdfVectors = new ArrayList<OpdfVector>(); for (int state = 0; state < NUM_STATES; state++) { OpdfVector v = new OpdfVector(b[state]); opdfVectors.add(v); } return opdfVectors; } public void printPi() { System.out.println("Pi = ["); for (int i = 0; i < pi.length; i++) { System.out.println(pi[i] + "\t"); } System.out.println("]"); } public void printA() { System.out.println("A = ["); for (int i = 0; i < a.length; i++) { for (int j = 0; j < a[i].length; j++) { System.out.print(a[i][j] + "\t"); } System.out.println(); } System.out.println("]"); } /** * Use this sparingly, as it takes up a lot of space in the console and may overflow such that previous output is not viewable. * TODO: want to have one B per attribute for leave-one-attribute explanation strategy */ public void printB() { System.out.println("B = ["); for (int j = 0; j < b.length; j++) { // for (int k = 0; k < b[j].size(); k++) { for (int k = 0; k < NUM_OBSERVATION_PERMS; k++) { // double[] observationVal = generateIntegerVector(k, NUM_OBSERVATION_VALS, NUM_OBSERVATION_DIM); // DoubleArrayWrapper observationValWrapper = new DoubleArrayWrapper(observationVal); System.out.print(b[j][k] + "\t"); } System.out.println(); } System.out.println("]"); } public void printB_naive() { System.out.println("B_naive = ["); for (int j = 0; j < b_naive.length; j++) { for (int f = 0; f < NUM_OBSERVATION_DIM; f++) { System.out.print(b_naive[j][f] + "\t"); } System.out.println(); } System.out.println("]"); } protected void initCounters() { NUM_OBSERVATIONS = trainingObservations.size(); // T = SEQUENCE_LENGTH = 5; // N_pow_T = Math.pow(N, T); // n = NUM_OBSERVATION_DIM; N_STATE_TO_STATE = new int[NUM_STATES][NUM_STATES]; N_STATE = new int[NUM_STATES]; NUM_OBSERVATION_PERMS = (int)Math.pow(2, NUM_OBSERVATION_DIM); N_STATE_TO_OBS = new int[NUM_STATES][NUM_OBSERVATION_PERMS]; N_STATE_TO_ATTR = new int[NUM_STATES][NUM_OBSERVATION_DIM]; } /** * Prepares the counts from training data before calculating the HMM parameters */ public void generateCountsFromTraining() { // iterate through observations together with states int prevState = -1; for (int obs = 0; obs < NUM_OBSERVATIONS; obs++) { // count states and state transitions ObservationInteger stateObj = trainingStates.get(obs); // state(t) int state = stateObj.value; if (obs == 0) { prevState = state; } N_STATE_TO_STATE[prevState][state]++; // increment count for transition from prevState to state N_STATE[state]++; // increment count for this state; do for all t // count emissions of state to observation ObservationVector observation = trainingObservations.get(obs); double[] observationVal = observation.values(); // vector of form e.g.: [0 0 1 1 0 1 0 ...] // DoubleArrayWrapper observationValWrapper = new DoubleArrayWrapper(observationVal); // Map<DoubleArrayWrapper, Integer> n_state_to_obs = N_STATE_TO_OBS[state]; // Integer origCount = n_state_to_obs.get(observationValWrapper); // if (origCount == null) { origCount = 0; } // n_state_to_obs.put(observationValWrapper, origCount + 1); // increment emission count for this observation vector N_STATE_TO_OBS[state][OpdfVector.getIntegerEquivalent(observationVal, NUM_OBSERVATION_VALS)]++; // increment emission count for this observation // iterate through features to see which is activated for (int f = 0; f < observationVal.length; f++) { if (observationVal[f] == 1) { N_STATE_TO_ATTR[state][f]++; } } // update prevState to current state prevState = state; } } /** * Calculate the HMM parameters: Pi, A, B */ public void calculateHmmParams() { // calculate Pi, the probabilities for states at t=1 pi = new double[NUM_STATES]; for (int i = 0; i < NUM_STATES; i++) { int count = N_STATE[i]; count = count == 0 ? 1 : count; // Laplace smoothing to prevent log(0) pi[i] = ((double)count) / NUM_OBSERVATIONS; } // calculate A, the state transition probabilities a = new double[NUM_STATES][NUM_STATES]; for (int i = 0; i < NUM_STATES; i++) { for (int j = 0; j < NUM_STATES; j++) { int count = N_STATE_TO_STATE[i][j]; count = count == 0 ? 1 : count; // Laplace smoothing to prevent log(0) a[i][j] = ((double)count) / N_STATE[i]; } } // calculate B, the emission probabilities of a state resulting in an observation // b = new Map[NUM_STATES]; b = new double[NUM_STATES][NUM_OBSERVATION_PERMS]; b_naive = new double[NUM_STATES][NUM_OBSERVATION_DIM]; for (int j = 0; j < NUM_STATES; j++) { // b[j] = new HashMap<DoubleArrayWrapper, Double>(); // Map<DoubleArrayWrapper, Integer> n_state_to_obs = N_STATE_TO_OBS[j]; // calculate across all observation permutations for (int k = 0; k < NUM_OBSERVATION_PERMS; k++) { // double[] observationVal = generateIntegerVector(k, NUM_OBSERVATION_VALS, NUM_OBSERVATION_DIM); // DoubleArrayWrapper observationValWrapper = new DoubleArrayWrapper(observationVal); // Integer count = n_state_to_obs.get(observationValWrapper); // if (count == null) { count = 0; } double count = N_STATE_TO_OBS[j][k]; count = count == 0 ? 1e-15 : count; // Laplace smoothing to prevent log(0) double prob = count / N_STATE[j]; // b[j].put(observationValWrapper, prob); b[j][k] = prob; // System.out.println(observationValWrapper + "\t b(" + j + "," + k + ") = " + prob); } // calculate across features for (int f = 0; f < NUM_OBSERVATION_DIM; f++) { int count = N_STATE_TO_ATTR[j][f]; count = count == 0 ? 1 : count; // Laplace smoothing to prevent log(0) double prob = ((double)count) / N_STATE[j]; b_naive[j][f] = prob; } } } public static String toDoubleArrayString(double[] vector) { String ret = "["; for (double el : vector) { ret += el + ", "; } ret = ret.substring(0, ret.length()-2); // truncate off last ", " ret += "]"; return ret; } public static String toIntArrayString(int[] vector) { String ret = "["; for (int el : vector) { ret += el + ", "; } ret = ret.substring(0, ret.length()-2); // truncate off last ", " ret += "]"; return ret; } /** * Ordinarily, the result is a number of sequences of a number of observations. * However, for home activity recognition, we take a long contiguious sequence and just use a sliding window of fixed length to get the sequences */ public static List<ObservationVector> readObservationsSequencesFromFile(File f, int dimension) { try { Reader reader = new FileReader(f); List<List<ObservationVector>> v = ObservationSequencesReader.readSequences(new ObservationVectorReader(dimension), reader); // TODO get rid of magic number reader.close(); return v.get(0); } catch (IOException e) { e.printStackTrace(); } catch (FileFormatException e) { e.printStackTrace(); } return null; } /** * Ordinarily, the result is a number of sequences of a number of observations. * However, for home activity recognition, we take a long contiguious sequence and just use a sliding window of fixed length to get the sequences */ public static List<ObservationInteger> readStateSequencesFromFile(File f) { try { Reader reader = new FileReader(f); List<List<ObservationInteger>> s = ObservationSequencesReader.readSequences(new ObservationIntegerReader(), reader); reader.close(); return s.get(0); } catch (IOException e) { e.printStackTrace(); } catch (FileFormatException e) { e.printStackTrace(); } return null; } /* * For testing, loading and saving */ public static void main(String[] args) { // learning HMM HmmSupervisedLearner learner = new HmmSupervisedLearner(8, 14, 2); Hmm<ObservationVector> hmm = learner.learn( new File("demos/home-hmm/kasteren-jahmm-observations.seq"), new File("demos/home-hmm/kasteren-jahmm-states.seq")); try { /* * Save HMM to file * Quite slow! */ HmmWriter.write( new FileWriter("demos/home-hmm/kasteren-jahmm.model"), new OpdfVectorWriter(), hmm); /* * Load HMM from file * Also slow...but since this is just an init step, it may be ok * ~6.4s */ // long start = System.currentTimeMillis(); // Hmm<ObservationVector> hmm = HmmReader.read( // new FileReader("demos/home-hmm/kasteren-jahmm.model"), // new OpdfVectorReader()); // long end = System.currentTimeMillis(); // // System.out.println("hmm = " + hmm.getPi(0)); // System.out.println("duration = " + (float)(end-start)/1000); } catch (IOException e) { e.printStackTrace(); // } catch (FileFormatException e) { // e.printStackTrace(); } } }