package qa.qcri.aidr.predict;
import java.sql.Connection;
import java.sql.Driver;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import org.apache.log4j.Logger;
import org.json.JSONObject;
import qa.qcri.aidr.dbmanager.dto.DocumentDTO;
import qa.qcri.aidr.dbmanager.dto.DocumentNominalLabelDTO;
import qa.qcri.aidr.dbmanager.dto.DocumentNominalLabelIdDTO;
import qa.qcri.aidr.dbmanager.ejb.remote.facade.TaskManagerRemote;
import qa.qcri.aidr.predict.classification.nominal.Model;
import qa.qcri.aidr.predict.classification.nominal.ModelNominalLabelPerformance;
import qa.qcri.aidr.predict.classification.nominal.NominalLabelBC;
import qa.qcri.aidr.predict.common.Helpers;
import qa.qcri.aidr.predict.common.TaggerConfigurationProperty;
import qa.qcri.aidr.predict.common.TaggerConfigurator;
import qa.qcri.aidr.predict.common.TaggerErrorLog;
import qa.qcri.aidr.predict.data.Document;
import qa.qcri.aidr.predict.dbentities.ModelFamilyEC;
import qa.qcri.aidr.predict.dbentities.NominalAttributeEC;
import qa.qcri.aidr.predict.dbentities.NominalLabelEC;
import qa.qcri.aidr.predict.dbentities.TaggerDocument;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import snaq.db.ConnectionPool;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SparseInstance;
/**
* Wrapper class for database communication (both MySQL and Redis).
*
* @author jrogstadius
* @author koushik
*/
public class DataStore {
public static TaskManagerRemote<DocumentDTO, Long> taskManager = null;
private static Logger logger = Logger.getLogger(DataStore.class);
private static final String remoteEJBJNDIName = TaggerConfigurator
.getInstance().getProperty(
TaggerConfigurationProperty.REMOTE_TASK_MANAGER_JNDI_NAME);
private static final long LOG_INTERVAL = Integer
.parseInt(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.LOG_INTERVAL_MINUTES)) * 60 * 1000;
private static int saveNewDocumentsCount = 0;
private static long lastSaveTime = 0;
private static JedisPool jedisPool = null;
private static ConnectionPool mySqlPool = null;
private static int attempts = 0;
private static final int MAX_RECREATE_POOL_ATTEMPTS = 3;
private static Object lockObject = new Object();
private static HashMap<Integer,NominalAttributeEC> attLabels = new HashMap<Integer,NominalAttributeEC>();
public static synchronized void initializeJedisPool() throws Exception {
if (null == jedisPool) {
jedisPool = new JedisPool(new JedisPoolConfig(), TaggerConfigurator
.getInstance().getProperty(
TaggerConfigurationProperty.REDIS_HOST));
logger.info("Initialized jedisPool = " + jedisPool);
} else {
logger.warn("Attempting to initialize an active JedisPool!");
}
}
public static void initDBPools() {
try {
initializeJedisPool();
} catch (Exception e1) {
logger.error("Unable to allocate JEDIS Pool!");
logger.error("Exception", e1);
TaggerErrorLog.sendErrorMail("Redis", "Could not establish Redis connection. " + e1.getMessage());
}
try {
initializeMySqlPool();
} catch (Exception e) {
logger.error("Unable to allocate MySQL Pool!");
logger.error("Exception", e);
TaggerErrorLog.sendErrorMail("Mysql connection", "Could not initialize mysql connection. " + e.getMessage());
}
}
@SuppressWarnings("unchecked")
public static void initTaskManager() {
if (taskManager != null) {
logger.warn("taskManager has already been initialized: " + taskManager
+ ". Hence, skipping taskManager initialization attempt...");
return;
}
// Else initialize taskManager
try {
long startTime = System.currentTimeMillis();
//Properties props = new Properties();
//props.setProperty("java.naming.factory.initial", "com.sun.enterprise.naming.SerialInitContextFactory");
//props.setProperty("java.naming.factory, url.pkgs", "com.sun.enterprise.naming");
//props.setProperty("java.naming.factory.state", "com.sun.corba.ee.impl.presentation.rmi.JNDIStateFactoryImpl");
//props.setProperty("org.omg.CORBA.ORBInitialHost", "localhost");
//props.setProperty("org.omg.CORBA.ORBInitialPort", "3700");
//InitialContext ctx = new InitialContext(props);
InitialContext ctx = new InitialContext();
taskManager = (TaskManagerRemote<DocumentDTO, Long>) ctx.lookup(DataStore.remoteEJBJNDIName);
logger.info("taskManager: " + taskManager + ", time taken to initialize = " + (System.currentTimeMillis() - startTime));
if (taskManager != null) {
logger.info("Success in connecting to remote EJB to initialize taskManager");
}
} catch (NamingException e) {
logger.error("Error in JNDI lookup for initializing remote EJB", e);
}
}
/*
* TODO: Rename all database columns and tables to use underscore_notation.
* Everything was initially created with camelCaseNotation, but apparently
* MySQL has a configuration where everything is forced into lowercase on
* database creation. This resulted in that all queries broke when when
* moving the code to the production server, so the current status is that
* the naming is... ugly.
*/
static class TrainingSampleNotification {
public int crisisID;
public Collection<Integer> attributeIDs;
public TrainingSampleNotification(int crisisID,
Collection<Integer> attributeIDs) {
this.crisisID = crisisID;
this.attributeIDs = attributeIDs;
}
}
/* REDIS */
public static Jedis getJedisConnection() {
try {
if (jedisPool != null)
return jedisPool.getResource();
else {
logger.error("Jedis Pool is NULL!");
initializeJedisPool();
return jedisPool.getResource();
}
} catch (Exception e) {
logger.error("Could not establish Redis connection. Is the Redis server running?");
logger.error("Exception", e);
TaggerErrorLog.sendErrorMail("Redis", "Could not establish Redis connection. " + e.getMessage());
}
return null;
}
public static void close(Jedis resource) {
jedisPool.returnResource(resource);
}
public static void clearRedisPipeline() {
Jedis redis = getJedisConnection();
redis.del(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.REDIS_FOR_CLASSIFICATION_QUEUE));
redis.del(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.REDIS_FOR_EXTRACTION_QUEUE));
redis.del(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.REDIS_FOR_OUTPUT_QUEUE));
redis.del(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.REDIS_LABEL_TASK_WRITE_QUEUE));
redis.del(TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.REDIS_TRAINING_SAMPLE_INFO_QUEUE));
close(redis);
}
public static final int MODEL_ID_ERROR = -1;
/* MYSQL */
public static synchronized void initializeMySqlPool() throws SQLException {
if (null == mySqlPool) {
try {
Class<?> c = Class.forName("com.mysql.jdbc.Driver");
Driver driver = (Driver) c.newInstance();
DriverManager.registerDriver(driver);
mySqlPool = new ConnectionPool("aidr-backend",
10, // min-pool default = 1
40, // max-pool default = 5
200, // max-size default 30
500, // timeout (sec)
TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.MYSQL_PATH),
TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.MYSQL_USERNAME),
TaggerConfigurator.getInstance().getProperty(
TaggerConfigurationProperty.MYSQL_PASSWORD));
logger.info("Initialized mySQLPool = " + mySqlPool);
attempts = 0;
} catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) {
logger.error("Exception when initializing MySQL connection");
logger.error("Exception:", e);
++attempts;
}
} else {
logger.warn("Attempting to initialize an active MySqlPool, attempts = " + attempts);
}
}
public static synchronized Connection getMySqlConnection() throws SQLException {
long timeout = 30000;
Connection con = null;
try {
if (mySqlPool != null) {
con = mySqlPool.getConnection(timeout);
} else {
logger.error("MySql Pool is NULL");
initializeMySqlPool();
con = mySqlPool.getConnection(timeout);
}
if (con != null) {
attempts = 0;
return con;
} else {
if (attempts < MAX_RECREATE_POOL_ATTEMPTS) {
if (!mySqlPool.isReleased() && mySqlPool != null) {
mySqlPool.release();
mySqlPool = null;
++attempts;
}
initializeMySqlPool();
con = mySqlPool.getConnection(timeout);
logger.warn("MySQL Pool reallocated!");
if (null == con) {
logger.error("The created MySQL connection is null even AFTER Reinitializing the MySQL Pool!!! Giving up...");
} else {
attempts = 0; //reset
}
return con;
}
}
} catch (Exception e) {
logger.error("Exception", e);
if (attempts < MAX_RECREATE_POOL_ATTEMPTS) {
if (!mySqlPool.isReleased() && mySqlPool != null) {
mySqlPool.release();
mySqlPool = null;
++attempts;
}
initializeMySqlPool();
con = mySqlPool.getConnection(timeout);
if (con != null) {
attempts = 0; //reset
}
}
}
return con;
}
public static void close(Connection con) {
if (con == null) {
return;
}
try {
con.close();
con = null;
} catch (SQLException e) {
logger.error("Exception when returning MySQL connection", e);
}
}
public static void close(Statement statement) {
if (statement == null) {
return;
}
try {
statement.close();
statement = null;
} catch (SQLException e) {
logger.error("Could not close statement", e);
}
}
public static void close(ResultSet resultset) {
if (resultset == null) {
return;
}
try {
resultset.close();
resultset = null;
} catch (SQLException e) {
logger.error("Could not close resultSet", e);
}
}
public static Integer getNullLabelID(int attributeID) {
String sql = "select nominalLabelID from nominal_label where nominalAttributeID="
+ attributeID + " and nominalLabelCode='null'";
Connection conn = null;
PreparedStatement query = null;
ResultSet result = null;
try {
conn = getMySqlConnection();
query = conn.prepareStatement(sql);
result = query.executeQuery();
if (result.next()) {
return result.getInt(1);
}
} catch (SQLException ex) {
logger.error("Error in executing SQL statement: " + sql, ex);
} finally {
close(result);
close(query);
close(conn);
}
return null;
}
public static Instances getTrainingSet(int crisisID, int attributeID)
throws Exception {
ArrayList<String[]> wordVectors = new ArrayList<>();
ArrayList<String> labels = new ArrayList<>();
String sql = "SELECT wordFeatures, nominalLabelID FROM nominal_label_training_data WHERE crisisID = "
+ crisisID + " AND nominalAttributeID = " + attributeID;
getLabeledSet(sql, wordVectors, labels);
return createInstances(wordVectors, labels);
}
public static Instances getEvaluationSet(int crisisID, int attributeID,
Instances trainingData) throws Exception {
ArrayList<String[]> wordVectors = new ArrayList<>();
ArrayList<String> labels = new ArrayList<>();
String sql = "SELECT wordFeatures, nominalLabelID FROM nominal_label_evaluation_data WHERE crisisID = "
+ crisisID + " AND nominalAttributeID = " + attributeID;
getLabeledSet(sql, wordVectors, labels);
return createFormattedInstances(trainingData, wordVectors, labels);
}
static void getLabeledSet(String sql, ArrayList<String[]> wordVectors,
ArrayList<String> labels) {
Connection conn = null;
PreparedStatement statement = null;
ResultSet result = null;
String wordFeatures = null;
try {
conn = getMySqlConnection();
statement = conn.prepareStatement(sql);
result = statement.executeQuery();
while (result.next()) {
//Weka class attributes only accept string values, hence the toString
labels.add(Integer.toString(result.getInt("nominalLabelID")));
wordFeatures = result.getString("wordFeatures");
JSONObject wordsJson = new JSONObject(
Helpers.unescapeJson(wordFeatures));
wordVectors.add(Helpers.toStringArray(wordsJson
.getJSONArray("words")));
}
} catch (SQLException e) {
logger.error("Exception while fetching dataset. ", e);
} catch (Exception e) {
logger.error("Exception while fetching dataset", e);
} finally {
close(result);
close(statement);
close(conn);
}
}
static Instances createInstances(ArrayList<String[]> wordVectors,
ArrayList<String> labels) throws Exception {
if (wordVectors.size() != labels.size()) {
throw new Exception();
}
// Build a dictionary based on words in the documents, and transform
// documents into word vectors
HashSet<String> uniqueWords = new HashSet<String>();
for (String[] words : wordVectors) {
uniqueWords.addAll(Arrays.asList(words));
}
// Create attributes based on the dictionary
ArrayList<Attribute> attributes = new ArrayList<Attribute>();
for (String word : uniqueWords) {
attributes.add(new Attribute(word));
}
// Make class attribute
HashSet<String> uniqueLabels = new HashSet<String>(labels);
ArrayList<String> uniqueLabelsList = new ArrayList<String>(uniqueLabels);
Attribute classAttribute = new Attribute("___aidrclass___",
uniqueLabelsList);
attributes.add(classAttribute);
// Create the dataset
Instances instances = new Instances("data", attributes,
wordVectors.size());
double[] missingVal = new double[attributes.size()];
instances.setClass(classAttribute);
// Add each document as an instance
for (int i = 0; i < wordVectors.size(); i++) {
Instance item = new SparseInstance(instances.numAttributes());
item.setDataset(instances);
for (String word : wordVectors.get(i)) {
Attribute attribute = instances.attribute(word);
if (attribute != null) {
item.setValue(attribute, 1);
}
}
item.setValue(classAttribute, labels.get(i));
item.replaceMissingValues(missingVal);
instances.add(item);
}
return instances;
}
static Instances createFormattedInstances(Instances headerSet,
ArrayList<String[]> wordVectors, ArrayList<String> labels)
throws Exception {
if (wordVectors.size() != labels.size()) {
throw new Exception();
}
// Build a dictionary based on words in the documents, and transform
// documents into word vectors
HashSet<String> uniqueWords = new HashSet<String>();
for (String[] words : wordVectors) {
uniqueWords.addAll(Arrays.asList(words));
}
// Create the dataset
Instances instances = new Instances(headerSet, wordVectors.size());
double[] missingVal = new double[headerSet.numAttributes()];
// Set class index
instances.setClassIndex(headerSet.numAttributes() - 1);
Attribute classAttribute = instances.classAttribute();
// Get valid class labels
HashSet<String> classValues = new HashSet<String>();
Enumeration<?> classEnum = classAttribute.enumerateValues();
while (classEnum.hasMoreElements()) {
classValues.add((String) classEnum.nextElement());
}
// Add each document as an instance
for (int i = 0; i < wordVectors.size(); i++) {
if (!classValues.contains(labels.get(i))) {
logger.error("New class label found in evaluation set. Discarding value.");
continue;
/*
* TODO: Handle unseen labels in a better way, as this will
* over-estimate classification performance. Adding new values
* to class attributes requires recreation of the header and
* copying of all data to a new Instances. See:
* http://comments.gmane.org/gmane.comp.ai.weka/7806
*/
}
Instance item = new DenseInstance(instances.numAttributes());
item.setDataset(instances);
// Words
for (String word : wordVectors.get(i)) {
Attribute attribute = instances.attribute(word);
if (attribute != null) {
item.setValue(attribute, 1);
}
}
item.setValue(classAttribute, labels.get(i));
item.replaceMissingValues(missingVal);
instances.add(item);
}
return instances;
}
public static void saveDocumentToDatabase(Document item) {
List<Document> wrapper = new ArrayList<Document>();
wrapper.add(item);
saveDocumentsToDatabase(wrapper);
}
public static boolean canLog() {
if (0 == lastSaveTime) {
lastSaveTime = System.currentTimeMillis();
return true;
} else {
if ((System.currentTimeMillis() - lastSaveTime) > LOG_INTERVAL) {
return true;
} else {
return false;
}
}
}
public static void saveDocumentsToDatabase(List<Document> items) {
try {
for (Document item : items) {
TaggerDocument doc = Document.fromDocumentToTaggerDocument(item);
//System.out.println("Attempting to save NEW document for collection = " + doc.getCrisisCode());
//logger.info("Attempting to save NEW document for collection = " + doc.getCrisisCode());
Long docID = taskManager.saveNewTask(TaggerDocument.toDocumentDTO(doc), doc.getCrisisID());
++saveNewDocumentsCount;
if (docID.longValue() != -1) {
// Update document with auto generated Doc
item.setDocumentID(docID);
//logger.info("Success in saving NEW document: " + item.getDocumentID() + ", for collection = " + item.getCrisisCode());
} else {
logger.error("Something went wrong in saving document: " + item.getDocumentID() + ", for collection = " + item.getCrisisCode());
}
}
if (canLog()) {
logger.info("In interval " + new Date(lastSaveTime) + " - " + new Date() + ", save NEW documents count = " + saveNewDocumentsCount);
lastSaveTime = System.currentTimeMillis();
saveNewDocumentsCount = 0;
}
} catch (Exception e) {
logger.error("Exception when attempting to write Document to database", e);
}
saveHumanLabels(items);
}
/**
* Saves human-provided labels for a document and sends a notification via a
* redis queue to the model controller.
*
* @param documents A list of human-annotated documents.
*/
static void saveHumanLabels(List<Document> documents) {
try {
/*
String insertSql = "INSERT INTO document_nominal_label (documentID, nominalLabelID) VALUES (?,?)";
*/
ArrayList<Integer> docsWithLabels = new ArrayList<>();
ArrayList<TrainingSampleNotification> notifications = new ArrayList<>();
int rows = 0;
for (Document d: documents) {
List<NominalLabelBC> labels = d.getHumanLabels(NominalLabelBC.class);
if (labels.isEmpty()) // Skip document if it has no human-provided labels
{
continue;
}
docsWithLabels.add(d.getDocumentID().intValue());
for (NominalLabelBC label : labels) {
//statement.setInt(1, doc.getDocumentID());
//statement.setInt(2, label.getNominalLabelID());
//statement.execute();
Long userID = d.getUserID() != null ? d.getUserID() : 1L; // default labeler : 'System' user (userID = 1 in DB)
DocumentNominalLabelIdDTO idDTO = new DocumentNominalLabelIdDTO(d.getDocumentID(), new Long(label.getNominalLabelID()), userID);
DocumentNominalLabelDTO dto = new DocumentNominalLabelDTO();
dto.setIdDTO(idDTO);
if (canLog()) {
logger.info("Attempting to save LABELED document: " + dto.getIdDTO().getDocumentId() + " with nominal labelID=" + dto.getIdDTO().getNominalLabelId() + ", for collection = " + d.getCrisisCode() + ", userID = " + dto.getIdDTO().getUserId());
lastSaveTime = System.currentTimeMillis();
saveNewDocumentsCount = 0;
}
//System.out.println("Attempting to save LABELED document: " + dto.getIdDTO().getDocumentId() + " with nominal labelID=" + dto.getIdDTO().getNominalLabelId() + ", for collection = " + d.getCrisisCode() + ", userID = " + dto.getIdDTO().getUserId());
taskManager.saveDocumentNominalLabel(dto);
rows++;
}
notifications.add(new TrainingSampleNotification(d.getCrisisID().intValue(), getAttributeIDs(labels)));
}
if (rows == 0) {
return;
}
logger.info("Saved " + rows + " human labels for " + docsWithLabels.size()
+ " documents");
//statement.executeUpdate("UPDATE document SET hasHumanLabels=1 WHERE documentID IN (" + Helpers.join(docsWithLabels, ",") + ")");
sendNewLabeledDocumentNotification(notifications);
} catch (Exception e) {
logger.error("Exception when attempting to insert new document labels", e);
}
}
private static Collection<Integer> getAttributeIDs(List<NominalLabelBC> labels) {
HashSet<Integer> ids = new HashSet<Integer>();
for (NominalLabelBC l : labels) {
ids.add(l.getAttributeID());
}
return ids;
}
public static void sendNewLabeledDocumentNotification(
Collection<TrainingSampleNotification> notifications) {
Jedis redis = DataStore.getJedisConnection();
for (TrainingSampleNotification n : notifications) {
String message = "{ \"crisis_id\": " + n.crisisID
+ ", \"attributes\": [" + Helpers.join(n.attributeIDs, ",")
+ "] }";
redis.rpush(
TaggerConfigurator
.getInstance()
.getProperty(
TaggerConfigurationProperty.REDIS_TRAINING_SAMPLE_INFO_QUEUE),
message);
}
DataStore.close(redis);
}
public static ArrayList<ModelFamilyEC> getActiveModels() {
ArrayList<ModelFamilyEC> modelFamilies = new ArrayList<>();
Connection conn = null;
PreparedStatement sql = null;
ResultSet result = null;
try {
conn = getMySqlConnection();
Statement sql2 = conn.createStatement();
sql2.execute("SET group_concat_max_len = 10240");
sql2.close();
sql = conn
.prepareStatement(
"SELECT \n"
+ " fam.modelFamilyID, \n"
+ " fam.crisisID, \n"
+ " col.code AS crisisCode, \n"
+ " col.name AS crisisName, \n"
+ " fam.nominalAttributeID, \n"
+ " attr.code AS nominalAttributeCode, \n"
+ " attr.name AS nominalAttributeName, \n"
+ " attr.description AS nominalAttributeDescription, \n"
+ " mdl.modelID, \n"
+ " lbl.nominalLabelID,\n"
+ " lbl.nominalLabelCode,\n"
+ " lbl.name as nominalLabelName,\n"
+ " lbl.description as nominLabelDescription, \n"
+ " COUNT(DISTINCT dnl.documentID) AS labeledItemCount\n"
+ "FROM model_family fam \n"
+ "LEFT JOIN model mdl on mdl.modelFamilyID = fam.modelFamilyID \n"
+ "JOIN collection col on col.id = fam.crisisID and col.classifier_enabled = 1 \n"
+ "JOIN nominal_attribute attr ON attr.nominalAttributeID = fam.nominalAttributeID \n"
+ "JOIN nominal_label lbl ON lbl.nominalAttributeID = fam.nominalAttributeID \n"
+ "LEFT JOIN document doc ON doc.crisisID=fam.crisisID \n"
+ "LEFT JOIN document_nominal_label dnl ON dnl.documentID=doc.documentID AND dnl.nominalLabelID=lbl.nominalLabelID \n"
+ "WHERE fam.isActive AND (mdl.modelID IS NULL OR mdl.isCurrentModel) \n"
+ "GROUP BY crisisID, nominalAttributeID, nominalLabelID ");
result = sql.executeQuery();
ModelFamilyEC family = null;
NominalAttributeEC attribute = null;
HashMap<ModelFamilyEC, Integer> familyLabelCount = new HashMap<>();
while (result.next()) {
if (family == null || family.getModelFamilyID() != result.getInt("modelFamilyID")) {
//create attribute
attribute = new NominalAttributeEC();
attribute.setNominalAttributeID(result.getInt("nominalAttributeID"));
attribute.setCode(result.getString("nominalAttributeCode"));
attribute.setDescription(result.getString("nominalAttributeDescription"));
attribute.setName(result.getString("nominalAttributeName"));
//create model family
family = new ModelFamilyEC();
family.setCrisisID(result.getInt("crisisID"));
int tmpModelID = result.getInt("modelID");
if (!result.wasNull()) {
family.setCurrentModelID(tmpModelID);
}
family.setIsActive(true);
family.setModelFamilyID(result.getInt("modelFamilyID"));
family.setNominalAttribute(attribute);
familyLabelCount.put(family, 0);
modelFamilies.add(family);
}
//create label
NominalLabelEC label = new NominalLabelEC();
label.setDescription(result.getString("nominLabelDescription"));
label.setName(result.getString("nominalLabelName"));
label.setNominalAttribute(attribute);
label.setNominalLabelCode(result.getString("nominalLabelCode"));
label.setNominalLabelID(result.getInt("nominalLabelID"));
attribute.addNominalLabel(label);
int count = familyLabelCount.get(family);
familyLabelCount.put(family, count + result.getInt("labeledItemCount"));
}
//sum training sample counts per attribute
for (Map.Entry<ModelFamilyEC, Integer> entry : familyLabelCount.entrySet()) {
entry.getKey().setTrainingExampleCount(entry.getValue());
}
} catch (SQLException e) {
logger.error("Exception when getting model state", e);
} finally {
close(result);
close(sql);
close(conn);
}
return modelFamilies;
}
public static void getActiveModelsDocCount(HashMap<Integer, HashMap<Integer, ModelFamilyEC>> modelFamilies, HashMap<Integer,
HashMap<String, ModelFamilyEC>> modelFamiliesByCode) {
Connection conn = null;
PreparedStatement sql = null;
ResultSet result = null;
try {
conn = getMySqlConnection();
sql = conn
.prepareStatement(
"SELECT \n"
+ " fam.modelFamilyID, \n"
+ " fam.crisisID, \n"
+ " fam.nominalAttributeID, \n"
+ " mdl.modelID, \n"
+ " lbl.nominalLabelID,\n"
+ " COUNT(DISTINCT dnl.documentID) AS labeledItemCount\n"
+ "FROM model_family fam \n"
+ "LEFT JOIN model mdl on mdl.modelFamilyID = fam.modelFamilyID \n"
+ "JOIN nominal_label lbl ON lbl.nominalAttributeID = fam.nominalAttributeID \n"
+ "LEFT JOIN document doc ON doc.crisisID=fam.crisisID \n"
+ "LEFT JOIN document_nominal_label dnl ON dnl.documentID=doc.documentID AND dnl.nominalLabelID=lbl.nominalLabelID \n"
+ "WHERE fam.isActive AND (mdl.modelID IS NULL OR mdl.isCurrentModel) \n"
+ "GROUP BY crisisID, nominalAttributeID, nominalLabelID ");
result = sql.executeQuery();
ModelFamilyEC family = null;
NominalAttributeEC attribute = null;
NominalLabelEC label = null;
HashMap<ModelFamilyEC, Integer> familyLabelCount = new HashMap<>();
Integer crisisID = null;
Integer attributeID = null;
Integer nominalLabelID = null;
int count;
while (result.next()) {
crisisID = result.getInt("crisisID");
attributeID = result.getInt("nominalAttributeID");
nominalLabelID = result.getInt("nominalLabelID");
if (!modelFamilies.containsKey(crisisID)) {
modelFamilies.put(crisisID, new HashMap<Integer, ModelFamilyEC>());
modelFamiliesByCode.put(crisisID, new HashMap<String, ModelFamilyEC>());
}
if(modelFamilies.get(crisisID).get(attributeID) == null) {
//create model family
family = new ModelFamilyEC();
family.setCrisisID(crisisID);
int tmpModelID = result.getInt("modelID");
if (!result.wasNull()) {
family.setCurrentModelID(tmpModelID);
}
family.setIsActive(true);
family.setModelFamilyID(result.getInt("modelFamilyID"));
}
else
family = modelFamilies.get(crisisID).get(attributeID);
attribute = family.getNominalAttribute();
if(attribute == null)
{
if(attLabels.containsKey(attributeID))
{
attribute = attLabels.get(attributeID);
family.setNominalAttribute(attribute);
}
else
{
synchronized(attLabels) {
getAttributesLabels();
if(attLabels.containsKey(attributeID))
{
attribute = attLabels.get(attributeID);
family.setNominalAttribute(attribute);
}
}
}
}
label = attribute.getNominalLabel(nominalLabelID);
if(label == null)
{
synchronized(attLabels) {
updateLabels(attributeID);
attribute = attLabels.get(attributeID);
}
}
if(familyLabelCount.get(family) == null)
familyLabelCount.put(family, 0);
modelFamilies.get(crisisID).put(attributeID, family);
modelFamiliesByCode.get(crisisID).put(attribute.getCode(), family);
count = familyLabelCount.get(family);
familyLabelCount.put(family, count + result.getInt("labeledItemCount"));
}
//sum training sample counts per attribute
for (Map.Entry<ModelFamilyEC, Integer> entry : familyLabelCount.entrySet()) {
entry.getKey().setTrainingExampleCount(entry.getValue());
//logger.info("training example count: " + entry.getValue() + " for family" + entry.getKey().getModelFamilyID());
}
} catch (SQLException e) {
logger.error("Exception when getting model state ::", e);
} catch (Exception e) {
logger.error("Exception in getActiveModelsDocCount ::", e);
} finally {
close(result);
close(sql);
close(conn);
}
}
public static void deleteModel(int modelID) {
Connection conn = null;
PreparedStatement sql = null;
try {
conn = getMySqlConnection();
sql = conn.prepareStatement("DELETE FROM model WHERE modelID=" + modelID);
sql.executeUpdate();
} catch (SQLException e) {
logger.error("Exception while deleting model");
} finally {
close(sql);
close(conn);
}
}
public static HashMap<String, Integer> getCrisisIDs() {
HashMap<String, Integer> crisisIDs = new HashMap<String, Integer>();
Connection conn = null;
PreparedStatement sql = null;
ResultSet result = null;
try {
conn = getMySqlConnection();
sql = conn.prepareStatement("select id, code from collection where classifier_enabled = 1;");
result = sql.executeQuery();
while (result.next()) {
crisisIDs.put(result.getString("code"), result.getInt("id"));
}
} catch (SQLException e) {
logger.error("Exception when getting crisis IDs", e);
} finally {
close(result);
close(sql);
close(conn);
}
return crisisIDs;
}
public static void truncateLabelingTaskBufferForCrisis(int crisisID, int maxLength) {
if (maxLength < 0 || crisisID < 0) {
logger.error("Cannot truncate the labeling task buffer - negative parameter(s)");
throw new RuntimeException(
"Cannot truncate the labeling task buffer - negative parameter(s)");
}
final int ERROR_MARGIN = 0; // if less than this, then skip delete
int deleteCount = taskManager.truncateLabelingTaskBufferForCrisis(crisisID, maxLength, ERROR_MARGIN);
logger.info("Truncation results for crisis " + crisisID + ", deleted doc count = " + deleteCount);
}
public static int saveModelToDatabase(int crisisID, int nominalAttributeID,
Model model) {
int modelID = MODEL_ID_ERROR;
Connection conn = null;
PreparedStatement modelInsert = null, mfUpdate = null;
PreparedStatement modelLabelPerfInsert = null;
ResultSet result = null;
NumberFormat format = NumberFormat.getNumberInstance(Locale.US);
String selectModelFamilyID = "(SELECT modelFamilyID FROM model_family WHERE crisisID = "
+ crisisID + " and nominalAttributeID = " + nominalAttributeID + ")";
try {
// Insert the model object
conn = getMySqlConnection();
System.out.println("AUC: " + model.getMeanAuc());
System.out.println("AUC formatted: " + format.format(model.getMeanAuc()));
String modelInsertSql =
"INSERT INTO model (modelFamilyID, avgPrecision, avgRecall, avgAuc, isCurrentModel, trainingCount, trainingTime) VALUES "
+ "(" + selectModelFamilyID + ", "
+ format.format(model.getMeanPrecision())
+ ", "
+ format.format(model.getMeanRecall())
+ ", "
+ format.format(model.getMeanAuc())
+ ",false,"
+ model.getTrainingSampleCount()
+ ", UTC_TIMESTAMP())";
modelInsert = conn.prepareStatement(modelInsertSql, Statement.RETURN_GENERATED_KEYS);
modelInsert.executeUpdate();
result = modelInsert.getGeneratedKeys();
if (result != null && result.next()) {
modelID = result.getInt(1);
}
System.out.println("Inserted a new model with model ID " + modelID); //TODO: remove
logger.info("Inserted a new model with model ID " + modelID);
//Insert per-label classification performance of this model
List<ModelNominalLabelPerformance> labelPerformaceList = model.getLabelPerformanceList();
String perfInsertSql = "INSERT INTO model_nominal_label (`modelID`, `nominalLabelID`, `labelPrecision`, `labelRecall`, `labelAuc`, `classifiedDocumentCount`) "
+ " VALUES (?,?,?,?,?,0);";
modelLabelPerfInsert = conn.prepareStatement(perfInsertSql);
for (ModelNominalLabelPerformance perf : labelPerformaceList) {
modelLabelPerfInsert.setInt(1, modelID);
modelLabelPerfInsert.setInt(2, perf.getNominalLabelID());
modelLabelPerfInsert.setString(3, format.format(perf.getPrecision()));
modelLabelPerfInsert.setString(4, format.format(perf.getRecall()));
modelLabelPerfInsert.setString(5, format.format(perf.getAuc()));
modelLabelPerfInsert.executeUpdate();
}
// Set the the new model as the active model of its model family
mfUpdate = conn
.prepareStatement("UPDATE model SET isCurrentModel = (modelID = " + modelID + ") "
+ "WHERE modelID = " + modelID + " OR (isCurrentModel AND modelFamilyID = "
+ selectModelFamilyID + ")");
mfUpdate.executeUpdate();
} catch (SQLException e) {
logger.error("Exception while saving model to database", e);
} finally {
close(modelLabelPerfInsert);
close(result);
close(modelInsert);
close(mfUpdate);
close(conn);
}
return modelID;
}
/**
* @return The added value of getting one more training sample of a given
* label. Calculated as 1-p(label).
*/
public static HashMap<Integer, HashMap<Integer, Double>> getNominalLabelTrainingValues() {
//<attributeID,<labelID, trainingValueWeight>>
HashMap<Integer, HashMap<Integer, Double>> scores = new HashMap<>();
Connection conn = null;
PreparedStatement sql = null;
ResultSet result = null;
try {
conn = getMySqlConnection();
sql = conn
.prepareStatement("select nl.nominalAttributeID, nl.nominalLabelID, 1-coalesce(count(dnl.nominalLabelID)/totalCount, 0.5) as weight \n"
+ "from nominal_label nl \n"
+ "left join document_nominal_label dnl on dnl.nominalLabelID=nl.nominalLabelID \n"
+ "left join (select nominalAttributeID, greatest(count(1),1) as totalCount \n"
+ " from document_nominal_label natural join nominal_label group by 1) lc on lc.nominalAttributeID=nl.nominalAttributeID \n"
+ "group by 1,2");
result = sql.executeQuery();
while (result.next()) {
int attrID = result.getInt("nominalAttributeID");
int labelID = result.getInt("nominalLabelID");
double weight = result.getDouble("weight");
if (!scores.containsKey(attrID)) {
scores.put(attrID, new HashMap<Integer, Double>());
}
scores.get(attrID).put(labelID, weight);
}
} catch (SQLException e) {
logger.error("Exception when getting nominal label training values", e);
} finally {
close(result);
close(sql);
close(conn);
}
return scores;
}
public static void saveClassifiedDocumentCounts(HashMap<Integer, HashMap<Integer, Integer>> data) {
Connection conn = null;
PreparedStatement selectStatement = null;
PreparedStatement updateStatement = null;
ResultSet resultSet = null;
String updateQuery = "UPDATE model_nominal_label SET classifiedDocumentCount = classifiedDocumentCount + ? "
+ "WHERE modelID = ? AND nominalLabelID = ?";
try {
// Insert document
conn = getMySqlConnection();
updateStatement = conn.prepareStatement(updateQuery);
for (Map.Entry<Integer, HashMap<Integer, Integer>> modelDocCounts : data.entrySet()) {
int modelID = modelDocCounts.getKey();
for (Map.Entry<Integer, Integer> labelDocCount : modelDocCounts.getValue().entrySet()) {
Integer labelID = labelDocCount.getKey();
Integer docCount = labelDocCount.getValue();
updateStatement.setInt(1, docCount);
updateStatement.setInt(2, modelID);
updateStatement.setInt(3, labelID);
updateStatement.executeUpdate();
}
}
} catch (SQLException e) {
logger.error("Exception when attempting to write ClassifiedDocumentCount to database : " + data, e);
} finally {
close(resultSet);
close(updateStatement);
close(selectStatement);
close(conn);
}
}
//update nominal labels for an attribute from db
private static void updateLabels(Integer attributeID){
Connection conn = null;
PreparedStatement selectStatement = null;
ResultSet result = null;
String selectQuery = "SELECT nominalLabelID,l.nominalLabelCode,l.name as nominalLabelName,"
+ "l.description as nominLabelDescription FROM nominal_label l where l.nominalAttributeID = ?";
try {
conn = getMySqlConnection();
selectStatement = conn.prepareStatement(selectQuery);
selectStatement.setInt(1, attributeID);
result = selectStatement.executeQuery();
NominalAttributeEC attribute = null;
NominalLabelEC label = null;
while (result.next()) {
int labelID = result.getInt("nominalLabelID");
attribute = attLabels.get(attributeID);
attribute.resetNominalLabels();
if(attribute.getNominalLabel(labelID) == null)
{
label = new NominalLabelEC();
label.setDescription(result.getString("nominLabelDescription"));
label.setName(result.getString("nominalLabelName"));
label.setNominalAttribute(attribute);
label.setNominalLabelCode(result.getString("nominalLabelCode"));
label.setNominalLabelID(result.getInt("nominalLabelID"));
attribute.addNominalLabel(label);
}
}
} catch (SQLException e) {
logger.error("Exception while updating nominal labels ::", e);
} finally {
close(result);
close(selectStatement);
close(conn);
}
}
//get all attributes and labels, and store in hashmap
//invoke this method after a timed wait - to refresh the data structure
public static void getAttributesLabels() {
Connection conn = null;
PreparedStatement selectStatement = null;
ResultSet result = null;
String selectQuery = "SELECT a.nominalAttributeID, a.code as nominalAttributeCode,"
+ " a.description as nominalAttributeDescription, a.name as nominalAttributeName, l.nominalLabelID, l.nominalLabelCode,l.name as nominalLabelName,"
+ "l.description as nominLabelDescription FROM nominal_attribute a join nominal_label l on a.nominalAttributeID = l.nominalAttributeID";
try {
conn = getMySqlConnection();
selectStatement = conn.prepareStatement(selectQuery);
result = selectStatement.executeQuery();
NominalAttributeEC attribute = null;
NominalLabelEC label = null;
while (result.next()) {
int attrID = result.getInt("nominalAttributeID");
int labelID = result.getInt("nominalLabelID");
if(!attLabels.containsKey(attrID))
{
attribute = new NominalAttributeEC();
attribute.setNominalAttributeID(attrID);
attribute.setCode(result.getString("nominalAttributeCode"));
attribute.setDescription(result.getString("nominalAttributeDescription"));
attribute.setName(result.getString("nominalAttributeName"));
attLabels.put(attrID, attribute);
}
else
attribute = attLabels.get(attrID);
if(attribute.getNominalLabel(labelID) == null)
{
label = new NominalLabelEC();
label.setDescription(result.getString("nominLabelDescription"));
label.setName(result.getString("nominalLabelName"));
label.setNominalAttribute(attribute);
label.setNominalLabelCode(result.getString("nominalLabelCode"));
label.setNominalLabelID(result.getInt("nominalLabelID"));
attribute.addNominalLabel(label);
}
}
} catch (SQLException e) {
logger.error("Exception while creating nominal attributes ::", e);
} finally {
close(result);
close(selectStatement);
close(conn);
}
}
}