/*********************************************************************************************
* Copyright (c) 2014-2015 Software Behaviour Analysis Lab, Concordia University, Montreal, Canada
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of Eclipse Public License v1.0 License which
* accompanies this distribution, and is available at http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Syed Shariyar Murtaza -- Initial design and implementation
**********************************************************************************************/
package org.eclipse.tracecompass.internal.totalads.algorithms.hiddenmarkovmodel;
import java.util.LinkedList;
import org.eclipse.osgi.util.NLS;
import org.eclipse.tracecompass.totalads.algorithms.AlgorithmFactory;
import org.eclipse.tracecompass.totalads.algorithms.AlgorithmTypes;
import org.eclipse.tracecompass.totalads.algorithms.IAlgorithmOutStream;
import org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm;
import org.eclipse.tracecompass.totalads.algorithms.Results;
import org.eclipse.tracecompass.totalads.dbms.IDataAccessObject;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSDBMSException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSGeneralException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSReaderException;
import org.eclipse.tracecompass.totalads.readers.ITraceIterator;
import org.swtchart.Chart;
/**
* This class implements the Hidden Markov Model for anomaly detection
*
* @author <p>
* Syed Shariyar Murtaza justsshary@hotmail.com
* </p>
*
*/
public class HiddenMarkovModel implements IDetectionAlgorithm {
private Integer fSeqLength;
private HmmMahout fHmm;
private NameToIDMapper fNameToID;
private boolean fIsTrainIntialized = false, fIsTestInitialized = false;
private int fNumStates, fNumSymbols, fNumIterations, fTestNameToIDSize;
private Double fTotalTestAnomalies = 0.0, fTotalTestTraces = 0.0, fLogThresholdTest = 0.0;
private LinkedList<Integer> fBatchLargeTrainingSeq;
private Double fTestTraceMinThreshold;
/**
* Constructor
*/
public HiddenMarkovModel() {
fNameToID = new NameToIDMapper();
fSeqLength = 1000;
}
/**
* Self registration of the model with the modelFactory
*
* @throws TotalADSGeneralException
* Validation exception
*/
public static void registerAlgorithm() throws TotalADSGeneralException {
AlgorithmFactory modelFactory = AlgorithmFactory.getInstance();
HiddenMarkovModel hmm = new HiddenMarkovModel();
modelFactory.registerModelWithFactory(AlgorithmTypes.ANOMALY, hmm);
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* initializeModelAndSettings(java.lang.String,
* org.eclipse.tracecompass.totalads.dbms.IDataAccessObject,
* java.lang.String[])
*/
@Override
public void initializeModelAndSettings(String modelName, IDataAccessObject dataAccessObject, String[] trainingSettings) throws TotalADSDBMSException, TotalADSGeneralException {
String[] setting = null;
if (trainingSettings != null) {
setting = new String[trainingSettings.length + 8];
setting[0] = SettingsCollection.KEY.toString();
setting[1] = "HMM"; //$NON-NLS-1$
for (int i = 0; i < trainingSettings.length; i++) {
setting[i + 2] = trainingSettings[i];
}
int idx = trainingSettings.length + 1;
setting[++idx] = SettingsCollection.NUM_SYMBOLS.toString();
setting[++idx] = "0"; //$NON-NLS-1$
setting[++idx] = SettingsCollection.LOG_LIKELIHOOD.toString();
setting[++idx] = "0.0"; //$NON-NLS-1$
setting[++idx] = SettingsCollection.SEQ_LENGTH.toString();
setting[++idx] = fSeqLength.toString();
} else {
String[] settings = { SettingsCollection.KEY.toString(), "Hmm", //$NON-NLS-1$
SettingsCollection.NUM_STATES.toString(), "5", //$NON-NLS-1$
SettingsCollection.NUMBER_OF_ITERATIONS.toString(), "10", //$NON-NLS-1$
SettingsCollection.NUM_SYMBOLS.toString(), "0", //$NON-NLS-1$
SettingsCollection.LOG_LIKELIHOOD.toString(), "0.0", //$NON-NLS-1$
SettingsCollection.SEQ_LENGTH.toString(), fSeqLength.toString() };
setting = settings;
}
HmmMahout hmm = new HmmMahout();
hmm.verifySaveSettingsCreateDb(setting, modelName, dataAccessObject, true, true);
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* getTrainingOptions()
*/
@Override
public String[] getTrainingSettings() {
String[] trainingSettings = new String[4];
trainingSettings[0] = SettingsCollection.NUM_STATES.toString();
trainingSettings[1] = "5"; //$NON-NLS-1$
trainingSettings[2] = SettingsCollection.NUMBER_OF_ITERATIONS.toString();
trainingSettings[3] = "10"; //$NON-NLS-1$
return trainingSettings;
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* getTestingOptions(java.lang.String,
* org.eclipse.tracecompass.totalads.dbms.IDataAccessObject)
*/
@Override
public String[] getTestSettings(String database, IDataAccessObject dataAccessObject) throws TotalADSDBMSException {
HmmMahout hmm = new HmmMahout();
String[] settings = hmm.loadSettings(database, dataAccessObject);
if (settings == null) {
return null;
}
String[] testingSettings = new String[4];
testingSettings[0] = SettingsCollection.LOG_LIKELIHOOD.toString();
testingSettings[1] = settings[7]; // probability
testingSettings[2] = SettingsCollection.SEQ_LENGTH.toString();
testingSettings[3] = settings[9]; // probability
return testingSettings;
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* saveTestingOptions(java.lang.String[], java.lang.String,
* org.eclipse.tracecompass.totalads.dbms.IDataAccessObject)
*/
@Override
public void saveTestSettings(String[] options, String database, IDataAccessObject dataAccessObject) throws TotalADSGeneralException, TotalADSDBMSException
{
HmmMahout hmm = new HmmMahout();
hmm.verifySaveSettingsCreateDb(options, database, dataAccessObject, false, false);
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* getSettingsToDisplay()
*/
@Override
public String[] getSettingsToDisplay(String database, IDataAccessObject dataAccessObject) throws TotalADSDBMSException {
HmmMahout hmm = new HmmMahout();
String[] settings = hmm.loadSettings(database, dataAccessObject);
return settings;
}
/*
* (non-Javadoc)
*
* @see
* org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#train
* (org.eclipse.tracecompass.totalads.readers.ITraceIterator,
* java.lang.Boolean, java.lang.String,
* org.eclipse.tracecompass.totalads.dbms.IDataAccessObject,
* org.eclipse.tracecompass.totalads.algorithms.IAlgorithmOutStream)
*/
@Override
public void train(ITraceIterator trace, Boolean isLastTrace, String database, IDataAccessObject dataAccessObject, IAlgorithmOutStream outStream) throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {
if (trace == null || isLastTrace == null || database == null || dataAccessObject == null || outStream == null) {
throw new TotalADSGeneralException(Messages.HiddenMarkovModel_NullArguments);
}
batchTraining(trace, isLastTrace, database, dataAccessObject, outStream);
}
/**
* Trains an HMM on a collection of traces at once; i.e., in a batch
*
* @param trace
* Trace iterator
* @param isLastTrace
* True if it is the last trace
* @param database
* Database name
* @param dao
* data access object
* @param outStream
* output stream to display message
* @throws TotalADSGeneralException
* Validation exception
* @throws TotalADSDBMSException
* DBMS exception
* @throws TotalADSReaderException
* Reader exception
*/
private void batchTraining(ITraceIterator trace, Boolean isLastTrace, String database, IDataAccessObject dao, IAlgorithmOutStream outStream) throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {
if (!fIsTrainIntialized) {
fHmm = new HmmMahout();
String[] options = fHmm.loadSettings(database, dao);// get settings
// from db
fNumStates = Integer.parseInt(options[1]);
fNumIterations = Integer.parseInt(options[3]);
fNameToID.loadMap(dao, database);
fIsTrainIntialized = true;
fBatchLargeTrainingSeq = new LinkedList<>();
}
outStream.addOutputEvent(Messages.HiddenMarkovModel_ExtractionMsg);
outStream.addNewLine();
String event = null;
while (trace.advance()) {
event = trace.getCurrentEvent();
fBatchLargeTrainingSeq.add(fNameToID.getId(event));
}
if (isLastTrace) {
fNumSymbols = fNameToID.getSize();
outStream.addOutputEvent(Messages.HiddenMarkovModel_BaumWelchMsg);
outStream.addNewLine();
int[] seq = new int[fBatchLargeTrainingSeq.size()];
for (int i = 0; i < fBatchLargeTrainingSeq.size(); i++) {
seq[i] = fBatchLargeTrainingSeq.get(i);
}
fBatchLargeTrainingSeq.clear();// clear memory
fHmm = trainBaumWelch(seq, fNumStates, fNumSymbols, fNumIterations);
outStream.addOutputEvent(Messages.HiddenMarkovModel_SaveHMMMsg);
outStream.addNewLine();
fHmm.saveHMM(database, dao);
// Get settings n update them
String[] settings = new String[2];
settings[0] = SettingsCollection.NUM_SYMBOLS.toString();
settings[1] = Integer.toString(fNumSymbols);
fHmm.verifySaveSettingsCreateDb(settings, database, dao, false, false);
outStream.addOutputEvent(fHmm.toString());
outStream.addNewLine();
// fHmm.saveHMM(database, connection);
fNameToID.saveMap(dao, database);
}
}
/**
* Trains using BaumWelch
*
* @param seq
* Sequence
* @param numStates
* Number of states
* @param numSymbols
* Number of symbols
* @param numIterations
* Number of iteration
* @return returns a trained HMM
* @throws TotalADSGeneralException
* validation exception
*/
private HmmMahout trainBaumWelch(int[] seq, int numStates, int numSymbols, int numIterations) throws TotalADSGeneralException {
try {
HmmMahout hmm = new HmmMahout();
hmm.initializeHMMWithCustomizeInitialValues(numSymbols, numStates);
hmm.learnUsingBaumWelch(numIterations, seq);
return hmm;
} catch (Exception ex) {
if (fNameToID.getSize() > numSymbols) {
throw new TotalADSGeneralException(Messages.HiddenMarkovModel_EventsOverlaodMsg
+ Messages.HiddenMarkovModel_HMMErrorMsg);
}
throw new TotalADSGeneralException(ex);
}
}
/*
* (non-Javadoc)
*
* @see
* org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#validate
* (org.eclipse.tracecompass.totalads.readers.ITraceIterator,
* java.lang.String,
* org.eclipse.tracecompass.totalads.dbms.IDataAccessObject,
* java.lang.Boolean,
* org.eclipse.tracecompass.totalads.algorithms.IAlgorithmOutStream)
*/
@Override
public void validate(ITraceIterator trace, String database, IDataAccessObject dataAccessObject,
Boolean isLastTrace, IAlgorithmOutStream outStream) throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {
if (trace == null || isLastTrace == null || database == null || dataAccessObject == null || outStream == null) {
throw new TotalADSGeneralException(Messages.HiddenMarkovModel_NullArguments);
}
int winWidth = 0;
Double logThreshold;
String[] options = fHmm.loadSettings(database, dataAccessObject);
logThreshold = Double.parseDouble(options[7]);
fSeqLength = Integer.parseInt(options[9]);
LinkedList<Integer> newSequence = new LinkedList<>();
outStream.addOutputEvent(Messages.HiddenMarkovModel_ValidationStart);
outStream.addNewLine();
Boolean isValidated = false;
outStream.addOutputEvent(Messages.HiddenMarkovModel_SequenceEvalMsg);
outStream.addNewLine();
String event = null;
while (trace.advance()) {
event = trace.getCurrentEvent();
newSequence.add(fNameToID.getId(event));
winWidth++;
isValidated = false;
if (winWidth >= fSeqLength) {
isValidated = true;
winWidth--;
int[] seq = new int[fSeqLength];
for (int i = 0; i < newSequence.size(); i++) {
seq[i] = newSequence.get(i);
}
// searching and adding to db
logThreshold = validationEvaluation(outStream, logThreshold, seq);
newSequence.remove(0);
}
}
if (!isValidated) {
int[] seq = new int[fSeqLength];
for (int i = 0; i < newSequence.size(); i++) {
seq[i] = newSequence.get(i);
}
newSequence.clear();// clear memory
logThreshold = validationEvaluation(outStream, logThreshold, seq);
}
options[7] = logThreshold.toString();
outStream.addOutputEvent(Messages.HiddenMarkovModel_MinLogLikeliHood + logThreshold.toString());
outStream.addNewLine();
outStream.addOutputEvent(Messages.HiddenMarkovModel_ValidationFinished);
outStream.addNewLine();
fHmm.verifySaveSettingsCreateDb(options, database, dataAccessObject, false, false);
}
/**
* Performs the evaluation for a likelihood of a sequence during validation
*
* @param outStream
* Output stream to display messages
* @param logThreshold
* threshold value
* @param seq
* Sequence
* @return Loglikelihood
*/
private Double validationEvaluation(IAlgorithmOutStream outStream, Double logThreshold, int[] seq) {
Double prob = 1.0;
Double logLikelihood = logThreshold;
try {
prob = fHmm.observationLikelihood(seq);
} catch (Exception ex) {
outStream.addOutputEvent(Messages.HiddenMarkovModel_ReTrainHMM);
outStream.addNewLine();
}
if (prob < logLikelihood) {
logLikelihood = prob;
// console.printTextLn(Arrays.toString(seq));
outStream.addOutputEvent("Min Log Likelihood Threshold: " + logThreshold.toString()); //$NON-NLS-1$
outStream.addNewLine();
}
return logLikelihood;
}
/*
* (non-Javadoc)
*
* @see
* org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#test
* (org.eclipse.tracecompass.totalads.readers.ITraceIterator,
* java.lang.String,
* org.eclipse.tracecompass.totalads.dbms.IDataAccessObject,
* org.eclipse.tracecompass.totalads.algorithms.IAlgorithmOutStream)
*/
@Override
public Results test(ITraceIterator trace, String database, IDataAccessObject dataAccessObject, IAlgorithmOutStream outputStream) throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {
if (trace == null || database == null || dataAccessObject == null || outputStream == null) {
throw new TotalADSGeneralException(Messages.HiddenMarkovModel_NullArguments);
}
int winWidth = 0;
String[] options;
if (!fIsTestInitialized) {
fHmm = new HmmMahout();
options = fHmm.loadSettings(database, dataAccessObject);
fLogThresholdTest = Double.parseDouble(options[7]);
fSeqLength = Integer.parseInt(options[9]);
fHmm.loadHmm(dataAccessObject, database);
fNameToID.loadMap(dataAccessObject, database);
fTestNameToIDSize = fNameToID.getSize();
fIsTestInitialized = true;
}
Results results = new Results();
LinkedList<Integer> newSequence = new LinkedList<>();
Boolean isTested = false;
fTotalTestTraces++;
String event = null;
outputStream.addOutputEvent(Messages.HiddenMarkovModel_SequenceExtrractionMsg);
outputStream.addNewLine();
int seqCount = 1;
fTestTraceMinThreshold = 0.0;
while (trace.advance()) {
event = trace.getCurrentEvent();
newSequence.add(fNameToID.getId(event));
winWidth++;
isTested = false;
if (winWidth >= fSeqLength) {
isTested = true;
winWidth--;
int[] seq = new int[fSeqLength];
for (int i = 0; i < newSequence.size(); i++) {
seq[i] = newSequence.get(i);
}
if (seqCount % 10000 == 0) {
outputStream.addOutputEvent(NLS.bind(Messages.HiddenMarkovModel_SpecificSeq, seqCount));
outputStream.addNewLine();
}
if (testEvaluation(results, fLogThresholdTest, seq) == true) {
break;
}
newSequence.remove(0);
seqCount++;
}
}
if (!isTested) {// when it is the last sequence
int[] seq = new int[fSeqLength];
for (int i = 0; i < newSequence.size(); i++) {
seq[i] = newSequence.get(i);
}
newSequence.clear();// clear memory
testEvaluation(results, fLogThresholdTest, seq);
}
String logLikelihoodValue = ""; //$NON-NLS-1$
if (fTestTraceMinThreshold != 0.0) {
logLikelihoodValue = fTestTraceMinThreshold.toString();
} else {
logLikelihoodValue = "NA"; //$NON-NLS-1$
}
results.setDetails("\nLog Likelihood: " + logLikelihoodValue + "\n"); //$NON-NLS-1$ //$NON-NLS-2$
outputStream.addOutputEvent("Log Likelihood: " + logLikelihoodValue); //$NON-NLS-1$
outputStream.addNewLine();
outputStream.addOutputEvent(Messages.HiddenMarkovModel_Anomaly + results.getAnomaly());
outputStream.addNewLine();
if (results.getAnomaly() == true) {
fTotalTestAnomalies++;
}
outputStream.addOutputEvent(Messages.HiddenMarkovModel_Finish);
outputStream.addNewLine();
return results;
}
/**
* Helper function for testing
*
* @param result
* Results
* @param logThreshold
* log likelihood
* @param seq
* sequence
* @return Return true if anomaly, else returns false
*/
private boolean testEvaluation(Results result, Double logThreshold, int[] seq) {
Double loglikelihood = 1.0;
Double logThresholdValue = logThreshold;
try {
loglikelihood = fHmm.observationLikelihood(seq);
} catch (Exception ex) {
result.setAnomaly(true);
if (fNameToID.getSize() > fTestNameToIDSize) {
Integer diff = fNameToID.getSize() - fTestNameToIDSize;
if (diff > 100) {
result.setDetails(Messages.HiddenMarkovModel_AdditionalEvents);
} else {
result.setDetails(Messages.HiddenMarkovModel_UnkownEvents);
}
int eventCount = 0;
for (int i = fTestNameToIDSize; i < fTestNameToIDSize + diff; i++) {// All
// these
// events
// are
// unknown
result.setDetails(fNameToID.getKey(i) + ", "); //$NON-NLS-1$
eventCount++;
if ((eventCount) % 10 == 0) {
result.setDetails("\n"); //$NON-NLS-1$
}
}
}
// fTotalTestAnomalies++;
return true;
}
if (loglikelihood < fTestTraceMinThreshold) {
fTestTraceMinThreshold = loglikelihood;
}
if (loglikelihood < logThresholdValue) {
logThresholdValue = loglikelihood;
result.setDetails(Messages.HiddenMarkovModel_AnomalousPatterns);
int firstRange = 10;
if (seq.length < 10) {
firstRange = seq.length;
}
for (int id = 0; id < firstRange; id++) {
result.setDetails(fNameToID.getKey(seq[id]) + ", "); //$NON-NLS-1$
}
int secondRange = seq.length / 2;
if (secondRange + 10 < seq.length) {
result.setDetails("\n.........................................................\n"); //$NON-NLS-1$
for (int id = secondRange; id < secondRange + 10; id++) {
result.setDetails(fNameToID.getKey(seq[id]) + ", "); //$NON-NLS-1$
}
}
int thirdRange = seq.length;
if (thirdRange - 10 > secondRange + 10) {
result.setDetails("\n.........................................................\n"); //$NON-NLS-1$
for (int id = secondRange; id < secondRange + 10; id++) {
result.setDetails(fNameToID.getKey(seq[id]) + ", "); //$NON-NLS-1$
}
}
result.setAnomaly(true);
return true;
}
result.setAnomaly(false);
return false;
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* getTotalAnomalyPercentage()
*/
@Override
public Double getTotalAnomalyPercentage() {
Double anomalyPercentage = (fTotalTestAnomalies / fTotalTestTraces) * 100;
return anomalyPercentage;
}
@Override
public Chart graphicalResults(ITraceIterator traceIterator) {
return null;
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* createInstance()
*/
@Override
public IDetectionAlgorithm createInstance() {
return new HiddenMarkovModel();
}
/*
* (non-Javadoc)
*
* @see
* org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#getName
* ()
*/
@Override
public String getName() {
return "Hidden Markov Model (HMM)"; //$NON-NLS-1$
}
/*
* (non-Javadoc)
*
* @see
* org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#getAcronym
* ()
*/
@Override
public String getAcronym() {
return "HMM"; //$NON-NLS-1$
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* getDescription()
*/
@Override
public String getDescription() {
return Messages.HiddenMarkovModel_Description;
}
/*
* (non-Javadoc)
*
* @see org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm#
* isOnlineLearningSupported()
*/
@Override
public boolean isOnlineLearningSupported() {
return false;
}
}