package com.cse10.analyzer;
import com.cse10.database.HibernateUtil;
import mltk.core.*;
import mltk.predictor.glm.ElasticNetLearner;
import mltk.predictor.glm.GLM;
import mltk.util.MathUtils;
import org.apache.commons.math.stat.regression.SimpleRegression;
import org.hibernate.Criteria;
import org.hibernate.SQLQuery;
import org.hibernate.Session;
import org.hibernate.Transaction;
import java.math.BigDecimal;
import java.util.*;
/*
* Main predictor class
*/
public class Predictor{
private String table;
private String[] fields;
private String[] quarters;
private PredictorAlgorithm predictorAlgo;
public Predictor(PredictorAlgorithm predictorAlgo, String table, String[] fields){
this.predictorAlgo = predictorAlgo;
this.table = table;
this.fields = fields;
//empty the table
Session session = HibernateUtil.getSessionFactory().openSession();
Transaction tx = session.beginTransaction();
session.createSQLQuery("truncate table "+table).executeUpdate();
tx.commit();
session.close();
}
//predict
public void predict(String[] quarters, String targetQuarter){
int indexToPredict = quarters.length;
this.quarters = quarters;
Session session = HibernateUtil.getSessionFactory().openSession();
List results = getInput();
HashMap<String,Integer> series = getSeriesHolder();
HashMap pre = (HashMap) results.get(0);
series.put((String) pre.get("crime_yearquarter"),((BigDecimal) pre.get("count")).intValue());
//for each time series
for (int i=1; i<results.size(); i++){
HashMap ele = (HashMap) results.get(i);
boolean flag = false;
for (int j=0; j<fields.length; j++){
if (!ele.get(fields[j]).equals(pre.get(fields[j]))){
flag = true;
break;
}
}
if (flag){
int predicted = predictorAlgo.predict(series);
insertToDB(pre,targetQuarter,predicted);
series = getSeriesHolder();
pre = ele;
}
series.put((String) ele.get("crime_yearquarter"),((BigDecimal) ele.get("count")).intValue());
}
}
//get mean square error
public double getMeanSqureError(String predictionTable){
Session session = HibernateUtil.getSessionFactory().openSession();
String sql = "SELECT Sum(val) sum," +
" Count(val) count," +
" Sum(val) / Count(val) err" +
" FROM (SELECT";
for (int i=0; i<fields.length; i++){
sql+=" pd."+fields[i]+",";
}
sql+=" pd.crime_yearquarter, Pow(count1 - Ifnull(count2, 0), 2) val" +
" FROM (SELECT";
for (int i=0; i<fields.length; i++){
sql+=" "+fields[i]+",";
}
sql+= " crime_yearquarter, Sum(`crime_count`) count1" +
" FROM " + predictionTable +
" WHERE crime_year < 2015" +
" AND crime_year > 2013" +
" GROUP BY ";
for (int i=0; i<fields.length; i++){
sql+=" "+fields[i]+",";
}
sql+=" `crime_yearquarter`) pd" +
" LEFT JOIN (SELECT";
for (int i=0; i<fields.length; i++){
sql+=" "+fields[i]+",";
}
sql+=" crime_yearquarter, Sum(`crime_count`) count2" +
" FROM `news_statistics`" +
" WHERE crime_year < 2015" +
" AND crime_year > 2013" +
" GROUP BY";
for (int i=0; i<fields.length; i++){
sql+=" "+fields[i]+",";
}
sql+=" crime_yearquarter) ns" +
" ON ";
sql+=" ns."+fields[0]+" = pd."+fields[0];
for (int i=1; i<fields.length; i++){
sql+=" AND ns."+fields[i]+" = pd."+fields[i];
}
sql+=" AND ns.`crime_yearquarter` = pd.`crime_yearquarter` ) tbl";
SQLQuery query = session.createSQLQuery(sql);
query.setResultTransformer(Criteria.ALIAS_TO_ENTITY_MAP);
List results = query.list();
session.close();
return (double) ((HashMap) results.get(0)).get("err");
}
//get the input
protected List getInput(){
Session session = HibernateUtil.getSessionFactory().openSession();
String fieldNames = fields[0];
for (int i=1; i<fields.length; i++){
fieldNames+=", "+fields[i];
}
String sql = "SELECT "+fieldNames+", crime_yearquarter, sum(crime_count) count" +
" from news_statistics" +
" where crime_yearquarter >= '" + quarters[0] + "' AND crime_yearquarter <= '" + quarters[quarters.length-1] + "'"+
" group by "+fieldNames+", crime_yearquarter" +
" order by "+fieldNames+", crime_yearquarter";
SQLQuery query = session.createSQLQuery(sql);
query.setResultTransformer(Criteria.ALIAS_TO_ENTITY_MAP);
List results = query.list();
session.close();
return results;
}
//setting the predictor algorithm
public void setPredictorAlgorithm (PredictorAlgorithm predictorAlgo){
this.predictorAlgo = predictorAlgo;
}
//save the output
protected void insertToDB(HashMap ele, String key, int count){
Session session = HibernateUtil.getSessionFactory().openSession();
Transaction tx = session.beginTransaction();
String fieldNames = fields[0];
String values = (String) ele.get(fields[0]);
for (int i=1; i<fields.length; i++){
fieldNames+=", "+fields[i];
values+="', '"+(String) ele.get(fields[i]);
}
fieldNames+=", crime_year, crime_yearquarter, crime_count";
session.createSQLQuery(
"INSERT INTO "+table+" ("+fieldNames+") " +
" VALUES ('"+
values+"', '"+
key.substring(0,4)+"', '"+
key+ "', '"+
count + "')").executeUpdate();
tx.commit();
session.close();
}
//get empty series
protected HashMap<String,Integer> getSeriesHolder(){
HashMap<String,Integer> series = new HashMap<String,Integer>();
for (int i=0; i<quarters.length; i++){
series.put(quarters[i],0);
}
return series;
}
}
/*
* interface to implement strategy design pattern
*/
interface PredictorAlgorithm{
public int predict(HashMap<String,Integer> series);
}
/*
* elastic net regression algorithm class
*/
class ENLPredictorAlgorithm implements PredictorAlgorithm{
public int predict(HashMap<String,Integer> series){
int indexToPredict = series.size();
List keys = new ArrayList(series.keySet());
Collections.sort(keys);
List<Attribute> attributes = new ArrayList<>();
Instances instances = new Instances(attributes);
int classIndex = 1;
ElasticNetLearner elasticNetLearner = new ElasticNetLearner();
double prediction = 0.0;
int count;
for ( count=0; count<keys.size(); count++){
double x = (double) count;
double y = (double) series.get(keys.get(count));
String[] data = {Double.toString(x), Double.toString(y)};
Instance instance = parseDenseInstance(data, classIndex);
instances.add(instance);
}
int numAttributes = instances.get(0).getValues().length;
for (int i = 0; i < numAttributes; i++) {
Attribute att = new NumericalAttribute("f" + i);
att.setIndex(i);
attributes.add(att);
}
if (classIndex >= 0) {
assignTargetAttribute(instances);
}
//end of instance creation
//build regressor
GLM glm = elasticNetLearner.buildRegressor(instances, 100, 0.0, 0.0);
//create new instance for prediction
int[] indices = {0};
double[] values = {indexToPredict};
Instance ins = new Instance(indices, values);
//predict the value
prediction = glm.regress(ins);
int output = (int) Math.round(prediction);
return (output>0)?output:0;
}
protected Instance parseDenseInstance(String[] data, int classIndex) {
if (classIndex < 0) {
double[] vector = new double[data.length];
double classValue = Double.NaN;
for (int i = 0; i < data.length; i++) {
vector[i] = Double.parseDouble(data[i]);
}
return new Instance(vector, classValue);
} else {
double[] vector = new double[data.length - 1];
double classValue = Double.NaN;
for (int i = 0; i < data.length; i++) {
double value = Double.parseDouble(data[i]);
if (i < classIndex) {
vector[i] = value;
} else if (i > classIndex) {
vector[i - 1] = value;
} else {
classValue = value;
}
}
return new Instance(vector, classValue);
}
}
protected void assignTargetAttribute(Instances instances) {
boolean isInteger = true;
for (Instance instance : instances) {
if (!MathUtils.isInteger(instance.getTarget())) {
isInteger = false;
break;
}
}
if (isInteger) {
TreeSet<Integer> set = new TreeSet<>();
for (Instance instance : instances) {
double target = instance.getTarget();
set.add((int) target);
}
String[] states = new String[set.size()];
int i = 0;
for (Integer v : set) {
states[i++] = v.toString();
}
instances.setTargetAttribute(new NominalAttribute("target", states));
} else {
instances.setTargetAttribute(new NumericalAttribute("target"));
}
}
}
/*
* linear regression class
*/
class LRPredictorAlgorithm implements PredictorAlgorithm{
public int predict(HashMap<String,Integer> series) {
int indexToPredict = series.size();
List keys = new ArrayList(series.keySet());
Collections.sort(keys);
SimpleRegression simpleRegression = new SimpleRegression();
simpleRegression.clear();
int i;
for (i=0; i<keys.size(); i++){
simpleRegression.addData(i, series.get(keys.get(i)));
}
double intercept = simpleRegression.getIntercept();
double slope = simpleRegression.getSlope();
System.out.println(intercept);
System.out.println(slope);
double prediction = simpleRegression.predict(indexToPredict);
int output = (int) Math.round(prediction);
return (output>0)?output:0;
}
}