package edu.harvard.econcs.turkserver.server;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.MapMaker;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.inject.Singleton;
import edu.harvard.econcs.turkserver.api.*;
/**
* Handles callback events for experiments
*
* TODO fix the triggered methods here to properly handle
* superclass methods with the same signature, superclass methods
* are currently added but not called at all.
*
* @author mao
*
*/
@Singleton
public class EventAnnotationManager {
protected final Logger logger = LoggerFactory.getLogger(this.getClass().getSimpleName());
final ConcurrentMap<String, Object> beans;
final Multimap<Class<?>, Object> beanClasses;
final Multimap<Class<?>, Method> starts;
final Multimap<Class<?>, Method> rounds;
final Multimap<Class<?>, Method> intervals;
final Multimap<Class<?>, Method> timeouts;
final Multimap<Class<?>, Method> connects;
final Multimap<Class<?>, Method> disconnects;
final Multimap<Class<?>, Method> broadcasts;
final Multimap<Class<?>, Method> services;
EventAnnotationManager() {
beans = new MapMaker().makeMap();
beanClasses = Multimaps.synchronizedSetMultimap(
HashMultimap.<Class<?>, Object>create());
starts = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
rounds = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
intervals = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
timeouts = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
connects = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
disconnects = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
broadcasts = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
services = Multimaps.synchronizedListMultimap(
ArrayListMultimap.<Class<?>, Method>create());
}
/**
* Add an experiment and process events
* @param expId
* @param exp
*/
boolean processExperiment(String expId, Object exp) {
if( exp == null )
throw new RuntimeException("Refusing to make mappings for a null experiment");
ExperimentServer e = exp.getClass().getAnnotation(ExperimentServer.class);
if( e == null )
logger.warn("Class {} does not have @Experiment annotation, but trying mappings anyway", exp.getClass().toString());
if( beanClasses.get(exp.getClass()).size() > 0 ) {
beans.put(expId, exp);
beanClasses.put(exp.getClass(), exp);
return true;
}
boolean result = processCallbacks(exp.getClass());
if( result ) {
beans.put(expId, exp);
beanClasses.put(exp.getClass(), exp);
}
else {
throw new IllegalArgumentException("Passed in a class with no callbacks");
}
return result;
}
private boolean processCallbacks(Class<?> klass) {
boolean result = false;
for (Class<?> c = klass; c != Object.class; c = c.getSuperclass())
{
Method[] methods = c.getDeclaredMethods();
for (Method method : methods)
{
result |= processVoid(c, method,
starts, StartExperiment.class);
result |= processVoid(c, method,
timeouts, TimeLimit.class);
result |= processVoid(c, method,
intervals, IntervalEvent.class);
result |= processInt(c, method,
rounds, StartRound.class);
result |= processWorkerActivity(c, method,
connects, WorkerConnect.class);
result |= processWorkerActivity(c, method,
disconnects, WorkerDisconnect.class);
result |= processBroadcastMessage(c, method,
broadcasts, BroadcastMessage.class);
result |= processServiceMessage(c, method,
services, ServiceMessage.class);
}
}
return result;
}
/**
* Test all the callbacks of a class. This should process the same way as a real experiment.
* @param klass
* @return
*/
public static boolean testCallbacks(Class<?> klass) {
Multimap<Class<?>, Method> testMap = ArrayListMultimap.<Class<?>, Method>create();
boolean result = false;
for (Class<?> c = klass; c != Object.class; c = c.getSuperclass())
{
Method[] methods = c.getDeclaredMethods();
for (Method method : methods)
{
result |= processVoid(c, method,
testMap, StartExperiment.class);
result |= processVoid(c, method,
testMap, TimeLimit.class);
result |= processVoid(c, method,
testMap, IntervalEvent.class);
result |= processInt(c, method,
testMap, StartRound.class);
result |= processWorkerActivity(c, method,
testMap, WorkerConnect.class);
result |= processWorkerActivity(c, method,
testMap, WorkerDisconnect.class);
result |= processBroadcastMessage(c, method,
testMap, BroadcastMessage.class);
result |= processServiceMessage(c, method,
testMap, ServiceMessage.class);
}
}
return result;
}
private static boolean processVoid(Class<?> klass, Method method,
Multimap<Class<?>, Method> map, Class<? extends Annotation> annot) {
if( method.getAnnotation(annot) == null ) return false;
if (method.getReturnType() != Void.TYPE)
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must have void return type");
if (method.getParameterTypes().length > 0)
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must have no parameters");
if (Modifier.isStatic(method.getModifiers()))
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must not be static");
map.put(klass, method);
return true;
}
private static boolean processInt(Class<?> klass, Method method,
Multimap<Class<?>, Method> map, Class<? extends Annotation> annot) {
if( method.getAnnotation(annot) == null ) return false;
if (method.getReturnType() != Void.TYPE)
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must have void return type");
Class<?>[] types = method.getParameterTypes();
if (types.length != 1 || !int.class.isAssignableFrom(types[0]) )
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must accept an int");
if (Modifier.isStatic(method.getModifiers()))
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must not be static");
map.put(klass, method);
return true;
}
private static boolean processWorkerActivity(Class<?> klass, Method method,
Multimap<Class<?>, Method> map, Class<? extends Annotation> annot) {
if( method.getAnnotation(annot) == null ) return false;
if (method.getReturnType() != Void.TYPE)
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must have void return type");
Class<?>[] types = method.getParameterTypes();
if (types.length != 1 || !HITWorker.class.isAssignableFrom(types[0]) )
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must accept a HITWorker");
if (Modifier.isStatic(method.getModifiers()))
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must not be static");
map.put(klass, method);
return true;
}
private static boolean processBroadcastMessage(Class<?> klass, Method method,
Multimap<Class<?>, Method> map, Class<? extends Annotation> annot) {
if( method.getAnnotation(annot) == null ) return false;
if (method.getReturnType() != Boolean.TYPE)
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must have boolean return type");
Class<?>[] types = method.getParameterTypes();
if (types.length != 2 || !HITWorker.class.isAssignableFrom(types[0]) ||
!Map.class.isAssignableFrom(types[1]) )
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": " +
"it must accept a HITWorker, then a Map<String, Object> message");
if (Modifier.isStatic(method.getModifiers()))
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must not be static");
map.put(klass, method);
return true;
}
private static boolean processServiceMessage(Class<?> klass, Method method,
Multimap<Class<?>, Method> map, Class<? extends Annotation> annot) {
if( method.getAnnotation(annot) == null ) return false;
if (method.getReturnType() != Void.TYPE)
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must have void return type");
Class<?>[] types = method.getParameterTypes();
if (types.length != 2 || !HITWorker.class.isAssignableFrom(types[0]) ||
!Map.class.isAssignableFrom(types[1]) )
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": " +
"it must accept a HITWorker, then a Map<String, Object> message");
if (Modifier.isStatic(method.getModifiers()))
throw new RuntimeException("Invalid " + annot.toString() + " method " + method + ": it must not be static");
map.put(klass, method);
return true;
}
private Object invokeMethod(Object bean, Method m, Object... args) {
boolean accessible = m.isAccessible();
try {
// TODO robust-ify the accessibility issue here
m.setAccessible(true);
return m.invoke(bean, args);
} catch (Exception e) {
logger.warn("Exception invoking {} on {}, ignoring", m, bean.getClass().toString());
e.printStackTrace();
return null;
} finally {
m.setAccessible(accessible);
}
}
List<Method> getIntervalEvents(String expId) {
Object bean = beans.get(expId);
List<Method> l = new LinkedList<>();
synchronized(intervals) {
l.addAll(intervals.get(bean.getClass()));
}
return l;
}
/**
* Delivers a broadcast message to an experiment
* @param expId
* @param message
*/
boolean deliverBroadcastMsg(String expId, HITWorker source, Map<String, Object> message) {
Object bean = beans.get(expId);
boolean forward = false;
synchronized(broadcasts) {
for( Method m : broadcasts.get(bean.getClass())) {
BroadcastMessage ann = m.getAnnotation(BroadcastMessage.class);
if( ann.key().length > 0 ) {
if( message == null ) continue;
String key = ann.key()[0];
if ( !message.containsKey(key) ) continue;
if ( ann.value().length > 0 && !message.get(key).equals(ann.value()[0])) continue;
}
forward |= (Boolean) invokeMethod(bean, m, source, message);
}
}
return forward;
}
/**
* Delivers a service message to an experiment
* @param expId
* @param source
* @param message
* @return
*/
void deliverServiceMsg(String expId, HITWorker source, Map<String, Object> message) {
Object bean = beans.get(expId);
synchronized(services) {
for( Method m : services.get(bean.getClass())) {
ServiceMessage ann = m.getAnnotation(ServiceMessage.class);
if( ann.key().length > 0 ) {
String key = ann.key()[0];
if ( !message.containsKey(key) ) continue;
if ( ann.value().length > 0 && !message.get(key).equals(ann.value()[0])) continue;
}
invokeMethod(bean, m, source, message);
}
}
}
void triggerStart(String expId) {
Object bean = beans.get(expId);
synchronized(starts) {
for( Method m : starts.get(bean.getClass())) {
invokeMethod(bean, m);
}
}
}
void triggerRound(String expId, int round) {
Object bean = beans.get(expId);
synchronized(rounds) {
for( Method m : rounds.get(bean.getClass())) {
invokeMethod(bean, m, round);
}
}
}
/**
* Special method to activate interval events, which are not all the same
* @param expId
* @param method
*/
void triggerInterval(String expId, Method method) {
Object bean = beans.get(expId);
invokeMethod(bean, method);
}
void triggerWorkerConnect(String expId, HITWorkerImpl source) {
Object bean = beans.get(expId);
synchronized(connects) {
for( Method m : connects.get(bean.getClass())) {
invokeMethod(bean, m, source);
}
}
}
void triggerWorkerDisconnect(String expId, HITWorkerImpl source) {
Object bean = beans.get(expId);
synchronized(disconnects) {
for( Method m : disconnects.get(bean.getClass())) {
invokeMethod(bean, m, source);
}
}
}
void triggerTimelimit(String expId) {
Object bean = beans.get(expId);
synchronized(timeouts) {
for( Method m : timeouts.get(bean.getClass())) {
invokeMethod(bean, m);
}
}
}
/**
* Remove callback tracking for an experimentId
* @param experimentId
*/
void deprocessExperiment(String experimentId) {
// No null experiments should have been mapped anyway
Object bean;
if( (bean = beans.remove(experimentId)) == null ) return;
beanClasses.remove(bean.getClass(), bean);
if( beanClasses.get(bean.getClass()).size() > 0 ) return;
// no more beans of this type, de-register callbacks
for (Class<?> c = bean.getClass(); c != Object.class; c = c.getSuperclass())
{
starts.removeAll(c);
timeouts.removeAll(c);
rounds.removeAll(c);
connects.removeAll(c);
disconnects.removeAll(c);
broadcasts.removeAll(c);
services.removeAll(c);
}
}
}