/*********************************************************************************************
* Copyright (c) 2014-2015 Software Behaviour Analysis Lab, Concordia University, Montreal, Canada
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of Eclipse Public License v1.0 License which
* accompanies this distribution, and is available at http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Syed Shariyar Murtaza -- Initial design and implementation
**********************************************************************************************/
package org.eclipse.tracecompass.totalads.algorithms;
import java.io.File;
import java.util.HashMap;
import java.util.concurrent.TimeUnit;
import org.eclipse.tracecompass.internal.totalads.readers.ctfreaders.CTFLTTngSysCallTraceReader;
import org.eclipse.tracecompass.internal.totalads.ssh.SSHConnector;
import org.eclipse.tracecompass.totalads.dbms.DBMSFactory;
import org.eclipse.tracecompass.totalads.dbms.IDataAccessObject;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSDBMSException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSGeneralException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSNetException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSReaderException;
import org.eclipse.tracecompass.totalads.readers.ITraceIterator;
import org.eclipse.tracecompass.totalads.readers.ITraceTypeReader;
import org.eclipse.tracecompass.totalads.algorithms.AlgorithmFactory;
import org.eclipse.tracecompass.totalads.algorithms.IAlgorithmOutStream;
import org.eclipse.tracecompass.totalads.algorithms.IAlgorithmUtilityResultsListener;
import org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm;
import org.eclipse.tracecompass.totalads.algorithms.Messages;
import org.eclipse.tracecompass.totalads.algorithms.Results;
import org.eclipse.osgi.util.NLS;
import org.eclipse.tracecompass.totalads.readers.TraceTypeFactory;
import com.google.common.base.Stopwatch;
/**
* This is the utility class to execute algorithms by implementing common
* recurring tasks required by all the algorithms.
*
* @author <p>
* Syed Shariyar Murtaza justsshary@hotmail.com
* </p>
*
*/
public class AlgorithmUtility {
private static volatile boolean fIsExecuting = true;
private AlgorithmUtility() {
}
/**
* Creates a model in the database with training settings
*
* @param modelName
* Model name
* @param algorithm
* Algorithm type
* @param trainingSettings
* Training Settings
* @throws TotalADSDBMSException
* An exception related to DBMS
* @throws TotalADSGeneralException
* An exception related to validation of parameters
*/
public static void createModel(String modelName, IDetectionAlgorithm algorithm, String[] trainingSettings) throws TotalADSDBMSException, TotalADSGeneralException {
if (modelName == null || modelName.isEmpty()) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyModel);
}
IDataAccessObject dao = DBMSFactory.INSTANCE.getDataAccessObject();
if (dao == null || !dao.isConnected()) {
throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
}
if (algorithm == null) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullAlgorithm);
}
String model = modelName + "_" + algorithm.getAcronym(); //$NON-NLS-1$
model = model.toUpperCase();
algorithm.initializeModelAndSettings(model, dao, trainingSettings);
}
/**
* Returns the algorithm for a given model name
*
* @param modelName
* Name of the model
* @return An object of type IDetectionAlgorithm
* @throws TotalADSGeneralException
* An exception related to validation of parameters
*/
public static IDetectionAlgorithm getAlgorithmFromModelName(String modelName) throws TotalADSGeneralException {
if (modelName == null) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullModel);
}
String[] modelParts = modelName.split("_"); //$NON-NLS-1$
if (modelParts == null || modelParts.length < 2) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_InvalidModel);
}
String algorithmAcronym = modelParts[1];
IDetectionAlgorithm algorithm = AlgorithmFactory.getInstance().getAlgorithmByAcronym(algorithmAcronym);
if (algorithm == null) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_InvalidModelTotalADS);
}
return algorithm;
}
// /////////////////////////////////////////////////////////////////////////////////////
// //////// Training and Validation
// ///////////////////////////////////////////////////////////////////////////////////
/**
* This function trains and validate models
*
* @param trainDirectory
* Train Directory
* @param validationDirectory
* Validation Directory
* @param traceReader
* Trace Reader
* @param modelsNames
* Names of models as an array
* @param outStream
* Output stream where the algorithm would display its output
* @throws TotalADSGeneralException
* An exception related to validation of parameters
* @throws TotalADSDBMSException
* An exception related to DBMS
* @throws TotalADSReaderException
* An exception related to the trace reader
*/
public static void trainAndValidateModels(String trainDirectory, String validationDirectory, ITraceTypeReader traceReader,
String[] modelsNames, IAlgorithmOutStream outStream)
throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {
if (trainDirectory == null || validationDirectory == null || traceReader == null
|| modelsNames == null || outStream == null) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullArguments);
}
if (trainDirectory.isEmpty() || validationDirectory.isEmpty()) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyDirectories);
}
IDataAccessObject dataAcessObject = DBMSFactory.INSTANCE.getDataAccessObject();
if (!dataAcessObject.isConnected()) {
throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
}
for (int i = 0; i < modelsNames.length; i++) {
if (dataAcessObject.datbaseExists(modelsNames[i]) == false) {
throw new TotalADSDBMSException(NLS.bind(Messages.AlgorithmUtility_NoDBofTypeFound, modelsNames[i]));
}
}
Stopwatch stopwatch = Stopwatch.createStarted();
for (int i = 0; i < modelsNames.length; i++) {
Boolean isLastTrace = false;
String modelName = modelsNames[i];
outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_ModelingOn, modelName));
outStream.addNewLine();
// //////////////////
// /File verifications of traces
// /////////////////
// Check for valid trace type reader and training traces before
// creating a database
// Get a file handler
File fileList[] = getDirectoryHandler(trainDirectory, traceReader);
try (ITraceIterator it = traceReader.getTraceIterator(fileList[0])) {
} catch (TotalADSReaderException ex) {
stopwatch.stop();
String message = Messages.AlgorithmUtility_InvalidTrainingTraces + ex.getMessage();
throw new TotalADSGeneralException(message);
}
// Check for valid trace type reader and validation traces before
// creating a database
File validationFileList[] = getDirectoryHandler(validationDirectory, traceReader);
try (ITraceIterator it = traceReader.getTraceIterator(validationFileList[0]);) {
} catch (TotalADSReaderException ex) {
stopwatch.stop();
String message = Messages.AlgorithmUtility_InvalidValidationTraces + ex.getMessage();
throw new TotalADSGeneralException(message);
}
// /////////
// Start training
// ////////
outStream.addOutputEvent(Messages.AlgorithmUtility_ModelTraining);
outStream.addNewLine();
IDetectionAlgorithm algorithm = getAlgorithmFromModelName(modelName);
for (int trcCnt = 0; trcCnt < fileList.length; trcCnt++) {
if (trcCnt == fileList.length - 1) {
isLastTrace = true;
}
// Get the trace
try (ITraceIterator trace = traceReader.getTraceIterator(fileList[trcCnt])) {
outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_CurrentTrainingTrace, (trcCnt + 1), fileList[trcCnt].getName()));
outStream.addNewLine();
algorithm.train(trace, isLastTrace, modelName, dataAcessObject, outStream);
}
// Check if user has asked to stop modeling
if (Thread.currentThread().isInterrupted()) {
break;
}
}
// Start validation
validateModels(validationFileList, traceReader, algorithm, modelName, outStream, dataAcessObject);
// Check if user has asked to stop modeling
if (Thread.currentThread().isInterrupted()) {
break;
}
}
stopwatch.stop();
Long elapsedMins = stopwatch.elapsed(TimeUnit.MINUTES);
Long elapsedSecs = stopwatch.elapsed(TimeUnit.SECONDS);
String msg = NLS.bind(Messages.AlgorithmUtility_TotalTime, elapsedMins, elapsedSecs);
outStream.addOutputEvent(msg);
outStream.addNewLine();
}
/**
* This functions validates a model for a given database of that model
*
* @param fileList
* Array of files
* @param traceReader
* trace reader
* @param algorithm
* Algorithm object
* @param database
* Database name
* @param outStream
* console object
* @throws TotalADSGeneralException
* An exception related to validation of parameters
* @throws TotalADSReaderException
* An exception related to the trace reader
* @throws TotalADSDBMSException
* An exception related to the DBMS
*/
private static void validateModels(File[] fileList, ITraceTypeReader traceReader, IDetectionAlgorithm algorithm,
String database, IAlgorithmOutStream outStream, IDataAccessObject dao) throws TotalADSGeneralException, TotalADSReaderException,
TotalADSDBMSException {
// process now
outStream.addOutputEvent(Messages.AlgorithmUtility_Validation);
outStream.addNewLine();
Boolean isLastTrace = false;
for (int trcCnt = 0; trcCnt < fileList.length; trcCnt++) {
// Check if user has asked to stop modeling
if (Thread.currentThread().isInterrupted()) {
break;
}
// get the trace
if (trcCnt == fileList.length - 1) {
isLastTrace = true;
}
try (ITraceIterator trace = traceReader.getTraceIterator(fileList[trcCnt])) {
outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_CurrentValidationTrace, (trcCnt + 1), fileList[trcCnt].getName()));
outStream.addNewLine();
algorithm.validate(trace, database, dao, isLastTrace, outStream);
}
}
}
// //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// / Test models
// /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* Tests algorithms against a set of traces in a folder
*
* @param testDirectory
* Test directory
* @param traceReader
* Trace reader
* @param models
* Model names
* @param outStream
* Output stream to print the output
* @param resultListener
* The resultListener object will receives messages about the
* results in real time
* @return Model names and the total number of anomalies found in the test
* folder for each model
* @throws TotalADSGeneralException
* General exception (usually validation errors)
* @throws TotalADSReaderException
* Exception related to trace reading
* @throws TotalADSDBMSException
* Exception related to database
*
*/
public static HashMap<String, Double> testModels(String testDirectory, ITraceTypeReader traceReader,
String[] models, IAlgorithmOutStream outStream, IAlgorithmUtilityResultsListener resultListener) throws TotalADSGeneralException, TotalADSReaderException, TotalADSDBMSException {
// First verify selections
Integer totalFiles;
if (testDirectory == null || traceReader == null || models == null || outStream == null ||
resultListener == null) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullArguments);
}
IDataAccessObject dao = DBMSFactory.INSTANCE.getDataAccessObject();
if (!dao.isConnected()) {
throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
}
for (int i = 0; i < models.length; i++) {
if (dao.datbaseExists(models[i]) == false) {
throw new TotalADSDBMSException(NLS.bind(Messages.AlgorithmUtility_NoDBofTypeFound, models[i]));
}
}
if (testDirectory.isEmpty()) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyTestDir);
}
// Get a file and a db handler
File fileList[] = getDirectoryHandler(testDirectory, traceReader);
// Check for valid trace type reader and traces before creating a
// fDatabase
try (ITraceIterator it = traceReader.getTraceIterator(fileList[0])) {
} catch (TotalADSReaderException ex) {
// this is just a validation error, cast it to UI exception
String message = NLS.bind(Messages.Algorithms_InvalidTrace, ex.getMessage());
throw new TotalADSGeneralException(message);
}
// Second, get all the algorithm instances related to models
IDetectionAlgorithm[] algorithm = new IDetectionAlgorithm[models.length];
for (int i = 0; i < models.length; i++) {
algorithm[i] = getAlgorithmFromModelName(models[i]);
}
// Third, start testing
totalFiles = fileList.length;
HashMap<String, Double> modelsAndAnomalyCount = new HashMap<>();
int anomCount = 0;
// for each trace
for (int trcCnt = 0; trcCnt < totalFiles; trcCnt++) {// totalFiles
outStream.addOutputEvent(NLS.bind(Messages.Algorithms_TraceCountMessage, trcCnt, fileList[trcCnt]));
outStream.addNewLine();
// for each selected model
HashMap<String, Results> modelResults = new HashMap<>();
final String traceName = fileList[trcCnt].getName();
for (int modelCnt = 0; modelCnt < models.length; modelCnt++) {
outStream.addOutputEvent(NLS.bind(Messages.Algorithms_ModelEval, models[modelCnt]));
outStream.addNewLine();
try (ITraceIterator trace =
traceReader.getTraceIterator(fileList[trcCnt])) {// get
// the
// trace
Results results = algorithm[modelCnt].test(trace, models[modelCnt], dao, outStream);
modelResults.put(models[modelCnt], results);
}
// get total anomalies so far for each instance of the algorithm
Double totalAnoms = algorithm[modelCnt].getTotalAnomalyPercentage();
modelsAndAnomalyCount.put(models[modelCnt], totalAnoms);
// update the listener about the results of te trace
resultListener.listenTestResults(traceName, modelResults, modelsAndAnomalyCount);
// Check if Executor has been stopped by the user
if (Thread.currentThread().isInterrupted()) {
break;
}
}
// Check if Executor has been stopped by the user
if (Thread.currentThread().isInterrupted()) {
break;
}
}
outStream.addNewLine();
outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_anomalies, anomCount));
outStream.addNewLine();
return modelsAndAnomalyCount;
}
// /////////////////////////////////////////////////////////////////////////////////////////////////
// /Online/Live Testing and Training
// ///////////////////////////////////////////////////////////////////////////////////////////////
/**
* Function to start online modeling. Once started, it will continuously
* collect a trace from a remote system, and train and test models. This
* function must be launched in a new thread otherwise it will block the
* running thread. To stop it call stopOnlineModeling function.
*
* @param userAtHost
* User and host name in the format user@hostname
* @param password
* Password
* @param port
* Port number
* @param snapshotDuration
* Time to collect trace
* @param intervalBetweenSnapshots
* Interval between the start of collection of two traces
* @param directoryToStoreTraces
* directory to store traces
* @param models
* Model (database) names
* @param outStream
* Output stream to display messages of processing
* @param resultListener
* Result listener to display results in the real time
* @param isTrainAndEval
* True for both training and testing, or false for only testing
* @throws TotalADSGeneralException
* Validation exception of parameters
*/
public static synchronized void startOnlineModeling(String userAtHost, String password, Integer port,
Integer snapshotDuration, Integer intervalBetweenSnapshots,
String directoryToStoreTraces, String[] models, IAlgorithmOutStream outStream,
IAlgorithmUtilityResultsListener resultListener, Boolean isTrainAndEval)
throws TotalADSGeneralException {
if (userAtHost == null || password == null || port == null || snapshotDuration == null ||
intervalBetweenSnapshots == null || directoryToStoreTraces == null ||
models == null || outStream == null || resultListener == null || isTrainAndEval == null) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullArguments);
}
SSHConnector ssh = new SSHConnector();
outStream.addOutputEvent(Messages.AlgorithmUtility_StartSsh);
outStream.addNewLine();
try {
// Connecting to SSH
ssh.openSSHConnectionUsingPassword(userAtHost, password, port, outStream, snapshotDuration);
ITraceTypeReader traceReader = TraceTypeFactory.getInstance().getCTFKernelReaderOrSimpleTextReader(true);
while (fIsExecuting) {
String tracePath = ssh.collectATrace(directoryToStoreTraces);
outStream.addOutputEvent(Messages.AlgorithmUtility_StartTest);
outStream.addNewLine();
testModels(tracePath, traceReader, models, outStream, resultListener);
if (isTrainAndEval) {// if it is both training and evaluation
outStream.addOutputEvent(Messages.AlgorithmUtility_StartTrain);
outStream.addNewLine();
trainModels(traceReader, tracePath, models, outStream);
}
}
} catch (TotalADSGeneralException | TotalADSReaderException |
TotalADSNetException | TotalADSDBMSException ex) {
outStream.addOutputEvent(ex.getMessage());
outStream.addNewLine();
} finally {
ssh.close();
}
}
/**
* Stops Online Modeling that was started by calling startOnlineModeling
*
* @param outStream
* OuputStream to display processing messages
*/
public static synchronized void stopOnlineModeling(IAlgorithmOutStream outStream) {
fIsExecuting = false;
outStream.addOutputEvent(Messages.AlgorithmUtility_StopMonitor);
outStream.addNewLine();
outStream.addOutputEvent(Messages.AlgorithmUtility_TimeToStopMonitor);
outStream.addNewLine();
outStream.addOutputEvent(Messages.AlgorithmUtility_Wait);
outStream.addNewLine();
}
/**
* Trains already existing models on a trace. Used by startOnlineModeling
* function
*
* @param traceReader
* Trace reader
* @param tracePath
* Trace Path
* @param models
* Model list
* @param outStream
* Output stream to display messages
* @throws TotalADSGeneralException
* Validation exception
* @throws TotalADSDBMSException
* DBMS related exception
* @throws TotalADSReaderException
* Reader related exception
*/
private static void trainModels(ITraceTypeReader traceReader, String tracePath, String[] models, IAlgorithmOutStream outStream) throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {
IDataAccessObject dao = DBMSFactory.INSTANCE.getDataAccessObject();
if (!dao.isConnected()) {
throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
}
for (int i = 0; i < models.length; i++) {
// Check if Executor has been stopped by the user
if (Thread.currentThread().isInterrupted()) {
break;
}
IDetectionAlgorithm algorithm = getAlgorithmFromModelName(models[i]);
try (ITraceIterator traceIterator =
traceReader.getTraceIterator(new File(tracePath))) {
outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_UpdateModelSpecific, models[i], algorithm.getName()));
outStream.addNewLine();
outStream.addOutputEvent(Messages.AlgorithmUtility_UpdateModel);
outStream.addNewLine();
algorithm.train(traceIterator, true, models[i], dao, outStream);
}
}
}
// //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// File Handling functions
// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/**
*
* @param directory
* The name of the directory
* @param traceReader
* An object of the trace reader
* @return An array list of traces suited for the appropriate type
* @throws TotalADSGeneralException
* An exception related to the validation of parameters
*/
private static File[] getDirectoryHandler(String directory, ITraceTypeReader traceReader) throws TotalADSGeneralException {
File traces = new File(directory);
CTFLTTngSysCallTraceReader kernelReader = new CTFLTTngSysCallTraceReader();
if (traceReader.getAcronym().equals(kernelReader.getAcronym())) {
return getDirectoryHandlerforLTTngTraces(traces);
}
return getDirectoryHandlerforTextTraces(traces);
}
/**
* Get an array of trace list for a directory or just one file handler if
* there is only one file
*
* @param traces
* File object representing traces
* @return the file handler to the correct path
* @throws TotalADSGeneralException
* An exception related to validation of parameters
*/
private static File[] getDirectoryHandlerforTextTraces(File traces) throws TotalADSGeneralException {
File[] fileList;
if (traces.isDirectory()) { // if it is a directory return the list of
// all files
Boolean isAllFiles = false, isAllFolders = false;
fileList = traces.listFiles();
for (File file : fileList) {
if (file.isDirectory()) {
isAllFolders = true;
} else if (file.isFile()) {
isAllFiles = true;
}
if (isAllFolders) {
throw new TotalADSGeneralException(NLS.bind(Messages.AlgorithmUtility_FolderContainsDirectories, traces.getName()));
}
}
if (!isAllFiles && !isAllFolders) {
throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyDir + traces.getName());
}
}
else {// if it is a single file return the single file; however, this
// code will never be reached
// as in GUI we are only using a directory handle, but if in
// future we decide to make changes then this could come handy
fileList = new File[1];
fileList[0] = traces;
}
return fileList;
}
/**
* Gets an array of list of directories
*
* @param traces
* File object representing traces
* @return Handler to the correct path of files
* @throws TotalADSGeneralException
* An exception related to validation of parameters
*/
private static File[] getDirectoryHandlerforLTTngTraces(File traces) throws TotalADSGeneralException {
if (traces.isDirectory()) {
File[] fileList = traces.listFiles();
File[] fileHandler;
Boolean isAllFiles = false, isAllFolders = false;
for (File file : fileList) {
if (file.isDirectory()) {
if (!file.getName().equalsIgnoreCase("index")) { //$NON-NLS-1$
isAllFolders = true;
}
} else if (file.isFile()) {
isAllFiles = true;
}
if (isAllFiles && isAllFolders) {
throw new TotalADSGeneralException(NLS.bind(Messages.AlgorithmUtility_FolderContainsMixture, traces.getName()));
}
}
// if it has reached this far
if (!isAllFiles && !isAllFolders) {
throw new TotalADSGeneralException(NLS.bind(Messages.AlgorithmUtility_EmptyDir, traces.getName()));
} else if (isAllFiles) { // return the name of folder as a trace
fileHandler = new File[1];
fileHandler[0] = traces;
} else {
fileHandler = fileList;
}
return fileHandler;
}
// this code may not be reached
throw new TotalADSGeneralException(NLS.bind(Messages.AlgorithmUtility_SelectFolder, traces.getName()));
}
}