package context.arch.intelligibility.hmm;
import java.util.List;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.ObservationVector;
import be.ac.ulg.montefiore.run.jahmm.io.OpdfVector;
import context.arch.discoverer.query.HmmWrapper;
import context.arch.enactor.HmmEnactor;
import context.arch.intelligibility.Explainer;
import context.arch.intelligibility.expression.DNF;
import context.arch.intelligibility.expression.Parameter;
import context.arch.intelligibility.expression.Reason;
import context.arch.intelligibility.query.Query;
import context.arch.widget.SequenceWidget;
public class HmmExplainer extends Explainer {
protected Hmm<ObservationVector> hmm;
protected HmmWrapper hmmWrapper;
protected HmmEnactor<ObservationVector> hmmEnactor;
/*
* The following are HMM parameters copied from the model
*/
protected double[] pi; // to model probabilities of states at t=1
protected double[][] a; // to model A matrix: state transition probabilities
protected double[][] b; // to model B matrix: emission probabilities from states to observations
protected double[][] b_naive; // simplified/modified emission probabilities; see comments at top of class
protected String[] OUTPUT_NAMES; // names of states
protected String[] INPUT_NAMES; // names of each feature input
public HmmExplainer(HmmEnactor<ObservationVector> enactor) {
super(enactor);
this.hmmEnactor = enactor; // save casted form
this.hmmWrapper = enactor.getHMM();
this.hmm = hmmWrapper.getHmm();
portParamsToMatrices();
N = hmm.nbStates();
T = hmmWrapper.getSequenceLength();
n = hmmWrapper.getInputNames().length;
N_pow_T = Math.pow(N, T);
F = new double[n];
}
public static final int F_VERSION_FULL = 0; // show evidence due to observations over time and features
public static final int F_VERSION_BY_FEATURE = 1; // show evidence due to observations summed over time; get to distinguish by features only
public static final int F_VERSION_BY_TIME = 2; // show evidence due to observations summed over features; get to distinguish by time only
protected int F_VERSION = F_VERSION_FULL;
public int getFeatureVersion() {
return F_VERSION;
}
public void setFeatureVersion(int F_VERSION) {
this.F_VERSION = F_VERSION;
}
/* =======================================================================================================================================
* Whole bunch of code to port from Jahmm objects to arrays (which are probably more efficient)
*/
/**
* Port parameters of Hmm to matrices.
*/
protected void portParamsToMatrices() {
int NUM_STATES = hmm.nbStates();
pi = new double[NUM_STATES];
for (int i = 0; i < NUM_STATES; i++) {
pi[i] = hmm.getPi(i);
}
a = new double[NUM_STATES][NUM_STATES];
for (int i = 0; i < NUM_STATES; i++) {
for (int j = 0; j < NUM_STATES; j++) {
a[i][j] = hmm.getAij(i, j);
}
}
int VECTOR_DIM = ((OpdfVector)hmm.getOpdf(0)).dimension(); // TODO: should do some checking
int VECTOR_PERMS = (int)Math.pow(2, VECTOR_DIM); // TODO: warning: magic number representing NUM_OBSERVATION_VALS
b = new double[NUM_STATES][VECTOR_PERMS];
b_naive = new double[NUM_STATES][VECTOR_DIM];
for (int i = 0; i < NUM_STATES; i++) {
for (int k = 0; k < VECTOR_PERMS; k++) {
b[i][k] = ((OpdfVector)hmm.getOpdf(i)).probability(k);
}
/*
* Derive naive emission matrix
*/
for (int f = 0; f < VECTOR_DIM; f++) {
b_naive[i][f] = 0;
for (int k = 0; k < VECTOR_PERMS; k++) {
if (((1 << f) & k) > 0) { // bit mask of whether k has bit value = 1 in element f
b_naive[i][f] += b[i][k]; // then add it
}
}
}
}
OUTPUT_NAMES = hmmWrapper.getOutputNames();
INPUT_NAMES = hmmWrapper.getInputNames();
}
/**
* TODO: this is currently a hack; make more 'native'
* Convert from list to 2d array
* @param o_naive_l
* @return
*/
protected double[][] portToObservationsArray(List<ObservationVector> o_naive_l) {
if (o_naive_l.isEmpty()) { return new double[0][0]; }
if (!(o_naive_l.get(0) instanceof ObservationVector)) { throw new RuntimeException("Only works with ObservationVector for now"); }
double[][] o_naive = new double[o_naive_l.size()][((ObservationVector)o_naive_l.get(0)).dimension()];
for (int t = 0; t < o_naive.length; t++) {
for (int r = 0; r < o_naive[t].length; r++) {
o_naive[t][r] = ((ObservationVector)o_naive_l.get(t)).value(r);
}
}
return o_naive;
}
/* ======================================================================================================================================= */
/**
* g(...) = ...
* @param o_naive observation sequence
* @param x state sequence (actual, or target/desired)
* @param F_VERSION
* @return
*/
protected Reason getWhyExplanation(List<ObservationVector> o_naive_l, int[] x, int F_VERSION) {
Reason reason = new Reason();
// TODO: make more 'native'
double[][] o_naive = portToObservationsArray(o_naive_l);
// total sum of evidence
double totalEvidence = 0; // sum through method
int numComponents = 0;
// prior probability
double priorEvidence = getPriorEvidence(x[0], T);
//String prior_name = "t" + 0 + "(" + OUTPUT_NAMES[x[0]] + ")"; // e.g. t0(bed)
String prior_name = SequenceWidget.getTPrepend(0); // "__T0_"
reason.add(Parameter.instance(prior_name, priorEvidence));
totalEvidence += priorEvidence;
numComponents++;
// transition probabilities
for (int t = 1; t < T; t++) { // t=1 to T; skip first
double transitionEvidence = getTransitionEvidence(x[t-1], x[t]);
//String name = "t" + (t) + "((" + OUTPUT_NAMES[x[t-1]] + ")_to_(" + OUTPUT_NAMES[x[t]] + "))"; // e.g. t1((bed)_to_(breakfast))
String name = SequenceWidget.getTPrepend(t); // "__T#_"
reason.add(Parameter.instance(name, transitionEvidence));
totalEvidence += transitionEvidence;
numComponents++;
}
/*
* emission probabilities: 3 versions
*/
if (F_VERSION == F_VERSION_BY_FEATURE) { // get to distinguish by features, r
for (int r = 0; r < n; r++) {
double featureEvidence = 0;
String name = "(" + INPUT_NAMES[r] + "="; // e.g. (microwave=1,0,0,...)
for (int t = 0; t < T; t++) {
featureEvidence += getFeatureEvidence(x[t], o_naive[t], r);
name += (int)o_naive[t][r] + (t < T-1 ? "," : ")");
}
reason.add(Parameter.instance(name, featureEvidence));
totalEvidence += featureEvidence;
numComponents++;
}
}
else if (F_VERSION == F_VERSION_BY_TIME) { // get to distinguish by time, t
for (int t = 0; t < T; t++) {
double featureEvidence = 0;
String name = "t" + t + "("; // e.g. t1(microwave=1,toilet=0,...)
for (int r = 0; r < n; r++) {
featureEvidence += getFeatureEvidence(x[t], o_naive[t], r);
name += INPUT_NAMES[r] + "=" + o_naive[t][r] + (t < T-1 ? "," : ")");
}
reason.add(Parameter.instance(name, featureEvidence));
totalEvidence += featureEvidence;
numComponents++;
}
}
else { // F_VERSION == F_VERSION_FULL
for (int t = 0; t < T; t++) {
for (int r = 0; r < n; r++) {
double featureEvidence = getFeatureEvidence(x[t], o_naive[t], r);
// String name = "t"+t+"(" + INPUT_NAMES[r] + "=" + o_naive[t][r] + ")"; // e.g. t1(microwave=1)
String name = SequenceWidget.getTPrepend(t) + INPUT_NAMES[r]; // "__T#_Name"
reason.add(Parameter.instance(name, featureEvidence));
totalEvidence += featureEvidence;
numComponents++;
}
}
}
// average evidence; use average instead of total, to normalize the "lengths" for visualization
// int n = NUM_OBSERVATION_DIM;
// double avgEvidence = totalEvidence / (1 + (T-1) + T*n); // obsolete
double avgEvidence = totalEvidence / numComponents;
reason.add(0, Parameter.instance("average", avgEvidence)); // add to front
return reason;
}
/**
* delta_g = g(xTarget) - g(xActual)
* @param o_naive
* @param xActual
* @param xTarget
* @param F_VERSION
* @return
*/
protected Reason getWhyNotExplanation(List<ObservationVector> o_naive_l, int[] xActual, int[] xTarget, int F_VERSION) {
Reason list = new Reason();
// TODO: make more 'native'
double[][] o_naive = portToObservationsArray(o_naive_l);
// total sum of evidence
double dTotalEvidence = 0; // sum through method
int numComponents = 0;
// calculate and add evidence due to prior
double whyPriorEvidence = getPriorEvidence(xActual[0], T);
double whyNotPriorEvidence = getPriorEvidence(xTarget[0], T);
double dPriorEvidence = whyNotPriorEvidence - whyPriorEvidence; // delta = target - actual
// String prior_name = "t" + 0 + "(" + OUTPUT_NAMES[xTarget[0]] + " vs. " + OUTPUT_NAMES[xActual[0]] + ")"; // e.g. t0(target vs. actual)
String prior_name = SequenceWidget.getTPrepend(0); // "__T0_"
list.add(Parameter.instance(prior_name, dPriorEvidence));
dTotalEvidence += dPriorEvidence;
numComponents++;
// transition probabilities
for (int t = 1; t < T; t++) { // t=1 to T; skip first
double whyTransitionEvidence = getTransitionEvidence(xActual[t-1], xActual[t]);
double whyNotTransitionEvidence = getTransitionEvidence(xTarget[t-1], xTarget[t]);
// System.out.println("whyNotTransitionEvidence = " +
// "getTransitionEvidence(xTarget["+(t-1)+"], xTarget["+t+"]) = " +
// "getTransitionEvidence("+xTarget[t-1]+", "+xTarget[t]+") = " +
// getTransitionEvidence(xTarget[t-1], xTarget[t]));
double dTransitionEvidence = whyNotTransitionEvidence - whyTransitionEvidence; // delta = target - actual
// String name = "t" + (t) + "(((" + OUTPUT_NAMES[xTarget[t-1]] + ")_to_(" + OUTPUT_NAMES[xTarget[t]] + ")) vs. " +
// "((" + OUTPUT_NAMES[xActual[t-1]] + ")_to_(" + OUTPUT_NAMES[xActual[t]] + ")))"; // e.g. t1(((bed)_to_(breakfast)) vs. ((bed)_to_(breakfast)))
String name = SequenceWidget.getTPrepend(t); // "__T#_"
list.add(Parameter.instance(name, dTransitionEvidence));
dTotalEvidence += dTransitionEvidence;
numComponents++;
}
/*
* emission probabilities: 3 versions
*/
if (F_VERSION == F_VERSION_BY_FEATURE) { // get to distinguish by features, r
for (int r = 0; r < n; r++) {
double whyFeatureEvidence = 0, whyNotFeatureEvidence = 0;
String name = "(" + INPUT_NAMES[r] + "="; // e.g. (microwave=1,0,0,...)
for (int t = 0; t < T; t++) {
whyFeatureEvidence += getFeatureEvidence(xActual[t], o_naive[t], r);
whyNotFeatureEvidence += getFeatureEvidence(xTarget[t], o_naive[t], r);
name += (int)o_naive[t][r] + (t < T-1 ? "," : ")");
}
double dFeatureEvidence = whyNotFeatureEvidence - whyFeatureEvidence; // delta = target - actual
list.add(Parameter.instance(name, dFeatureEvidence));
dTotalEvidence += dFeatureEvidence;
numComponents++;
}
}
else if (F_VERSION == F_VERSION_BY_TIME) { // get to distinguish by time, t
for (int t = 0; t < T; t++) {
double whyFeatureEvidence = 0, whyNotFeatureEvidence = 0;
String name = "t" + t + "("; // e.g. t1(microwave=1,toilet=0,...)
for (int r = 0; r < n; r++) {
whyFeatureEvidence += getFeatureEvidence(xActual[t], o_naive[t], r);
whyNotFeatureEvidence += getFeatureEvidence(xTarget[t], o_naive[t], r);
name += INPUT_NAMES[r] + "=" + o_naive[t][r] + (t < T-1 ? "," : ")");
}
double dFeatureEvidence = whyNotFeatureEvidence - whyFeatureEvidence; // delta = target - actual
list.add(Parameter.instance(name, dFeatureEvidence));
dTotalEvidence += dFeatureEvidence;
numComponents++;
}
}
else { // F_VERSION == F_VERSION_FULL
for (int t = 0; t < T; t++) {
for (int r = 0; r < n; r++) {
double whyFeatureEvidence = getFeatureEvidence(xActual[t], o_naive[t], r);
double whyNotFeatureEvidence = getFeatureEvidence(xTarget[t], o_naive[t], r);
// String name = "t"+t+"(" + INPUT_NAMES[r] + "=" + o_naive[t][r] + ")"; // e.g. t1(microwave=1)
String name = SequenceWidget.getTPrepend(t) + INPUT_NAMES[r]; // "__T#_Name"
double dFeatureEvidence = whyNotFeatureEvidence - whyFeatureEvidence; // delta = target - actual
list.add(Parameter.instance(name, dFeatureEvidence));
dTotalEvidence += dFeatureEvidence;
numComponents++;
}
}
}
// average evidence; use average instead of total, to normalize the "lengths" for visualization
// int n = NUM_OBSERVATION_DIM;
double avgEvidence = dTotalEvidence / numComponents;
list.add(0, Parameter.instance("average", avgEvidence)); // add to front
return list;
}
int N; // number of states
int T; // sequence length
int n; // number of features
/** Commonly used factor, so store it; N^T */
double N_pow_T;
/** Is a constant that only needs to be set once */
protected double H = Double.NaN;
/**
* @param x0
* @return
*/
protected double getPriorEvidence(int x0, int T) {
if (Double.isNaN(H)) { // not yet set
// H = [sum(j=1..N){log(pi[j])}]^T
H = 1;
for (double p : pi) { H += Math.log(p); }
H = Math.pow(H, T);
}
// h = (N^T)*log(pi[x0]) - H
double evidence = N_pow_T * Math.log(pi[x0]) - H;
// System.out.println("N_pow_T: " + N_pow_T);
// System.out.println("H: " + H);
// System.out.println("N_pow_T * Math.log(pi[x0]): " + N_pow_T * Math.log(pi[x0]));
return evidence;
}
/** Is a constant that only needs to be set once */
protected double U = Double.NaN;
/**
*
* @param x1 from this state; x[t-1]
* @param x2 to this state; x[t]
* @return
*/
protected double getTransitionEvidence(int x1, int x2) {
// U = [sum(j1=1..N,j2=1..N){log(a[j1][j2])}]^(T-2)
if (Double.isNaN(U)) {
U = 1;
for (int j1 = 0; j1 < N; j1++) {
for (int j2 = 0; j2 < N; j2++) {
U += Math.log(a[j1][j2]);
}
}
U = Math.pow(U, T-2);
}
// u = (N^T)*log(a[x1][x2]) - U
double evidence = N_pow_T * Math.log(a[x1][x2]) - U;
// System.out.println("U: " + U);
// System.out.println("Math.log(a[x1][x2]): " + Math.log(a[x1][x2]));
return evidence;
}
/** Is a constant that only needs to be set once */
protected double[] F;
/**
*
* @param xt state at time t
* @param ot observation vector at time t
* @param r index of feature of observation we care about here
* @return
*/
protected double getFeatureEvidence(int xt, double[] ot, int r) {
if (F[r] == 0) { // not yet set
double Fr = 1;
// F = n^(T-1) * [sum(j=1..N){log(b_naive[j][r])}]^T
for (int j = 0; j < N; j++) { // iterate states
double p;
if (ot[r] == 1) { p = b_naive[j][r]; }
else { p = 1 - b_naive[j][r]; } // probability of not
// System.out.println("b_naive["+j+"]["+r+"] = " + b_naive[j][r]);
if (p > 0) {
Fr += Math.log(p);
}
else {
/*
* because p <= 0, probably due to some addition error arising from
* counting too many zeros even with Laplace smoothing
*/
Fr += 0;
}
// System.out.println("F"+r+" = " + Fr + ", j = " + j + ", p = " + p + ", N = " + N); // why is p negative? b_naive > 1
}
Fr = Math.pow(Fr, T);
Fr *= Math.pow(n, T-1);
F[r] = Fr;
// System.out.println("F["+r+"] = " + F[r]);
}
// f = (N^T)*log(b_naive[i][r]) - F
double evidence = N_pow_T * Math.log(b_naive[xt][r]) - F[r];
// System.out.println("F: " + F);
// System.out.println("Math.log(b_naive[i][r]): " + Math.log(b_naive[xt][r]));
return evidence;
}
// TODO increase efficiency by using caching
@SuppressWarnings("serial")
@Override
public DNF getWhyExplanation() {
final List<ObservationVector> o = hmmEnactor.getObservations();
final int[] x = hmm.mostLikelyStateSequence(o);
return new DNF() {{ add(getWhyExplanation(o, x, F_VERSION)); }};
}
/**
*
* @param altOutcomeValue; assumes format: "# # #..." up to sequence length
* @return
*/
protected int[] parseOutcomeValueSequence(String altOutcomeValue) {
String[] altOutcomeValueSequence = altOutcomeValue.split(" ");
int[] xTarget = new int[altOutcomeValueSequence.length];
for (int i = 0; i < xTarget.length; i++) {
try {
xTarget[i] = Integer.parseInt(altOutcomeValueSequence[i]);
} catch (NumberFormatException e) {
e.printStackTrace();
}
}
return xTarget;
}
@Override
public DNF getWhyNotExplanation(String altOutcomeValue) {
return getWhyNotExplanations(parseOutcomeValueSequence(altOutcomeValue));
}
public DNF getWhyNotExplanations(int[] xTarget) {
List<ObservationVector> o = hmmEnactor.getObservations();
int[] x = hmm.mostLikelyStateSequence(o);
Reason conj = getWhyNotExplanation(o, x, xTarget, F_VERSION);
DNF dnf = new DNF();
dnf.add(conj);
return dnf;
}
@Override
public DNF getHowToExplanation(String altOutcomeValue) {
return getHowToExplanations(parseOutcomeValueSequence(altOutcomeValue));
}
public DNF getHowToExplanations(int[] xTarget) {
List<ObservationVector> o = hmmEnactor.getObservations();
Reason conj = getWhyExplanation(o, xTarget, F_VERSION);
DNF dnf = new DNF();
dnf.add(conj);
return dnf;
}
@Override
public DNF getCertaintyExplanation() {
List<ObservationVector> o = hmmEnactor.getObservations();
int[] x = hmm.mostLikelyStateSequence(o);
double certainty = hmm.probability(o, x);
return new DNF(Parameter.instance(Query.QUESTION_CERTAINTY, certainty));
}
/* ================================================================================
* Internal methods to calculate evidences due to the HMM
* ================================================================================ */
}