/* * Copyright 2015 Goldman Sachs. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.gs.collections.impl.bag.sorted.mutable; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.Arrays; import java.util.Comparator; import java.util.Iterator; import java.util.Map; import java.util.Set; import com.gs.collections.api.RichIterable; import com.gs.collections.api.bag.Bag; import com.gs.collections.api.bag.sorted.MutableSortedBag; import com.gs.collections.api.bag.sorted.SortedBag; import com.gs.collections.api.block.function.Function; import com.gs.collections.api.block.function.Function0; import com.gs.collections.api.block.predicate.Predicate; import com.gs.collections.api.block.predicate.Predicate2; import com.gs.collections.api.block.predicate.primitive.IntPredicate; import com.gs.collections.api.block.procedure.Procedure; import com.gs.collections.api.block.procedure.Procedure2; import com.gs.collections.api.block.procedure.primitive.ObjectIntProcedure; import com.gs.collections.api.map.sorted.MutableSortedMap; import com.gs.collections.api.ordered.OrderedIterable; import com.gs.collections.api.set.sorted.MutableSortedSet; import com.gs.collections.api.stack.MutableStack; import com.gs.collections.api.tuple.Pair; import com.gs.collections.impl.Counter; import com.gs.collections.impl.block.factory.Comparators; import com.gs.collections.impl.block.procedure.checked.CheckedProcedure2; import com.gs.collections.impl.map.sorted.mutable.TreeSortedMap; import com.gs.collections.impl.multimap.bag.sorted.mutable.TreeBagMultimap; import com.gs.collections.impl.set.sorted.mutable.TreeSortedSet; import com.gs.collections.impl.stack.mutable.ArrayStack; import com.gs.collections.impl.utility.Iterate; import com.gs.collections.impl.utility.ListIterate; import com.gs.collections.impl.utility.OrderedIterate; import com.gs.collections.impl.utility.internal.IterableIterate; import com.gs.collections.impl.utility.internal.SortedBagIterables; /** * A TreeBag is a MutableSortedBag which uses a SortedMap as its underlying data store. Each key in the SortedMap represents some item, * and the value in the map represents the current number of occurrences of that item. * * @since 4.2 */ public class TreeBag<T> extends AbstractMutableSortedBag<T> implements Externalizable { private static final Function0<Counter> NEW_COUNTER_BLOCK = new Function0<Counter>() { public Counter value() { return new Counter(); } }; private static final long serialVersionUID = 1L; private MutableSortedMap<T, Counter> items; private int size; public TreeBag() { this.items = TreeSortedMap.newMap(); } private TreeBag(MutableSortedMap<T, Counter> map) { this.items = map; this.size = (int) map.valuesView().sumOfInt(Counter.TO_COUNT); } public TreeBag(Comparator<? super T> comparator) { this.items = TreeSortedMap.newMap(comparator); } public TreeBag(SortedBag<T> sortedBag) { this(sortedBag.comparator(), sortedBag); } public TreeBag(Comparator<? super T> comparator, Iterable<? extends T> iterable) { this(comparator); this.addAllIterable(iterable); } public static <E> TreeBag<E> newBag() { return new TreeBag<E>(); } public static <E> TreeBag<E> newBag(Comparator<? super E> comparator) { return new TreeBag<E>(comparator); } public static <E> TreeBag<E> newBag(Iterable<? extends E> source) { if (source instanceof SortedBag<?>) { return new TreeBag<E>((SortedBag<E>) source); } return Iterate.addAllTo(source, TreeBag.<E>newBag()); } public static <E> TreeBag<E> newBag(Comparator<? super E> comparator, Iterable<? extends E> iterable) { return new TreeBag<E>(comparator, iterable); } public static <E> TreeBag<E> newBagWith(E... elements) { //noinspection SSBasedInspection return TreeBag.newBag(Arrays.asList(elements)); } public static <E> TreeBag<E> newBagWith(Comparator<? super E> comparator, E... elements) { //noinspection SSBasedInspection return TreeBag.newBag(comparator, Arrays.asList(elements)); } @Override public TreeBag<T> clone() { return new TreeBag<T>(this); } @Override public boolean equals(Object other) { if (this == other) { return true; } if (!(other instanceof Bag)) { return false; } final Bag<?> bag = (Bag<?>) other; if (this.sizeDistinct() != bag.sizeDistinct()) { return false; } return this.items.keyValuesView().allSatisfy(new Predicate<Pair<T, Counter>>() { public boolean accept(Pair<T, Counter> each) { return bag.occurrencesOf(each.getOne()) == each.getTwo().getCount(); } }); } @Override public int hashCode() { final Counter counter = new Counter(); this.forEachWithOccurrences(new ObjectIntProcedure<T>() { public void value(T each, int count) { counter.add((each == null ? 0 : each.hashCode()) ^ count); } }); return counter.getCount(); } @Override protected RichIterable<T> getKeysView() { return this.items.keysView(); } public int sizeDistinct() { return this.items.size(); } public void forEachWithOccurrences(final ObjectIntProcedure<? super T> procedure) { this.items.forEachKeyValue(new Procedure2<T, Counter>() { public void value(T item, Counter count) { procedure.value(item, count.getCount()); } }); } public MutableSortedBag<T> selectByOccurrences(final IntPredicate predicate) { MutableSortedMap<T, Counter> map = this.items.select(new Predicate2<T, Counter>() { public boolean accept(T each, Counter occurrences) { return predicate.accept(occurrences.getCount()); } }); return new TreeBag<T>(map); } public int occurrencesOf(Object item) { Counter counter = this.items.get(item); return counter == null ? 0 : counter.getCount(); } @Override public boolean isEmpty() { return this.items.isEmpty(); } public boolean remove(Object item) { Counter counter = this.items.get(item); if (counter != null) { if (counter.getCount() > 1) { counter.decrement(); } else { this.items.remove(item); } this.size--; return true; } return false; } public void clear() { this.items.clear(); this.size = 0; } @Override public boolean contains(Object o) { return this.items.containsKey(o); } public int compareTo(SortedBag<T> otherBag) { return SortedBagIterables.compare(this, otherBag); } public void writeExternal(final ObjectOutput out) throws IOException { out.writeObject(this.comparator()); out.writeInt(this.items.size()); try { this.items.forEachKeyValue(new CheckedProcedure2<T, Counter>() { public void safeValue(T object, Counter parameter) throws Exception { out.writeObject(object); out.writeInt(parameter.getCount()); } }); } catch (RuntimeException e) { if (e.getCause() instanceof IOException) { throw (IOException) e.getCause(); } throw e; } } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { this.items = new TreeSortedMap<T, Counter>((Comparator<T>) in.readObject()); int size = in.readInt(); for (int i = 0; i < size; i++) { this.addOccurrences((T) in.readObject(), in.readInt()); } } public void each(final Procedure<? super T> procedure) { this.items.forEachKeyValue(new Procedure2<T, Counter>() { public void value(T key, Counter value) { for (int i = 0; i < value.getCount(); i++) { procedure.value(key); } } }); } @Override public void forEachWithIndex(final ObjectIntProcedure<? super T> objectIntProcedure) { final Counter index = new Counter(); this.items.forEachKeyValue(new Procedure2<T, Counter>() { public void value(T key, Counter value) { for (int i = 0; i < value.getCount(); i++) { objectIntProcedure.value(key, index.getCount()); index.increment(); } } }); } public void forEach(int fromIndex, int toIndex, Procedure<? super T> procedure) { ListIterate.rangeCheck(fromIndex, toIndex, this.size); if (fromIndex > toIndex) { throw new IllegalArgumentException("fromIndex must not be greater than toIndex"); } Iterator<Map.Entry<T, Counter>> iterator = this.items.entrySet().iterator(); int i = 0; while (iterator.hasNext() && i < fromIndex) { Map.Entry<T, Counter> entry = iterator.next(); Counter value = entry.getValue(); int count = value.getCount(); if (i + count < fromIndex) { i += count; } else { for (int j = 0; j < count; j++) { if (i >= fromIndex && i <= toIndex) { procedure.value(entry.getKey()); } i++; } } } while (iterator.hasNext() && i <= toIndex) { Map.Entry<T, Counter> entry = iterator.next(); Counter value = entry.getValue(); int count = value.getCount(); for (int j = 0; j < count; j++) { if (i <= toIndex) { procedure.value(entry.getKey()); } i++; } } } public void forEachWithIndex(int fromIndex, int toIndex, ObjectIntProcedure<? super T> objectIntProcedure) { ListIterate.rangeCheck(fromIndex, toIndex, this.size); if (fromIndex > toIndex) { throw new IllegalArgumentException("fromIndex must not be greater than toIndex"); } Iterator<Map.Entry<T, Counter>> iterator = this.items.entrySet().iterator(); int i = 0; while (iterator.hasNext() && i < fromIndex) { Map.Entry<T, Counter> entry = iterator.next(); Counter value = entry.getValue(); int count = value.getCount(); if (i + count < fromIndex) { i += count; } else { for (int j = 0; j < count; j++) { if (i >= fromIndex && i <= toIndex) { objectIntProcedure.value(entry.getKey(), i); } i++; } } } while (iterator.hasNext() && i <= toIndex) { Map.Entry<T, Counter> entry = iterator.next(); Counter value = entry.getValue(); int count = value.getCount(); for (int j = 0; j < count; j++) { if (i <= toIndex) { objectIntProcedure.value(entry.getKey(), i); } i++; } } } @Override public <P> void forEachWith(final Procedure2<? super T, ? super P> procedure, final P parameter) { this.items.forEachKeyValue(new Procedure2<T, Counter>() { public void value(T key, Counter value) { for (int i = 0; i < value.getCount(); i++) { procedure.value(key, parameter); } } }); } public Iterator<T> iterator() { return new InternalIterator(); } public void addOccurrences(T item, int occurrences) { if (occurrences < 0) { throw new IllegalArgumentException("Cannot add a negative number of occurrences"); } if (occurrences > 0) { this.items.getIfAbsentPut(item, NEW_COUNTER_BLOCK).add(occurrences); this.size += occurrences; } } public boolean removeOccurrences(Object item, int occurrences) { if (occurrences < 0) { throw new IllegalArgumentException("Cannot remove a negative number of occurrences"); } if (occurrences == 0) { return false; } Counter counter = this.items.get(item); if (counter == null) { return false; } int startCount = counter.getCount(); if (occurrences >= startCount) { this.items.remove(item); this.size -= startCount; return true; } counter.add(occurrences * -1); this.size -= occurrences; return true; } public boolean setOccurrences(T item, int occurrences) { if (occurrences < 0) { throw new IllegalArgumentException("Cannot set a negative number of occurrences"); } int originalOccurrences = this.occurrencesOf(item); if (originalOccurrences == occurrences) { return false; } if (occurrences == 0) { this.items.remove(item); } else { this.items.put(item, new Counter(occurrences)); } this.size -= originalOccurrences - occurrences; return true; } public TreeBag<T> without(T element) { this.remove(element); return this; } public TreeBag<T> withAll(Iterable<? extends T> iterable) { this.addAllIterable(iterable); return this; } public TreeBag<T> withoutAll(Iterable<? extends T> iterable) { this.removeAllIterable(iterable); return this; } public TreeBag<T> with(T element) { this.add(element); return this; } public MutableSortedBag<T> newEmpty() { return TreeBag.newBag(this.items.comparator()); } public boolean removeIf(Predicate<? super T> predicate) { boolean changed = false; Set<Map.Entry<T, Counter>> entries = this.items.entrySet(); for (Iterator<Map.Entry<T, Counter>> iterator = entries.iterator(); iterator.hasNext(); ) { Map.Entry<T, Counter> entry = iterator.next(); if (predicate.accept(entry.getKey())) { this.size -= entry.getValue().getCount(); iterator.remove(); changed = true; } } return changed; } public <P> boolean removeIfWith(Predicate2<? super T, ? super P> predicate, P parameter) { boolean changed = false; Set<Map.Entry<T, Counter>> entries = this.items.entrySet(); for (Iterator<Map.Entry<T, Counter>> iterator = entries.iterator(); iterator.hasNext(); ) { Map.Entry<T, Counter> entry = iterator.next(); if (predicate.accept(entry.getKey(), parameter)) { this.size -= entry.getValue().getCount(); iterator.remove(); changed = true; } } return changed; } public boolean removeAllIterable(Iterable<?> iterable) { int oldSize = this.size; for (Object each : iterable) { Counter removed = this.items.remove(each); if (removed != null) { this.size -= removed.getCount(); } } return this.size != oldSize; } public int size() { return this.size; } public int indexOf(Object object) { if (this.items.containsKey(object)) { long result = this.items.headMap((T) object).values().sumOfInt(Counter.TO_COUNT); if (result > Integer.MAX_VALUE) { throw new IllegalStateException(); } return (int) result; } return -1; } public MutableSortedSet<Pair<T, Integer>> zipWithIndex() { final Comparator<? super T> comparator = this.items.comparator(); return this.zipWithIndex(TreeSortedSet.newSet(new Comparator<Pair<T, Integer>>() { public int compare(Pair<T, Integer> o1, Pair<T, Integer> o2) { int compare = comparator == null ? Comparators.nullSafeCompare(o1, o2) : comparator.compare(o1.getOne(), o2.getOne()); if (compare != 0) { return compare; } return o1.getTwo().compareTo(o2.getTwo()); } })); } public MutableSortedSet<T> distinct() { return TreeSortedSet.newSet(this.comparator(), this.items.keySet()); } public <V> TreeBagMultimap<V, T> groupBy(Function<? super T, ? extends V> function) { return this.groupBy(function, TreeBagMultimap.<V, T>newMultimap(this.comparator())); } public <V> TreeBagMultimap<V, T> groupByEach(Function<? super T, ? extends Iterable<V>> function) { return this.groupByEach(function, TreeBagMultimap.<V, T>newMultimap(this.comparator())); } public int detectIndex(Predicate<? super T> predicate) { return Iterate.detectIndex(this, predicate); } public <S> boolean corresponds(OrderedIterable<S> other, Predicate2<? super T, ? super S> predicate) { return OrderedIterate.corresponds(this, other, predicate); } public MutableStack<T> toStack() { return ArrayStack.newStack(this); } public MutableSortedBag<T> take(int count) { if (count < 0) { throw new IllegalArgumentException("Count must be greater than zero, but was: " + count); } return IterableIterate.take(this, Math.min(this.size(), count), this.newEmpty()); } public MutableSortedBag<T> drop(int count) { if (count < 0) { throw new IllegalArgumentException("Count must be greater than zero, but was: " + count); } return IterableIterate.drop(this, count, this.newEmpty()); } public Comparator<? super T> comparator() { return this.items.comparator(); } public TreeBag<T> with(T... elements) { this.addAll(Arrays.asList(elements)); return this; } public TreeBag<T> with(T element1, T element2) { this.add(element1); this.add(element2); return this; } public boolean add(T item) { Counter counter = this.items.getIfAbsentPut(item, NEW_COUNTER_BLOCK); counter.increment(); this.size++; return true; } public TreeBag<T> with(T element1, T element2, T element3) { this.add(element1); this.add(element2); this.add(element3); return this; } private class InternalIterator implements Iterator<T> { private final Iterator<T> iterator = TreeBag.this.items.keySet().iterator(); private T currentItem; private int occurrences; private boolean canRemove; public boolean hasNext() { return this.occurrences > 0 || this.iterator.hasNext(); } public T next() { if (this.occurrences == 0) { this.currentItem = this.iterator.next(); this.occurrences = TreeBag.this.occurrencesOf(this.currentItem); } this.occurrences--; this.canRemove = true; return this.currentItem; } public void remove() { if (!this.canRemove) { throw new IllegalStateException(); } if (this.occurrences == 0) { this.iterator.remove(); TreeBag.this.size--; } else { TreeBag.this.remove(this.currentItem); } this.canRemove = false; } } }