package qa.qcri.aidr.predict.classification.nominal; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; import org.apache.log4j.Logger; import org.json.JSONArray; import org.json.JSONObject; import redis.clients.jedis.Jedis; import qa.qcri.aidr.predict.DataStore; import qa.qcri.aidr.predict.classification.nominal.CrisisAttributePair; import qa.qcri.aidr.predict.common.Event; import qa.qcri.aidr.predict.common.TaggerConfigurationProperty; import qa.qcri.aidr.predict.common.TaggerConfigurator; import qa.qcri.aidr.predict.dbentities.ModelFamilyEC; /** * ModelRetrainTrigger gets notified of new training samples through a Redis * queue. When a sufficient number of samples have arrived, it throws an event * triggering rebuilding of the relevant model. * * @author jrogstadius */ class ModelRetrainTrigger implements Runnable { private static Logger logger = Logger.getLogger(ModelRetrainTrigger.class); public Event<CrisisAttributePair> onRetrain = new Event<CrisisAttributePair>(); int timeThreshold = 60000; // 1000*60*10; //TODO: Model re-training // threshold should be dynamic HashMap<Integer, HashMap<Integer, Integer>> newSampleCounts = new HashMap<Integer, HashMap<Integer, Integer>>(); HashMap<Integer, HashMap<Integer, Long>> rebuildTimestamps = new HashMap<Integer, HashMap<Integer, Long>>(); HashMap<Integer, List<Integer>> forceRetrains = new HashMap<Integer, List<Integer>>(); public void run() { while (true) { parseTrainingSamples(); checkRetrainThresholds(); forceRetrains.clear(); try { Thread.sleep(10000); // sleep for 10sec before next attempt } catch (InterruptedException e) { // TODO Auto-generated catch block e.printStackTrace(); logger.error("Exception: ", e); } } } public void initialize(ArrayList<ModelFamilyEC> modelStates) { for (ModelFamilyEC m : modelStates) { increment( m.getCrisisID(), new int[] { m.getNominalAttribute().getNominalAttributeID() }, m.getTrainingExampleCount() % Integer .parseInt(TaggerConfigurator .getInstance() .getProperty( TaggerConfigurationProperty.SAMPLE_COUNT_THRESHOLD))); } } private int parseTrainingSamples() { Jedis redis = DataStore.getJedisConnection(); int newSampleCount = 0; String line = null; long consumptionStart = new Date().getTime(); try { while ((line = getInfoMessage(redis)) != null && new Date().getTime() - consumptionStart < timeThreshold) { logger.info("A training sample has arrived"); // Parse notification containing event id and attribute ids JSONObject obj = new JSONObject(line); int crisisID = obj.getInt("crisis_id"); JSONArray attrArr = obj.getJSONArray("attributes"); int[] attributeIDs = new int[attrArr.length()]; for (int i = 0; i < attrArr.length(); i++) { attributeIDs[i] = attrArr.getInt(i); } increment(crisisID, attributeIDs, 1); if (obj.has("force_retrain") && obj.getBoolean("force_retrain")) { if (!forceRetrains.containsKey(crisisID)) { forceRetrains.put(crisisID, new ArrayList<Integer>()); } for (int attributeID : attributeIDs) { forceRetrains.get(crisisID).add(attributeID); } } newSampleCount++; } } catch (Exception e) { logger.error("Exception while processing training sample message queue", e); } finally { DataStore.close(redis); } return newSampleCount; } private String getInfoMessage(Jedis redis) { List<String> result = redis .blpop(5, TaggerConfigurator .getInstance() .getProperty( TaggerConfigurationProperty.REDIS_TRAINING_SAMPLE_INFO_QUEUE)); if (result == null || result.size() != 2) { return null; // Result is null on timeout, size should always be 2 } return result.get(1); } private void increment(int crisisID, int[] attributeIDs, int incrementValue) { for (int id : attributeIDs) { if (!newSampleCounts.containsKey(crisisID)) { newSampleCounts.put(crisisID, new HashMap<Integer, Integer>()); rebuildTimestamps.put(crisisID, new HashMap<Integer, Long>()); } if (!newSampleCounts.get(crisisID).containsKey(id)) { newSampleCounts.get(crisisID).put(id, incrementValue); rebuildTimestamps.get(crisisID).put(id, (long) 0); } else { newSampleCounts.get(crisisID).put(id, newSampleCounts.get(crisisID).get(id) + incrementValue); } } } private void checkRetrainThresholds() { // For each event and ontology where there are enough new training // samples, and it was long enough since the last model was built, // retrain for (int crisisID : newSampleCounts.keySet()) { HashMap<Integer, Integer> eventMap = newSampleCounts.get(crisisID); for (int attributeID : eventMap.keySet()) { long now = new Date().getTime(); if (eventMap.get(attributeID) >= Integer .parseInt(TaggerConfigurator .getInstance() .getProperty( TaggerConfigurationProperty.SAMPLE_COUNT_THRESHOLD)) && (now - rebuildTimestamps.get(crisisID).get( attributeID)) >= timeThreshold) { retrain(crisisID, attributeID); } else { if (forceRetrains.containsKey(crisisID) && forceRetrains.get(crisisID) .contains(attributeID)) { retrain(crisisID, attributeID); } } } } } private void retrain(int crisisID, int attributeID) { logger.info("Time to retrain model for " + crisisID + " attribute " + attributeID); onRetrain.fire(this, new CrisisAttributePair(crisisID, attributeID)); newSampleCounts.get(crisisID).put(attributeID, 0); rebuildTimestamps.get(crisisID).put(attributeID, new Date().getTime()); } }