package edu.berkeley.nlp.util.functional; import edu.berkeley.nlp.util.CollectionUtils; import edu.berkeley.nlp.util.Factory; import edu.berkeley.nlp.util.LazyIterable; import edu.berkeley.nlp.util.Pair; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.*; /** * Collection of Functional Utilities you'd * find in any functional programming language. * Things like map, filter, reduce, etc.. * * Created by IntelliJ IDEA. * User: aria42 * Date: Oct 7, 2008 * Time: 1:06:08 PM */ public class FunctionalUtils { public static <T> List<T> take(Iterator<T> it, int n) { List<T> result = new ArrayList<T>(); for (int i=0; i < n && it.hasNext(); ++i) { result.add(it.next()); } return result; } private static Method getMethod(Class c, String field) { Method[] methods = c.getDeclaredMethods(); String trgMethName = "get" + field; Method trgMeth = null; for (Method m: methods) { if (m.getName().equalsIgnoreCase(trgMethName) || m.getName().equalsIgnoreCase(field)) { return m; } } return null; } private static Field getField(Class c, String fieldName) { Field[] fields = c.getDeclaredFields(); for (Field f: fields) { if (f.getName().equalsIgnoreCase(fieldName)) { return f; } } return null; } public static <T> Pair<T,Double> findMax(Iterable<T> xs, Function<T,Double> fn) { double max = Double.NEGATIVE_INFINITY; T argMax = null; for (T x : xs) { double val = fn.apply(x); if (val > max) { max = val ; argMax = x; } } return Pair.newPair(argMax,max); } public static <T> Pair<T,Double> findMin(Iterable<T> xs, Function<T,Double> fn) { double min= Double.POSITIVE_INFINITY; T argMin = null; for (T x : xs) { double val = fn.apply(x); if (val < min) { min= val ; argMin = x; } } return Pair.newPair(argMin,min); } public static<K,I,V> Map<K,V> compose(Map<K,I> map, Function<I,V> fn) { return map(map,fn, (Predicate<K>) Predicates.getTruePredicate(),new HashMap<K,V>()); } public static<K,I,V> Map<K,V> compose(Map<K,I> map, Function<I,V> fn, Predicate<K> pred) { return map(map,fn,pred,new HashMap<K,V>()); } public static <C> List make(Factory<C> factory, int k ) { List<C> insts = new ArrayList<C>(); for (int i = 0; i < k; i++) { insts.add(factory.newInstance()); } // Fuck you cvs return insts; } public static<K,I,V> Map<K,V> map(Map<K,I> map, Function<I,V> fn, Predicate<K> pred, Map<K,V> resultMap) { for (Map.Entry<K,I> entry: map.entrySet()) { K key = entry.getKey(); I inter = entry.getValue(); if (pred.apply(key)) resultMap.put(key, fn.apply(inter)); } return resultMap; } public static<I,O> Map<I,O> mapPairs(Iterable<I> lst, Function<I,O> fn) { return mapPairs(lst,fn,new HashMap<I,O>()); } public static<I,O> Map<I,O> mapPairs(Iterable<I> lst, Function<I,O> fn, Map<I,O> resultMap) { for (I input: lst) { O output = fn.apply(input); resultMap.put(input,output); } return resultMap; } public static<I,O> List<O> map(Iterable<I> lst, Function<I,O> fn) { return map(lst,fn,(Predicate<O>) Predicates.getTruePredicate()); } public static<I,O> Iterable<O> lazyMap(Iterable<I> lst, Function<I,O> fn) { return lazyMap(lst,fn,(Predicate<O>) Predicates.getTruePredicate()); } public static<I,O> Iterable<O> lazyMap(Iterable<I> lst, Function<I,O> fn, Predicate<O> pred) { return new LazyIterable<O,I>(lst,fn,pred,20); } public static<I,O> List<O> flatMap(Iterable<I> lst, Function<I,List<O>> fn) { Predicate<List<O>> p = Predicates.getTruePredicate(); return flatMap(lst,fn,p); } public static<I,O> List<O> flatMap(Iterable<I> lst, Function<I,List<O>> fn, Predicate<List<O>> pred) { List<List<O>> lstOfLsts = map(lst,fn,pred); List<O> init = new ArrayList<O>(); return reduce(lstOfLsts, init, new Function<Pair<List<O>, List<O>>, List<O>>() { public List<O> apply(Pair<List<O>, List<O>> input) { List<O> result = input.getFirst(); result.addAll(input.getSecond()); return result; } }); } public static<I,O> O reduce(Iterable<I> inputs, O initial, Function<Pair<O,I>,O> fn) { O output = initial; for (I input: inputs) { output = fn.apply(Pair.newPair(output,input)); } return output; } public static<I,O> List<O> map(Iterable<I> lst, Function<I,O> fn, Predicate<O> pred) { List<O> outputs = new ArrayList(); for (I input: lst) { O output = fn.apply(input); if (pred.apply(output)) { outputs.add(output); } } return outputs; } public static<I> List<I> filter(final Iterable<I> lst, final Predicate<I> pred) { List<I> ret = new ArrayList<I>(); for (I input : lst) { if (pred.apply(input)) ret.add(input); } return ret; } public static <O,T> Function getAccessor(String field, Class c) { final Method trgMeth = getMethod(c, field); final Field trgField = getField(c, field); if (trgMeth == null && trgField == null) { throw new RuntimeException("Couldn't find field or method to access " + field); } return new Function<O,T>() { public T apply(O input) { try { return (T) (trgMeth != null ? trgMeth.invoke(input) : trgField.get(input)); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException("Error accessing Method or target"); } } }; } public static<K,O> Map<K, Collection<O>> groupBy(Iterable<O> objs, Function<O,K> groupFn) { return groupBy(objs,groupFn, new Factory<Collection<O>>() { public Collection<O> newInstance(Object... args) { return new ArrayList<O>(); } }); } public static<K,O> Map<K, Collection<O>> groupBy(Iterable<O> objs, String field) { return groupBy(objs,getAccessor(field,objs.iterator().next().getClass())); } /** * Groups <code>objs</code> by the field <code>field</code>. Tries * to find public method getField, ignoring case, then to directly * access the field if that fails. * @param objs * @param field * @return */ public static<K,O, C extends Collection<O>> Map<K,C> groupBy(Iterable<O> objs, Function<O,K> groupFn, final Factory<C> fact) { Iterator<O> it = objs.iterator(); if (!it.hasNext()) return new HashMap<K,C>(); Map<K,C> map = new HashMap<K,C>(); for (O obj: objs) { K key = null; try { key = (K) groupFn.apply(obj); } catch (Exception e) { e.printStackTrace(); return null; } CollectionUtils.addToValueCollection(map,key,obj, fact); } return map; } public static <T> T first(Iterable<T> objs, Predicate<T> pred) { for (T obj : objs) { if (pred.apply(obj)) return obj; } return null; } public static<O,K> List<O> filter(Iterable<O> coll, final String field, final K value) throws Exception { Iterator<O> it = coll.iterator(); if (!it.hasNext()) return new ArrayList<O>(); Class c = it.next().getClass(); final Method m = getMethod(c,field); final Field f = getField(c,field); return filter(coll, new Predicate<O>() { public Boolean apply(O input) { try { K inputVal = (K)(m != null ? m.invoke(input) : f.get(input)); return inputVal.equals(value); } catch (Exception e) { } return false; } }); } public static List<Integer> range(int n) { List<Integer> result = new ArrayList<Integer>(); for (int i = 0; i < n; i++) { result.add(i); } return result; } /** * Testing Purposes */ private static class Person { public String prefix ; public String name; public Person(String name) { this.name = name; this.prefix = name.substring(0,3); } public String toString() { return "Person(" + name + ")"; } } public static <T> T find(Iterable<T> elems, Predicate<T> pred) { for (T elem : elems) { if (pred.apply(elem)) return elem; } return null; } public static void main(String[] args) throws Exception { List<Person> objs = CollectionUtils.makeList( new Person("david"), new Person("davs"), new Person("maria"), new Person("marshia") ); Map<String, Collection<Person>> grouped = groupBy(objs,getAccessor("prefix",Person.class)); System.out.printf("groupd: %s",grouped); } }