/*
* beymani: Outlier and anamoly detection
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.beymani.predictor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.chombo.storm.Cache;
import org.chombo.storm.MessageQueue;
import org.chombo.util.Pair;
import redis.clients.jedis.Jedis;
/**
* Predictor based on markov model
* @author pranab
*
*/
public class MarkovModelPredictor extends ModelBasedPredictor {
private List<String> states;
private double[][] stateTranstionProb;
private Map<String, List<String>> records = new HashMap<String, List<String>>();
private boolean localPredictor;
private int stateSeqWindowSize;
private int stateOrdinal;
private enum DetectionAlgorithm {
MissProbability,
MissRate,
EntropyReduction
};
private DetectionAlgorithm detectionAlgorithm;
private Map<String, Pair<Double, Double>> globalParams;
private double metricThreshold;
private int numStates;
private int[] maxStateProbIndex;
private double[] entropy;
private String outputQueue;
private MessageQueue outQueue;
private Cache cache;
private static final Logger LOG = Logger.getLogger(MarkovModelPredictor.class);
private boolean debugOn;
/**
* @param conf
*/
public MarkovModelPredictor(Map conf) {
if (conf.get("debug").toString().equals("on")) {
LOG.setLevel(Level.DEBUG);;
debugOn = true;
}
outputQueue = conf.get("redis.output.queue").toString();
outQueue = MessageQueue.createMessageQueue(conf, outputQueue);
cache = Cache.createCache(conf);
//model
String modelKey = conf.get("redis.markov.model.key").toString();
String model = cache.get(modelKey);
Scanner scanner = new Scanner(model);
int lineCount = 0;
int row = 0;
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
if (0 == lineCount) {
//states
String[] items = line.split(",");
states = Arrays.asList(items);
numStates = items.length;
stateTranstionProb = new double[numStates][numStates];
LOG.info("numStates:" + numStates);
} else {
//populate state transtion probability
deseralizeTableRow(stateTranstionProb, line, ",", row, numStates);
++row;
}
++lineCount;
}
scanner.close();
if (debugOn){
for (int i = 0; i < numStates; ++i) {
for (int j = 0; j < numStates; ++j) {
LOG.info("state trans prob[" + i + "][" + j +"]=" + stateTranstionProb[i][j]);
}
}
}
localPredictor = Boolean.parseBoolean(conf.get("local.predictor").toString());
if (localPredictor) {
stateSeqWindowSize = Integer.parseInt(conf.get("state.seq.window.size").toString());
if (debugOn)
LOG.info("local predictor window size:" + stateSeqWindowSize );
} else {
stateSeqWindowSize = 5;
globalParams = new HashMap<String, Pair<Double, Double>>();
}
//state value ordinal within record
stateOrdinal = Integer.parseInt(conf.get("state.ordinal").toString());
//detection algoritm
String algorithm = conf.get("detection.algorithm").toString();
if (debugOn)
LOG.info("detection algorithm:" + algorithm);
if (algorithm.equals("missProbability")) {
detectionAlgorithm = DetectionAlgorithm.MissProbability;
} else if (algorithm.equals("missRate")) {
detectionAlgorithm = DetectionAlgorithm.MissRate;
//max probability state index
maxStateProbIndex = new int[numStates];
for (int i = 0; i < numStates; ++i) {
int maxProbIndex = -1;
double maxProb = -1;
for (int j = 0; j < numStates; ++j) {
if (stateTranstionProb[i][j] > maxProb) {
maxProb = stateTranstionProb[i][j];
maxProbIndex = j;
}
}
maxStateProbIndex[i] = maxProbIndex;
}
} else if (algorithm.equals("entropyReduction")) {
detectionAlgorithm = DetectionAlgorithm.EntropyReduction;
//entropy per source state
entropy = new double[numStates];
for (int i = 0; i < numStates; ++i) {
double ent = 0;
for (int j = 0; j < numStates; ++j) {
ent += -stateTranstionProb[i][j] * Math.log(stateTranstionProb[i][j]);
}
entropy[i] = ent;
}
} else {
//error
}
//metric threshold
metricThreshold = Double.parseDouble(conf.get("metric.threshold").toString());
}
/**
* @param table
* @param data
* @param delim
* @param row
* @param numCol
*/
public void deseralizeTableRow(double[][] table, String data, String delim, int row, int numCol) {
String[] items = data.split(delim);
if (items.length != numCol) {
throw new IllegalArgumentException(
"Row serialization failed, number of tokens in string does not match with number of columns");
}
for (int c = 0; c < numCol; ++c) {
table[row][c] = Double.parseDouble(items[c]);
}
}
@Override
public double execute(String entityID, String record) {
double score = 0;
List<String> recordSeq = records.get(entityID);
if (null == recordSeq) {
recordSeq = new ArrayList<String>();
records.put(entityID, recordSeq);
}
//add and maintain size
recordSeq.add(record);
if (recordSeq.size() > stateSeqWindowSize) {
recordSeq.remove(0);
}
String[] stateSeq = null;
if (localPredictor) {
//local metric
if (debugOn)
LOG.info("local metric, seq size " + recordSeq.size());
if (recordSeq.size() == stateSeqWindowSize) {
stateSeq = new String[stateSeqWindowSize];
for (int i = 0; i < stateSeqWindowSize; ++i) {
stateSeq[i] = recordSeq.get(i).split(",")[stateOrdinal];
}
score = getLocalMetric( stateSeq);
}
} else {
//global metric
if (debugOn)
LOG.info("global metric");
if (recordSeq.size() >= 2) {
stateSeq = new String[2];
for (int i = stateSeqWindowSize - 2, j =0; i < stateSeqWindowSize; ++i) {
stateSeq[j++] = recordSeq.get(i).split(",")[stateOrdinal];
}
Pair<Double,Double> params = globalParams.get(entityID);
if (null == params) {
params = new Pair<Double,Double>(0.0, 0.0);
globalParams.put(entityID, params);
}
score = getGlobalMetric( stateSeq, params);
}
}
//outlier
if (debugOn)
LOG.info("metric " + entityID + ":" + score);
if (score > metricThreshold) {
StringBuilder stBld = new StringBuilder(entityID);
stBld.append(" : ");
for (String st : stateSeq) {
stBld.append(st).append(" ");
}
stBld.append(": ");
stBld.append(score);
outQueue.send(stBld.toString());
}
return score;
}
/**
* @param stateSeq
* @return
*/
private double getLocalMetric(String[] stateSeq) {
double metric = 0;
double[] params = new double[2];
params[0] = params[1] = 0;
if (detectionAlgorithm == DetectionAlgorithm.MissProbability) {
missProbability(stateSeq, params);
} else if (detectionAlgorithm == DetectionAlgorithm.MissRate) {
missRate(stateSeq, params);
} else {
entropyReduction( stateSeq, params);
}
metric = params[0] / params[1];
return metric;
}
/**
* @param stateSeq
* @return
*/
private double getGlobalMetric(String[] stateSeq, Pair<Double,Double> globParams) {
double metric = 0;
double[] params = new double[2];
params[0] = params[1] = 0;
if (detectionAlgorithm == DetectionAlgorithm.MissProbability) {
missProbability(stateSeq, params);
} else if (detectionAlgorithm == DetectionAlgorithm.MissRate) {
missRate(stateSeq, params);
} else {
entropyReduction( stateSeq, params);
}
globParams.setLeft(globParams.getLeft() + params[0]);
globParams.setRight(globParams.getRight() + params[1]);
metric = globParams.getLeft() / globParams.getRight();
return metric;
}
/**
* @param stateSeq
* @return
*/
private void missProbability(String[] stateSeq, double[] params) {
int start = localPredictor? 1 : stateSeq.length - 1;
for (int i = start; i < stateSeq.length; ++i ){
int prState = states.indexOf(stateSeq[i -1]);
int cuState = states.indexOf(stateSeq[i ]);
if (debugOn)
LOG.info("state prob index:" + prState + " " + cuState);
//add all probability except target state
for (int j = 0; j < states.size(); ++ j) {
if (j != cuState)
params[0] += stateTranstionProb[prState][j];
}
params[1] += 1;
}
if (debugOn)
LOG.info("params:" + params[0] + ":" + params[1]);
}
/**
* @param stateSeq
* @return
*/
private void missRate(String[] stateSeq, double[] params) {
int start = localPredictor? 1 : stateSeq.length - 1;
for (int i = start; i < stateSeq.length; ++i ){
int prState = states.indexOf(stateSeq[i -1]);
int cuState = states.indexOf(stateSeq[i ]);
params[0] += (cuState == maxStateProbIndex[prState]? 0 : 1);
params[1] += 1;
}
}
/**
* @param stateSeq
* @return
*/
private void entropyReduction(String[] stateSeq, double[] params) {
int start = localPredictor? 1 : stateSeq.length - 1;
double entropyWoTragetState = 0;
double entropy = 0;
for (int i = start; i < stateSeq.length; ++i ){
int prState = states.indexOf(stateSeq[i -1]);
int cuState = states.indexOf(stateSeq[i ]);
if (debugOn)
LOG.info("state prob index:" + prState + " " + cuState);
for (int j = 0; j < states.size(); ++ j) {
double pr = stateTranstionProb[prState][j];
double enComp = -pr * Math.log(pr);
//entropy without target state
if (j != cuState) {
entropyWoTragetState += enComp;
}
//full entropy
entropy += enComp;
}
}
params[0] = entropyWoTragetState / entropy;
params[1] = 1;
}
}