package uk.ac.rhul.cs.collections; import uk.ac.rhul.cs.utils.StringUtils; import java.util.Collection; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.TreeMap; /** * Tree-based multiset class * @author tamas * * @param <E> the types of elements to be stored in this multiset */ public class TreeMultiset<E> implements Multiset<E> { public class Entry implements Multiset.Entry<E> { E object; int count; private Entry(Map.Entry<E, Integer> entry) { this.object = entry.getKey(); this.count = entry.getValue(); } public int getCount() { return count; } public E getElement() { return object; } public String toString() { if (count == 1) return object.toString(); StringBuilder sb = new StringBuilder(); sb.append(object.toString()); sb.append(" x "); sb.append(count); return sb.toString(); } } /** * Internal storage area */ protected TreeMap<E, Integer> data; public TreeMultiset() { data = new TreeMap<E, Integer>(); } public int add(E element, int occurrences) { if (occurrences < 0) throw new IllegalArgumentException("occurrences must not be negative"); int count = this.count(element); if (occurrences > 0) data.put(element, count + occurrences); return count; } public int count(Object element) { Integer count = data.get(element); if (count == null) return 0; return count; } public Set<E> elementSet() { return data.keySet(); } public Set<Multiset.Entry<E>> entrySet() { Set<Multiset.Entry<E>> result = new HashSet<Multiset.Entry<E>>(); for (Map.Entry<E, Integer> entry: data.entrySet()) result.add(new Entry(entry)); return result; } public int remove(E element, int occurrences) { if (occurrences < 0) throw new IllegalArgumentException("occurrences must not be negative"); int count = this.count(element); if (occurrences > 0) { if (count <= occurrences) data.remove(element); else data.put(element, count - occurrences); } return count; } public int setCount(E element, int count) { if (count < 0) throw new IllegalArgumentException("count must not be negative"); int oldCount = this.count(element); if (count == 0) data.remove(element); else data.put(element, count); return oldCount; } public boolean add(E element) { int count = this.count(element); data.put(element, count + 1); return count > 0; } public boolean addAll(Collection<? extends E> elements) { boolean result = false; for (E element: elements) result = result | this.add(element); return result; } public void clear() { data.clear(); } public boolean contains(Object key) { return data.containsKey(key); } public boolean containsAll(Collection<?> keys) { for (Object o: keys) if (!data.containsKey(o)) return false; return true; } public boolean isEmpty() { return data.isEmpty(); } public Iterator<E> iterator() { return data.keySet().iterator(); } @SuppressWarnings("unchecked") public boolean remove(Object element) { int count = this.count(element); if (count == 0) return false; if (count == 1) { data.remove(element); return true; } data.put((E) element, count - 1); return true; } public boolean removeAll(Collection<?> elements) { boolean result = false; for (Object element: elements) result = result | this.remove(element); return result; } public boolean retainAll(Collection<?> elements) { throw new UnsupportedOperationException(); } public int size() { int totalSize = 0; for (Map.Entry<E, Integer> entry: data.entrySet()) totalSize += entry.getValue(); return totalSize; } public Object[] toArray() { Object[] result = new Object[this.size()]; int i = 0; for (Map.Entry<E, Integer> entry: data.entrySet()) { E obj = entry.getKey(); int n = entry.getValue(); for (int j = 0; j < n; j++, i++) result[i] = obj; } return result; } @SuppressWarnings("unchecked") public <T> T[] toArray(T[] dummy) { return (T[])(this.toArray()); } public String toString() { return "[" + StringUtils.join(this.entrySet().iterator(), ", ") + "]"; } }