/*********************************************************************************************
* Copyright (c) 2014-2015 Software Behaviour Analysis Lab, Concordia University, Montreal, Canada
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of Eclipse Public License v1.0 License which
* accompanies this distribution, and is available at http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Syed Shariyar Murtaza -- Initial design and implementation
**********************************************************************************************/
package org.eclipse.tracecompass.internal.totalads.algorithms.hiddenmarkovmodel;
import java.util.Arrays;
import java.util.Random;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmAlgorithms;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmTrainer;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.eclipse.tracecompass.totalads.dbms.IDBCursor;
import org.eclipse.tracecompass.totalads.dbms.IDBRecord;
import org.eclipse.tracecompass.totalads.dbms.IDataAccessObject;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSDBMSException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSGeneralException;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
/**
* This class implements the HMM algorithm using Apache Mahout library
*
* @author <p>
* Syed Shariyar Murtaza justsshary@hotmail.com
* </p>
*
*/
class HmmMahout {
private HmmModel fHmm;
/**
* Initializes Hidden Markov Model with random initial probabilities
*
* @param numSymbols
* number of symbols
* @param numStates
* Number of states
*
*/
public void initializeHMM(int numSymbols, int numStates) {
fHmm = new HmmModel(numStates, numSymbols);
}
/**
* Initializes HMM with the customized transition, emission and initial
* probabilities rather than using Mahout's initialization. Specially this
* function makes sure that initial probabilities are equal.
*
* @param numSymbols
* Number of Symbols
* @param numStates
* Number of States
*/
public void initializeHMMWithCustomizeInitialValues(int numSymbols, int numStates) {
// Generating transition probabilities with random numbers
Random random = new Random();
double start = 0.0001;
double end = 1.0000;
DenseMatrix tansitionProbabilities = new DenseMatrix(numStates, numStates);
// Measuring Transition Probabilities
double[] rowSums = new double[numStates];
Arrays.fill(rowSums, 0.0);
for (int row = 0; row < numStates; row++) {
for (int col = 0; col < numStates; col++) {
tansitionProbabilities.set(row, col, getRandomRealNumber(start, end, random));
rowSums[row] += tansitionProbabilities.get(row, col);
}
}
for (int row = 0; row < numStates; row++) {
for (int col = 0; col < numStates; col++) {
tansitionProbabilities.set(row, col, (tansitionProbabilities.get(row, col) / rowSums[row]));
}
}
// Assigning initial state probabilities Pi; i.e. probabilities at time
// 1
DenseVector initialProbabilities = new DenseVector(numStates);
double initialProb = 1 / ((double) numStates);
for (int idx = 0; idx < numStates; idx++) {
initialProbabilities.set(idx, initialProb);
}
// Measuring Emission probabilities of each symbol
DenseMatrix emissionProbabilities = new DenseMatrix(numStates, numSymbols);
Arrays.fill(rowSums, 0.0);// Utilizing the same rowSums variable
random = new Random();
for (int row = 0; row < numStates; row++) {
for (int col = 0; col < numSymbols; col++) {
emissionProbabilities.set(row, col, getRandomRealNumber(start, end, random));
rowSums[row] += emissionProbabilities.get(row, col);
}
}
for (int row = 0; row < numStates; row++) {
for (int col = 0; col < numSymbols; col++) {
emissionProbabilities.set(row, col, emissionProbabilities.get(row, col) / rowSums[row]);
}
}
fHmm = new HmmModel(tansitionProbabilities, emissionProbabilities, initialProbabilities);
}
/**
* Returns a decimal random number within a decimal range
*
* @param start
* @param end
* @param random
* @return
*/
private static double getRandomRealNumber(double start, double end, Random random) {
// get the range, casting to long to avoid overflow problems
double range = end - start;
// compute a fraction of the range, 0 <= frac < range
double fraction = (range * random.nextDouble());
double randomNumber = fraction + start;
return randomNumber;
}
/**
* Validates settings and saves them into the database after creating a new
* database if required
*
* @param settings
* SettingsForm array
* @param database
* Database name
* @param connection
* IDataAccessObject object
* @param isNewSettings
* True if settings are inserted first time, else false if
* existing fields are updated
* @param isNewDB
* if new database has to be created
* @throws TotalADSGeneralException
* Validation exception
* @throws TotalADSDBMSException
* DBMS exception
*/
public void verifySaveSettingsCreateDb(String[] settings, String database, IDataAccessObject connection, Boolean isNewSettings, Boolean isNewDB) throws TotalADSGeneralException, TotalADSDBMSException {
JsonObject settingObject = new JsonObject();
for (int i = 0; i < settings.length; i += 2) {
if (SettingsCollection.NUM_STATES.toString().equalsIgnoreCase(settings[i])) {
try {
Integer num_states = Integer.parseInt(settings[i + 1]);
settingObject.add(SettingsCollection.NUM_STATES.toString(), new JsonPrimitive(num_states));
} catch (Exception ex) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectIntStates);
}
} else if (SettingsCollection.NUM_SYMBOLS.toString().equalsIgnoreCase(settings[i])) {
try {
Integer num_symbols = Integer.parseInt(settings[i + 1]);
settingObject.add(SettingsCollection.NUM_SYMBOLS.toString(), new JsonPrimitive(num_symbols));
} catch (Exception ex) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectIntSymbols);
}
} else if (SettingsCollection.SEQ_LENGTH.toString().equalsIgnoreCase(settings[i])) {
try {
Integer seqLength = Integer.parseInt(settings[i + 1]);
settingObject.add(SettingsCollection.SEQ_LENGTH.toString(), new JsonPrimitive(seqLength));
} catch (Exception ex) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectIntSeq);
}
} else if (SettingsCollection.LOG_LIKELIHOOD.toString().equalsIgnoreCase(settings[i])) {
Double prob = null;
try {
prob = Double.parseDouble(settings[i + 1]);
} catch (Exception ex) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectDecForLog);
}
if (prob > 0.0) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectNegForLog);
}
settingObject.add(SettingsCollection.LOG_LIKELIHOOD.toString().toString(), new JsonPrimitive(prob));
} else if (SettingsCollection.NUMBER_OF_ITERATIONS.toString().equalsIgnoreCase(settings[i])) {
Integer it = null;
try {
it = Integer.parseInt(settings[i + 1]);
} catch (Exception ex) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectIntIteration);
}
if (it <= 0) {
throw new TotalADSGeneralException(Messages.HmmMahout_SelectIterations);
}
settingObject.add(SettingsCollection.NUMBER_OF_ITERATIONS.toString().toString(), new JsonPrimitive(it));
} else if (SettingsCollection.KEY.toString().equalsIgnoreCase(settings[i])) {
settingObject.add(SettingsCollection.KEY.toString(), new JsonPrimitive("hmm")); //$NON-NLS-1$
}
}
// creating id for query searching
JsonObject jsonKey = new JsonObject();
jsonKey.addProperty(SettingsCollection.KEY.toString(), "hmm"); //$NON-NLS-1$
if (isNewDB) {
String[] collectionNames = { HmmModelCollection.COLLECTION_NAME.toString(), SettingsCollection.COLLECTION_NAME.toString()
, NameToIDCollection.COLLECTION_NAME.toString() };
connection.createDatabase(database, collectionNames);
}
if (isNewSettings) {
connection.insertOrUpdateUsingJSON(database, jsonKey, settingObject, SettingsCollection.COLLECTION_NAME.toString());
} else {
connection.updateFieldsInExistingDocUsingJSON(database, jsonKey, settingObject, SettingsCollection.COLLECTION_NAME.toString());
}
}
/**
* Loads settings from the database
*
* @param database
* Database or model name
* @param dataAccessObject
* Data access object
* @return Settings as an array of String
* @throws TotalADSDBMSException
* DBMS Exception
*/
public String[] loadSettings(String database, IDataAccessObject dataAccessObject) throws TotalADSDBMSException {
String[] settings = null;
try (IDBCursor cursor = dataAccessObject.selectAll(database,
SettingsCollection.COLLECTION_NAME.toString())) {
if (cursor.hasNext()) {
settings = new String[10];
IDBRecord dbObject = cursor.next();
settings[0] = SettingsCollection.NUM_STATES.toString();
settings[1] = dbObject.get(SettingsCollection.NUM_STATES.toString()).toString();
settings[2] = SettingsCollection.NUMBER_OF_ITERATIONS.toString();
settings[3] = dbObject.get(SettingsCollection.NUMBER_OF_ITERATIONS.toString()).toString();
settings[4] = SettingsCollection.NUM_SYMBOLS.toString();
settings[5] = dbObject.get(SettingsCollection.NUM_SYMBOLS.toString()).toString();
settings[6] = SettingsCollection.LOG_LIKELIHOOD.toString();
settings[7] = dbObject.get(SettingsCollection.LOG_LIKELIHOOD.toString()).toString();
settings[8] = SettingsCollection.SEQ_LENGTH.toString();
settings[9] = dbObject.get(SettingsCollection.SEQ_LENGTH.toString()).toString();
}
}
return settings;
}
/**
* Trains an HMM on a sequence using the BaumWelch algorithm
*
* @param numIterations
* Number of Iterations
* @param observedSequence
* The sequence
*/
public void learnUsingBaumWelch(Integer numIterations, Integer[] observedSequence) {
int[] seq = new int[observedSequence.length];
for (int i = 0; i < seq.length; i++) {
seq[i] = observedSequence[i];
}
HmmTrainer.trainBaumWelch(fHmm, seq, 0.0001, numIterations, true);
}
/**
* Trains an HMM on a sequence using the BaumWelch algorithm
*
* @param numIterations
* Number of iterations
* @param observedSequence
* The sequence
*/
public void learnUsingBaumWelch(Integer numIterations, int[] observedSequence) {
HmmTrainer.trainBaumWelch(fHmm, observedSequence, 0.0001, numIterations, true);
}
/**
* Returns the observation sequence's log likelihood based on a model
*
* @param sequence
* Integer array of sequences
* @return Log Likelihood
*/
public double observationLikelihood(int[] sequence) {
Matrix m = HmmAlgorithms.forwardAlgorithm(fHmm, sequence, true);
int lastCol = m.numCols() - 1;
int numRows = m.numRows();
double sum = 0.0;
for (int i = 0; i < numRows; i++) {
sum += m.getQuick(i, lastCol);
}
return sum;
}
/**
* Update HMM based on an incremental version as described in
* http://goanna.cs.rmit.edu.au/~jiankun/Sample_Publication/ICON04_Dau.pdf
*
* @param sequence
* The sequence
* @param dataAccessObject
* Data access object
* @param database
* Model name
* @throws TotalADSDBMSException
* Validation exception
*/
public void updatePreviousModel(Integer[] sequence, IDataAccessObject dataAccessObject, String database) throws TotalADSDBMSException {
int[] seq = new int[sequence.length];
for (int i = 0; i < sequence.length; i++) {
seq[i] = sequence[i];
}
double prob = 1.0;
Matrix transition = fHmm.getTransitionMatrix().divide(prob);
Matrix emission = fHmm.getEmissionMatrix().divide(prob);
Vector initial = fHmm.getInitialProbabilities().divide(prob);
HmmMahout oldHMM = new HmmMahout();
oldHMM.loadHmm(dataAccessObject, database);
if (oldHMM.fHmm != null) {
transition = oldHMM.fHmm.getTransitionMatrix().plus(transition);
emission = oldHMM.fHmm.getEmissionMatrix().plus(emission);
initial = oldHMM.fHmm.getInitialProbabilities().plus(initial);
}
HmmMahout newHMM = new HmmMahout();
newHMM.fHmm = new HmmModel(transition, emission, initial);
newHMM.saveHMM(database, dataAccessObject);
}
/**
* Loads the model directly from a database
*
* @param dao
* Data access object
* @param modelName
* Model (or database) name
* @throws TotalADSDBMSException
* DBMS exception
*/
public void loadHmm(IDataAccessObject dao, String modelName) throws TotalADSDBMSException {
try (IDBCursor cursor = dao.selectAll(modelName,
HmmModelCollection.COLLECTION_NAME.toString())) {
if (cursor.hasNext()) {
Gson gson = new Gson();
if (cursor.hasNext()) {
IDBRecord dbObject = cursor.next();
Object emissionProb = dbObject.get(HmmModelCollection.EMISSIONPROB.toString());
Object transsitionProb = dbObject.get(HmmModelCollection.TRANSITIONPROB.toString());
Object initialProb = dbObject.get(HmmModelCollection.INTITIALPROB.toString());
DenseMatrix emissionMatrix = gson.fromJson(emissionProb.toString(), DenseMatrix.class);
DenseMatrix transitionMatrix = gson.fromJson(transsitionProb.toString(), DenseMatrix.class);
DenseVector initialProbVector = gson.fromJson(initialProb.toString(), DenseVector.class);
fHmm = new HmmModel(transitionMatrix, emissionMatrix, initialProbVector);
}
}
}
}
/**
* This functions saves the HmmJahmm model into the database
*
* @param database
* Model (or database) name
* @param dao
* Data access object
* @throws TotalADSDBMSException
* DBMS exception
*/
public void saveHMM(String database, IDataAccessObject dao) throws TotalADSDBMSException {
// / Inserting the states and probabilities
// Creating states ids
String key = "hmm"; //$NON-NLS-1$
Gson gson = new Gson();
DenseMatrix emissionMatrix = (DenseMatrix) fHmm.getEmissionMatrix();
DenseMatrix transitionMatrix = (DenseMatrix) fHmm.getTransitionMatrix();
Vector initialProb = fHmm.getInitialProbabilities();
JsonObject hmmDoc = new JsonObject();
hmmDoc.add(HmmModelCollection.KEY.toString(), new JsonPrimitive(key));
hmmDoc.add(HmmModelCollection.EMISSIONPROB.toString(), gson.toJsonTree(emissionMatrix));
hmmDoc.add(HmmModelCollection.TRANSITIONPROB.toString(), gson.toJsonTree(transitionMatrix));
hmmDoc.add(HmmModelCollection.INTITIALPROB.toString(), gson.toJsonTree(initialProb));
// Creating id for query searching
JsonObject jsonTheKey = new JsonObject();
jsonTheKey.addProperty(HmmModelCollection.KEY.toString(), key);
dao.insertOrUpdateUsingJSON(database, jsonTheKey, hmmDoc, HmmModelCollection.COLLECTION_NAME.toString());
}
/**
* Prints the model
*
* @return HMM model in the textual representation
*/
@Override
public String toString() {
return Messages.HmmMahout_HmmModel + "\n" + //$NON-NLS-1$
Messages.HmmMahout_HiddenStates + fHmm.getNrOfHiddenStates() + "\n" + //$NON-NLS-1$
Messages.HmmMahout_ObservableEvents + fHmm.getNrOfOutputStates() + "\n" + //$NON-NLS-1$
Messages.HmmMahout_EmisssionProbs + fHmm.getEmissionMatrix().toString() + "\n" + //$NON-NLS-1$
Messages.HmmMahout_Transition + fHmm.getTransitionMatrix().toString() + "\n" + //$NON-NLS-1$
Messages.HmmMahout_Initial + fHmm.getInitialProbabilities();
}
}