package edu.berkeley.nlp.util.functional;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Factory;
import edu.berkeley.nlp.util.LazyIterable;
import edu.berkeley.nlp.util.Pair;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.*;
/**
* Collection of Functional Utilities you'd
* find in any functional programming language.
* Things like map, filter, reduce, etc..
*
* Created by IntelliJ IDEA.
* User: aria42
* Date: Oct 7, 2008
* Time: 1:06:08 PM
*/
public class FunctionalUtils {
public static <T> List<T> take(Iterator<T> it, int n) {
List<T> result = new ArrayList<T>();
for (int i=0; i < n && it.hasNext(); ++i) {
result.add(it.next());
}
return result;
}
private static Method getMethod(Class c, String field) {
Method[] methods = c.getDeclaredMethods();
String trgMethName = "get" + field;
Method trgMeth = null;
for (Method m: methods) {
if (m.getName().equalsIgnoreCase(trgMethName) || m.getName().equalsIgnoreCase(field)) {
return m;
}
}
return null;
}
private static Field getField(Class c, String fieldName) {
Field[] fields = c.getDeclaredFields();
for (Field f: fields) {
if (f.getName().equalsIgnoreCase(fieldName)) {
return f;
}
}
return null;
}
public static <T> Pair<T,Double> findMax(Iterable<T> xs, Function<T,Double> fn) {
double max = Double.NEGATIVE_INFINITY;
T argMax = null;
for (T x : xs) {
double val = fn.apply(x);
if (val > max) { max = val ; argMax = x; }
}
return Pair.newPair(argMax,max);
}
public static <T> Pair<T,Double> findMin(Iterable<T> xs, Function<T,Double> fn) {
double min= Double.POSITIVE_INFINITY;
T argMin = null;
for (T x : xs) {
double val = fn.apply(x);
if (val < min) { min= val ; argMin = x; }
}
return Pair.newPair(argMin,min);
}
public static<K,I,V> Map<K,V> compose(Map<K,I> map, Function<I,V> fn) {
return map(map,fn, (Predicate<K>) Predicates.getTruePredicate(),new HashMap<K,V>());
}
public static<K,I,V> Map<K,V> compose(Map<K,I> map, Function<I,V> fn, Predicate<K> pred) {
return map(map,fn,pred,new HashMap<K,V>());
}
public static <C> List make(Factory<C> factory, int k ) {
List<C> insts = new ArrayList<C>();
for (int i = 0; i < k; i++) {
insts.add(factory.newInstance());
}
// Fuck you cvs
return insts;
}
public static<K,I,V> Map<K,V> map(Map<K,I> map, Function<I,V> fn, Predicate<K> pred, Map<K,V> resultMap) {
for (Map.Entry<K,I> entry: map.entrySet()) {
K key = entry.getKey();
I inter = entry.getValue();
if (pred.apply(key)) resultMap.put(key, fn.apply(inter));
}
return resultMap;
}
public static<I,O> Map<I,O> mapPairs(Iterable<I> lst, Function<I,O> fn)
{
return mapPairs(lst,fn,new HashMap<I,O>());
}
public static<I,O> Map<I,O> mapPairs(Iterable<I> lst, Function<I,O> fn, Map<I,O> resultMap)
{
for (I input: lst) {
O output = fn.apply(input);
resultMap.put(input,output);
}
return resultMap;
}
public static<I,O> List<O> map(Iterable<I> lst, Function<I,O> fn) {
return map(lst,fn,(Predicate<O>) Predicates.getTruePredicate());
}
public static<I,O> Iterable<O> lazyMap(Iterable<I> lst, Function<I,O> fn) {
return lazyMap(lst,fn,(Predicate<O>) Predicates.getTruePredicate());
}
public static<I,O> Iterable<O> lazyMap(Iterable<I> lst, Function<I,O> fn, Predicate<O> pred) {
return new LazyIterable<O,I>(lst,fn,pred,20);
}
public static<I,O> List<O> flatMap(Iterable<I> lst,
Function<I,List<O>> fn) {
Predicate<List<O>> p = Predicates.getTruePredicate();
return flatMap(lst,fn,p);
}
public static<I,O> List<O> flatMap(Iterable<I> lst,
Function<I,List<O>> fn,
Predicate<List<O>> pred) {
List<List<O>> lstOfLsts = map(lst,fn,pred);
List<O> init = new ArrayList<O>();
return reduce(lstOfLsts, init,
new Function<Pair<List<O>, List<O>>, List<O>>() {
public List<O> apply(Pair<List<O>, List<O>> input) {
List<O> result = input.getFirst();
result.addAll(input.getSecond());
return result;
}
});
}
public static<I,O> O reduce(Iterable<I> inputs,
O initial,
Function<Pair<O,I>,O> fn) {
O output = initial;
for (I input: inputs) {
output = fn.apply(Pair.newPair(output,input));
}
return output;
}
public static<I,O> List<O> map(Iterable<I> lst, Function<I,O> fn, Predicate<O> pred) {
List<O> outputs = new ArrayList();
for (I input: lst) {
O output = fn.apply(input);
if (pred.apply(output)) {
outputs.add(output);
}
}
return outputs;
}
public static<I> List<I> filter(final Iterable<I> lst, final Predicate<I> pred) {
List<I> ret = new ArrayList<I>();
for (I input : lst) {
if (pred.apply(input)) ret.add(input);
}
return ret;
}
public static <O,T> Function getAccessor(String field, Class c) {
final Method trgMeth = getMethod(c, field);
final Field trgField = getField(c, field);
if (trgMeth == null && trgField == null) {
throw new RuntimeException("Couldn't find field or method to access " + field);
}
return new Function<O,T>() {
public T apply(O input) {
try {
return (T) (trgMeth != null ? trgMeth.invoke(input) : trgField.get(input));
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Error accessing Method or target");
}
}
};
}
public static<K,O> Map<K, Collection<O>>
groupBy(Iterable<O> objs, Function<O,K> groupFn) {
return groupBy(objs,groupFn, new Factory<Collection<O>>() {
public Collection<O> newInstance(Object... args) {
return new ArrayList<O>();
}
});
}
public static<K,O> Map<K, Collection<O>> groupBy(Iterable<O> objs, String field) {
return groupBy(objs,getAccessor(field,objs.iterator().next().getClass()));
}
/**
* Groups <code>objs</code> by the field <code>field</code>. Tries
* to find public method getField, ignoring case, then to directly
* access the field if that fails.
* @param objs
* @param field
* @return
*/
public static<K,O, C extends Collection<O>> Map<K,C>
groupBy(Iterable<O> objs, Function<O,K> groupFn, final Factory<C> fact) {
Iterator<O> it = objs.iterator();
if (!it.hasNext()) return new HashMap<K,C>();
Map<K,C> map = new HashMap<K,C>();
for (O obj: objs) {
K key = null;
try {
key = (K) groupFn.apply(obj);
} catch (Exception e) {
e.printStackTrace();
return null;
}
CollectionUtils.addToValueCollection(map,key,obj, fact);
}
return map;
}
public static <T> T first(Iterable<T> objs, Predicate<T> pred) {
for (T obj : objs) {
if (pred.apply(obj)) return obj;
}
return null;
}
public static<O,K> List<O> filter(Iterable<O> coll, final String field, final K value) throws Exception {
Iterator<O> it = coll.iterator();
if (!it.hasNext()) return new ArrayList<O>();
Class c = it.next().getClass();
final Method m = getMethod(c,field);
final Field f = getField(c,field);
return filter(coll, new Predicate<O>() {
public Boolean apply(O input) {
try {
K inputVal = (K)(m != null ? m.invoke(input) : f.get(input));
return inputVal.equals(value);
} catch (Exception e) { }
return false;
}
});
}
public static List<Integer> range(int n) {
List<Integer> result = new ArrayList<Integer>();
for (int i = 0; i < n; i++) {
result.add(i);
}
return result;
}
/**
* Testing Purposes
*/
private static class Person {
public String prefix ;
public String name;
public Person(String name) {
this.name = name;
this.prefix = name.substring(0,3);
}
public String toString() { return "Person(" + name + ")"; }
}
public static <T> T find(Iterable<T> elems, Predicate<T> pred) {
for (T elem : elems) {
if (pred.apply(elem)) return elem;
}
return null;
}
public static void main(String[] args) throws Exception {
List<Person> objs = CollectionUtils.makeList(
new Person("david"),
new Person("davs"),
new Person("maria"),
new Person("marshia")
);
Map<String, Collection<Person>> grouped = groupBy(objs,getAccessor("prefix",Person.class));
System.out.printf("groupd: %s",grouped);
}
}