package com.googlecode.totallylazy;
import com.googlecode.totallylazy.functions.Curried2;
import com.googlecode.totallylazy.functions.Function1;
import com.googlecode.totallylazy.numbers.Numbers;
import com.googlecode.totallylazy.predicates.LogicalPredicate;
import com.googlecode.totallylazy.predicates.Predicate;
import com.googlecode.totallylazy.reflection.Methods;
import java.lang.reflect.Method;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import static com.googlecode.totallylazy.functions.Callables.toClass;
import static com.googlecode.totallylazy.reflection.Methods.methodName;
import static com.googlecode.totallylazy.reflection.Methods.parameterTypes;
import static com.googlecode.totallylazy.predicates.Predicates.is;
import static com.googlecode.totallylazy.predicates.Predicates.not;
import static com.googlecode.totallylazy.predicates.Predicates.notNullValue;
import static com.googlecode.totallylazy.predicates.Predicates.nullValue;
import static com.googlecode.totallylazy.predicates.Predicates.where;
import static com.googlecode.totallylazy.Sequences.sequence;
import static com.googlecode.totallylazy.comparators.Comparators.by;
import static com.googlecode.totallylazy.numbers.Numbers.ascending;
import static com.googlecode.totallylazy.numbers.Numbers.minimum;
import static com.googlecode.totallylazy.numbers.Numbers.sum;
public class Dispatcher {
private final Class<?> aClass;
private final Object instance;
private final Predicate<? super Method> predicate;
private final ConcurrentMap<List<?>, Option<Method>> cache = new ConcurrentHashMap<List<?>, Option<Method>>();
private Dispatcher(Class<?> aClass, Object instance, Predicate<? super Method> predicate) {
this.aClass = aClass;
this.instance = instance;
this.predicate = predicate;
}
public static Dispatcher dispatcher(Class<?> aClass, String name) {
return dispatcher(aClass, where(methodName(), is(name)));
}
public static Dispatcher dispatcher(Class<?> aClass, Predicate<? super Method> predicate) {
return dispatcher(aClass, null, predicate);
}
public static Dispatcher dispatcher(Object instance, String name) {
return dispatcher(instance, where(methodName(), is(name)));
}
public static Dispatcher dispatcher(Object instance, Predicate<? super Method> predicate) {
return dispatcher(instance.getClass(), instance, predicate);
}
public static Dispatcher dispatcher(Class<?> aClass, Object instance, Predicate<? super Method> predicate) {
return new Dispatcher(aClass, instance, predicate);
}
public <T> T invoke(Object... args) {
return this.<T>invokeOption(args).get();
}
public <T> Option<T> invokeOption(Object... args) {
final List<Class<?>> argumentClasses = sequence(args).map(toClass()).toList();
return cache.computeIfAbsent(argumentClasses, key -> Methods.allMethods(aClass).
filter(predicate).
filter(where(parameterTypes(), matches(argumentClasses))).
sort(by(distanceFrom(argumentClasses), ascending())).
headOption()).map(Methods.<T>invokeOn(instance, args));
}
private static Function1<Method, Number> distanceFrom(final Iterable<Class<?>> argumentClasses) {
return method -> distanceFrom(argumentClasses, sequence(method.getParameterTypes()));
}
static Number distanceFrom(Iterable<Class<?>> argumentClasses, Iterable<Class<?>> parameterTypes) {
return sequence(argumentClasses).zip(parameterTypes).map(distanceBetween().pair()).reduce(sum);
}
private static Curried2<Class<?>, Class<?>, Number> distanceBetween() {
return Dispatcher::distanceBetween;
}
static Number distanceBetween(Class<?> argument, Class<?> parameterType) {
if (argument.equals(parameterType)) return 0;
return Numbers.add(parameterType.isInterface() ? 1 : 1.1, sequence(argument.getInterfaces()).
append(argument.getSuperclass()).
filter(not(nullValue())).
map(distanceBetween().flip().apply(parameterType)).
reduce(minimum));
}
private static LogicalPredicate<Class<?>[]> matches(final Iterable<Class<?>> argumentClasses) {
return new LogicalPredicate<Class<?>[]>() {
@Override
public boolean matches(Class<?>[] classes) {
return sequence(classes).equals(argumentClasses, new LogicalPredicate<Pair<Class<?>, Class<?>>>() {
@Override
public boolean matches(Pair<Class<?>, Class<?>> pair) {
return pair.first().isAssignableFrom(pair.second());
}
});
}
};
}
}