package org.core4j;
import java.lang.reflect.Array;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
public class Enumerable<T> implements Iterable<T> {
private final Iterable<T> values;
protected Enumerable(Iterable<T> values) {
if (values == null) {
throw new RuntimeException("values cannot be null");
}
this.values = values;
}
public static <T> Enumerable<T> create(T... values) {
return new Enumerable<T>(new ArrayIterable<T>(values));
}
public static <T> Enumerable<T> create(Iterable<T> values) {
return new Enumerable<T>(values);
}
@SuppressWarnings("unchecked")
public static <T> Enumerable<T> create(Class<T> clazz, Enumeration<?> e) {
List<T> rt = new ArrayList<T>();
while (e.hasMoreElements()) {
rt.add((T) e.nextElement());
}
return new Enumerable<T>(rt);
}
public static <T> Enumerable<T> createFromIterator(final Func<Iterator<T>> fn) {
return new Enumerable<T>(makeIterable(fn));
}
@SuppressWarnings("unchecked")
public T[] toArray(Class<T> clazz) {
List<T> rt = toList();
T[] array = (T[]) Array.newInstance(clazz, rt.size());
for (int i = 0; i < array.length; i++) {
array[i] = rt.get(i);
}
return array;
}
public List<T> toList() {
List<T> rt = new ArrayList<T>();
for (T value : values) {
rt.add(value);
}
return rt;
}
public Set<T> toSet() {
Set<T> rt = new HashSet<T>();
for (T value : values) {
rt.add(value);
}
return rt;
}
public SortedSet<T> toSortedSet() {
SortedSet<T> rt = new TreeSet<T>();
for (T value : values) {
rt.add(value);
}
return rt;
}
public SortedSet<T> toSortedSet(Comparator<? super T> comparator) {
SortedSet<T> rt = new TreeSet<T>(comparator);
for (T value : values) {
rt.add(value);
}
return rt;
}
public <K> Map<K, T> toMap(Func1<T, K> keyFn) {
Map<K, T> rt = new HashMap<K, T>();
for (T value : values) {
rt.put(keyFn.apply(value), value);
}
return rt;
}
public int count() {
int rt = 0;
for (@SuppressWarnings("unused")
T value : values) {
rt++;
}
return rt;
}
public T first() {
for (T value : values) {
return value;
}
throw new RuntimeException("No elements");
}
public T first(Predicate1<T> predicate) {
for (T value : values) {
if (predicate.apply(value)) {
return value;
}
}
throw new RuntimeException("No elements match the predicate");
}
public T firstOrNull() {
for (T value : values) {
return value;
}
return null;
}
public T firstOrNull(Predicate1<T> predicate) {
for (T value : values) {
if (predicate.apply(value)) {
return value;
}
}
return null;
}
public Enumerable<T> where(Predicate1<T> predicate) {
return new Enumerable<T>(new PredicateIterable<T>(this, predicate));
}
public <TOutput> Enumerable<TOutput> select(Func1<T, TOutput> projection) {
return new Enumerable<TOutput>(new FuncIterable<T, TOutput>(this, projection));
}
public static <K, E> Map<K, List<E>> group(Collection<E> c, final Func1<E, K> projection) {
Map<K, List<E>> map = new HashMap<K, List<E>>();
for (E e : c) {
K key = projection.apply(e);
if (key != null) {
List<E> list = map.get(key);
if (list == null) {
list = new ArrayList<E>();
map.put(key, list);
}
list.add(e);
}
}
return map;
}
public T last() {
T rt = null;
boolean empty = true;
for (T value : values) {
empty = false;
rt = value;
}
if (empty) {
throw new RuntimeException("No elements");
}
return rt;
}
public Iterator<T> iterator() {
return values.iterator();
}
public Enumerable<T> reverse() {
List<T> rt = this.toList();
Collections.reverse(rt);
return new Enumerable<T>(rt);
}
private List<Iterable<T>> thisThenOthers(Iterable<T>... others) {
List<Iterable<T>> rt = new ArrayList<Iterable<T>>();
rt.add(this);
for (Iterable<T> other : others) {
rt.add(other);
}
return rt;
}
@SuppressWarnings("unchecked")
public Enumerable<T> concat(Iterable<T> other) {
return concat(new Iterable[] { other });
}
public Enumerable<T> concat(Iterable<T>... others) {
List<Iterable<T>> rt = thisThenOthers(others);
return new Enumerable<T>(new ConcatIterable<T>(rt));
}
@SuppressWarnings("unchecked")
public Enumerable<T> concat(T... others) {
return concat(new Enumerable[] { Enumerable.create(others) });
}
public Enumerable<T> take(final int count) {
return createFromIterator(new Func<Iterator<T>>() {
public Iterator<T> apply() {
return new TakeIterator<T>(Enumerable.this, count);
}
});
}
private static class TakeIterator<T> extends ReadOnlyIterator<T> {
private int left;
private Iterator<T> iterator;
public TakeIterator(Iterable<T> values, int count) {
iterator = values.iterator();
left = count;
}
@Override
protected IterationResult<T> advance() throws Exception {
if (left <= 0) {
return IterationResult.done();
}
if (!iterator.hasNext()) {
return IterationResult.done();
}
left--;
return IterationResult.next(iterator.next());
}
}
public boolean any(Predicate1<T> predicate) {
for (T value : values) {
if (predicate.apply(value)) {
return true;
}
}
return false;
}
public boolean all(Predicate1<T> predicate) {
for (T value : values) {
if (!predicate.apply(value)) {
return false;
}
}
return true;
}
public boolean contains(T value) {
for (T existingValue : values) {
if (existingValue.equals(value)) {
return true;
}
}
return false;
}
public T elementAt(int index) {
int i = 0;
for (T value : values) {
if (index == i++) {
return value;
}
}
throw new RuntimeException("No element at index " + index);
}
public T elementAtOrNull(int index) {
int i = 0;
for (T value : values) {
if (index == i++) {
return value;
}
}
return null;
}
public <TReturn> TReturn aggregate(Class<TReturn> clazz, Func2<T, TReturn, TReturn> aggregation) {
return aggregate(clazz, null, aggregation);
}
public <TReturn> TReturn aggregate(Class<TReturn> clazz, TReturn initialValue, Func2<T, TReturn, TReturn> aggregation) {
TReturn rt = initialValue;
for (T value : values) {
rt = aggregation.apply(value, rt);
}
return rt;
}
public <TReturn> TReturn sum(final Class<TReturn> clazz) {
if (clazz.equals(Double.class) || clazz.equals(Integer.class) || clazz.equals(BigDecimal.class)) {
Func2<T, TReturn, TReturn> aggregation = new Func2<T, TReturn, TReturn>() {
@SuppressWarnings("unchecked")
public TReturn apply(T input1, TReturn input2) {
Number n1 = (Number) input1; // assumes T (this) is Number
Number n2 = (Number) input2; // this is safe, one of Double,Integer,BigDecimal
// TODO better way?
if (clazz.equals(Double.class)) {
Double rt = n1.doubleValue() + (n2 == null ? 0 : n2.doubleValue());
return (TReturn) rt;
}
if (clazz.equals(Integer.class)) {
Integer rt = n1.intValue() + (n2 == null ? 0 : n2.intValue());
return (TReturn) rt;
}
if (clazz.equals(BigDecimal.class)) {
if (n1 instanceof Integer) {
n1 = BigDecimal.valueOf((Integer) n1);
}
if (n1 instanceof Double) {
n1 = BigDecimal.valueOf((Double) n1);
}
BigDecimal bd1 = n1 == null ? BigDecimal.ZERO : (BigDecimal) n1;
BigDecimal bd2 = n2 == null ? BigDecimal.ZERO : (BigDecimal) n2;
BigDecimal rt = bd1.add(bd2);
return (TReturn) rt;
}
throw new UnsupportedOperationException("No default aggregation for class " + clazz.getSimpleName());
}
};
return aggregate(clazz, aggregation);
}
throw new UnsupportedOperationException("No default aggregation for class " + clazz.getSimpleName());
}
public <TReturn> TReturn sum(Class<TReturn> clazz, Func1<T, TReturn> projection) {
Enumerable<TReturn> rt = this.select(projection);
return rt.sum(clazz);
}
private static class ArrayIterable<T> implements Iterable<T> {
private final T[] values;
public ArrayIterable(T[] values) {
this.values = values;
}
public Iterator<T> iterator() {
return new ArrayIterator<T>(values);
}
}
private static class ArrayIterator<T> implements Iterator<T> {
private final T[] values;
private int current = -1;
public ArrayIterator(T[] values) {
this.values = values;
}
public boolean hasNext() {
return current < (values.length - 1);
}
public T next() {
try {
return values[++current];
} catch (ArrayIndexOutOfBoundsException e) {
throw new NoSuchElementException();
}
}
public void remove() {
throw new UnsupportedOperationException("remove()");
}
}
private static class PredicateIterable<T> implements Iterable<T> {
private final Iterable<T> iterable;
private final Predicate1<T> predicate;
public PredicateIterable(Iterable<T> i, Predicate1<T> p) {
this.iterable = i;
this.predicate = p;
}
public Iterator<T> iterator() {
return new PredicateIterator<T>(iterable.iterator(), predicate);
}
}
private static class PredicateIterator<T> extends ReadOnlyIterator<T> {
private final Iterator<T> iterator;
private final Predicate1<T> predicate;
private boolean useEx = true; // ( slightly faster in perf tests)
public PredicateIterator(Iterator<T> i, Predicate1<T> p) {
this.iterator = i;
this.predicate = p;
}
@Override
protected IterationResult<T> advance() {
// exception-backed method
if (useEx) {
try {
T rt = iterator.next();
while (!predicate.apply(rt)) {
rt = iterator.next();
}
return IterationResult.next(rt);
} catch (NoSuchElementException e) {
return IterationResult.done();
}
} else {
// non-exception-backed method
if (iterator.hasNext()) {
T rt = iterator.next();
while (!predicate.apply(rt)) {
if (iterator.hasNext()) {
rt = iterator.next();
} else {
return IterationResult.done();
}
}
return IterationResult.next(rt);
} else {
return IterationResult.done();
}
}
}
}
private static class FuncIterable<X, Y> implements Iterable<Y> {
private final Iterable<X> iterable;
private final Func1<X, Y> projection;
public FuncIterable(Iterable<X> iterable, Func1<X, Y> projection) {
this.iterable = iterable;
this.projection = projection;
}
public Iterator<Y> iterator() {
return new FuncIterator<X, Y>(iterable.iterator(), projection);
}
}
private static class FuncIterator<X, Y> implements Iterator<Y> {
private final Iterator<X> iterator;
private final Func1<X, Y> projection;
public FuncIterator(Iterator<X> iterator, Func1<X, Y> projection) {
this.iterator = iterator;
this.projection = projection;
}
public boolean hasNext() {
return iterator.hasNext();
}
public Y next() {
return projection.apply(iterator.next());
}
public void remove() {
iterator.remove();
}
}
private static class ConcatIterable<T> implements Iterable<T> {
private final Iterable<Iterable<T>> iterables;
public ConcatIterable(Iterable<Iterable<T>> iterables) {
this.iterables = iterables;
}
public Iterator<T> iterator() {
return new ConcatIterator<T>(Enumerable.create(iterables).select(new Func1<Iterable<T>, Iterator<T>>() {
public Iterator<T> apply(Iterable<T> x) {
return x.iterator();
}
}).toList());
}
}
private static class ConcatIterator<T> implements Iterator<T> {
private final List<Iterator<T>> iterators;
private int current;
public ConcatIterator(List<Iterator<T>> iterators) {
this.iterators = iterators;
}
public boolean hasNext() {
boolean rt = iterators.get(current).hasNext();
while (!rt) {
if (current == iterators.size() - 1) {
return rt;
}
current++;
rt = iterators.get(current).hasNext();
}
return rt;
}
public T next() {
while (true) {
try {
return iterators.get(current).next();
} catch (NoSuchElementException e) {
if (current == iterators.size() - 1) {
throw new NoSuchElementException();
}
current++;
}
}
}
public void remove() {
throw new UnsupportedOperationException("remove()");
}
}
public <TKey extends Comparable<TKey>> Enumerable<T> orderBy(final Func1<T, TKey> projection) {
return orderBy(new Comparator<T>() {
public int compare(T o1, T o2) {
TKey lhs = projection.apply(o1);
TKey rhs = projection.apply(o2);
return lhs.compareTo(rhs);
}
});
}
public Enumerable<T> orderBy(Comparator<T> comparator) {
List<T> rt = this.toList();
Collections.sort(rt, comparator);
return Enumerable.create(rt);
}
public Enumerable<T> orderBy() {
return orderBy(new Comparator<T>() {
@SuppressWarnings("unchecked")
public int compare(T o1, T o2) {
Comparable<T> lhs = (Comparable<T>) o1;
return lhs.compareTo(o2);
}
});
}
public String join(String separator) {
StringBuilder rt = new StringBuilder();
boolean isFirst = true;
for (T value : this) {
if (isFirst) {
isFirst = false;
} else {
rt.append(separator);
}
rt.append(value == null ? "" : value.toString());
}
return rt.toString();
}
@SuppressWarnings("unchecked")
public static <T> Enumerable<T> empty(Class<T> clazz) {
return Enumerable.create();
}
public static Enumerable<Integer> range(final int start, final int count) {
return createFromIterator(new Func<Iterator<Integer>>() {
public Iterator<Integer> apply() {
return new RangeIterator(start, count);
}
});
}
private static class RangeIterator extends ReadOnlyIterator<Integer> {
private final int end;
private Integer current;
public RangeIterator(int start, int count) {
current = start;
end = start + count - 1;
}
@Override
protected IterationResult<Integer> advance() throws Exception {
if (current == null) {
return IterationResult.done();
}
int rt = current;
if (rt == end) {
current = null;
} else {
current = rt + 1;
}
return IterationResult.next(rt);
}
}
public <TOutput> Enumerable<TOutput> cast(Class<TOutput> clazz) {
return this.select(new Func1<T, TOutput>() {
@SuppressWarnings("unchecked")
public TOutput apply(T input) {
return (TOutput) input;
}
});
}
public <TOutput> Enumerable<TOutput> ofType(Class<TOutput> clazz) {
final Class<TOutput> finalClazz = clazz;
return this.where(new Predicate1<T>() {
public boolean apply(T input) {
return input != null && finalClazz.isAssignableFrom(input.getClass());
}
}).cast(clazz);
}
public Enumerable<T> skip(int count) {
return Enumerable.create(new SkipEnumerable<T>(this, count));
}
private static class SkipEnumerable<T> implements Iterable<T> {
private final Enumerable<T> target;
private final int count;
public SkipEnumerable(Enumerable<T> target, int count) {
this.target = target;
this.count = count;
}
public Iterator<T> iterator() {
Iterator<T> rt = target.iterator();
for (int i = 0; i < count; i++) {
if (!rt.hasNext()) {
return rt;
}
rt.next();
}
return rt;
}
}
public Enumerable<T> skipWhile(final Predicate1<T> predicate) {
final Boolean[] skipping = new Boolean[] { true };
return this.where(new Predicate1<T>() {
public boolean apply(T input) {
if (!skipping[0]) {
return true;
}
if (!predicate.apply(input)) {
skipping[0] = false;
return true;
}
return false;
}
});
}
@SuppressWarnings("unchecked")
public Enumerable<T> intersect(Enumerable<T> other) {
return intersect(new Enumerable[] { other });
}
public Enumerable<T> intersect(Enumerable<T>... others) {
List<T> rt = this.distinct().toList();
for (Enumerable<T> other : others) {
Set<T> set = other.toSet();
for (T value : Enumerable.create(rt).toList()) {
if (!set.contains(value)) {
rt.remove(value);
}
}
}
return Enumerable.create(rt);
}
@SuppressWarnings("unchecked")
public Enumerable<T> union(Enumerable<T> other) {
return union(new Enumerable[] { other });
}
public Enumerable<T> union(Enumerable<T>... others) {
final List<Iterable<T>> rt = thisThenOthers(others);
return Enumerable.create(makeIterable(new Func<Iterator<T>>() {
public Iterator<T> apply() {
return new UnionIterator<T>(rt);
}
}));
}
private static class UnionIterator<T> extends ReadOnlyIterator<T> {
private final List<Iterable<T>> involved;
public UnionIterator(List<Iterable<T>> involved) {
this.involved = involved;
}
private Set<T> seen;
private int currentIndex = -1;
private Iterator<T> currentIterator;
@Override
protected IterationResult<T> advance() {
if (seen == null) {
seen = new HashSet<T>();
}
while (true) {
if (currentIterator == null) {
currentIndex++;
if (currentIndex >= involved.size()) {
return IterationResult.done();
}
currentIterator = involved.get(currentIndex).iterator();
}
if (!currentIterator.hasNext()) {
currentIterator = null;
} else {
T value = currentIterator.next();
if (!seen.contains(value)) {
seen.add(value);
return IterationResult.next(value);
}
}
}
}
}
private static <T> Iterable<T> makeIterable(final Func<Iterator<T>> fn) {
return new Iterable<T>() {
public Iterator<T> iterator() {
return fn.apply();
}
};
}
public <TResult> Enumerable<TResult> selectMany(final Func1<T, Enumerable<TResult>> selector) {
return Enumerable.createFromIterator(new Func<Iterator<TResult>>() {
public Iterator<TResult> apply() {
return new SelectManyIterator<T, TResult>(Enumerable.this, selector);
}
});
}
private static class SelectManyIterator<TSource, TResult> extends ReadOnlyIterator<TResult> {
private final Iterator<TSource> sourceIterator;
private final Func1<TSource, Enumerable<TResult>> selector;
private Iterator<TResult> resultIterator;
public SelectManyIterator(Iterable<TSource> source, Func1<TSource, Enumerable<TResult>> selector) {
this.selector = selector;
this.sourceIterator = source.iterator();
}
@Override
protected IterationResult<TResult> advance() throws Exception {
while (true) {
if (resultIterator == null) {
if (!sourceIterator.hasNext()) {
return IterationResult.done();
}
TSource source = sourceIterator.next();
resultIterator = selector.apply(source).iterator();
}
if (!resultIterator.hasNext()) {
resultIterator = null;
} else {
return IterationResult.next(resultIterator.next());
}
}
}
}
public Enumerable<T> distinct() {
return Enumerable.createFromIterator(new Func<Iterator<T>>() {
public Iterator<T> apply() {
return new DistinctIterator<T>(Enumerable.this);
}
});
}
private static class DistinctIterator<T> extends ReadOnlyIterator<T> {
private final Iterator<T> iterator;
private Set<T> seen;
public DistinctIterator(Iterable<T> source) {
iterator = source.iterator();
}
@Override
protected IterationResult<T> advance() throws Exception {
if (seen == null) {
seen = new HashSet<T>();
}
while (iterator.hasNext()) {
T value = iterator.next();
if (!seen.contains(value)) {
seen.add(value);
return IterationResult.next(value);
}
}
return IterationResult.done();
}
}
public <TKey> Enumerable<Grouping<TKey, T>> groupBy(Func1<T, TKey> keySelector) {
List<TKey> ordering = new ArrayList<TKey>();
final Map<TKey, List<T>> map = new HashMap<TKey, List<T>>();
for (T value : this) {
TKey key = keySelector.apply(value);
if (!ordering.contains(key)) {
ordering.add(key);
map.put(key, new ArrayList<T>());
}
map.get(key).add(value);
}
return Enumerable.create(ordering).select(new Func1<TKey, Grouping<TKey, T>>() {
public Grouping<TKey, T> apply(TKey input) {
return new Grouping<TKey, T>(input, Enumerable.create(map.get(input)));
}
});
}
public <TResult extends Comparable<TResult>> TResult max(Func1<T, TResult> fn) {
TResult rt = null;
for (T value : this) {
TResult newValue = fn.apply(value);
if (newValue == null) {
continue;
}
if (rt == null) {
rt = newValue;
} else {
if (newValue.compareTo(rt) > 0) {
rt = newValue;
}
}
}
return rt;
}
public <TResult extends Comparable<TResult>> TResult min(Func1<T, TResult> fn) {
TResult rt = null;
for (T value : this) {
TResult newValue = fn.apply(value);
if (newValue == null) {
continue;
}
if (rt == null) {
rt = newValue;
} else {
if (newValue.compareTo(rt) < 0) {
rt = newValue;
}
}
}
return rt;
}
}