/* * Copyright (C) 2010 The Guava Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.common.collect.testing; import java.io.Serializable; import java.util.Collection; import java.util.Comparator; import java.util.Iterator; import java.util.NavigableSet; import java.util.SortedSet; import java.util.TreeSet; /** * A wrapper around {@code TreeSet} that aggressively checks to see if elements * are mutually comparable. This implementation passes the navigable set test * suites. * * @author Louis Wasserman */ public final class SafeTreeSet<E> implements Serializable, NavigableSet<E> { @SuppressWarnings("unchecked") private static final Comparator NATURAL_ORDER = new Comparator<Comparable>() { @Override public int compare(Comparable o1, Comparable o2) { return o1.compareTo(o2); } }; private final NavigableSet<E> delegate; public SafeTreeSet() { this(new TreeSet<E>()); } public SafeTreeSet(Collection<? extends E> collection) { this(new TreeSet<E>(collection)); } public SafeTreeSet(Comparator<? super E> comparator) { this(new TreeSet<E>(comparator)); } public SafeTreeSet(SortedSet<E> set) { this(new TreeSet<E>(set)); } private SafeTreeSet(NavigableSet<E> delegate) { this.delegate = delegate; for (E e : this) { checkValid(e); } } @Override public boolean add(E element) { return delegate.add(checkValid(element)); } @Override public boolean addAll(Collection<? extends E> collection) { for (E e : collection) { checkValid(e); } return delegate.addAll(collection); } @Override public E ceiling(E e) { return delegate.ceiling(checkValid(e)); } @Override public void clear() { delegate.clear(); } @SuppressWarnings("unchecked") @Override public Comparator<? super E> comparator() { Comparator<? super E> comparator = delegate.comparator(); if (comparator == null) { comparator = NATURAL_ORDER; } return comparator; } @Override public boolean contains(Object object) { return delegate.contains(checkValid(object)); } @Override public boolean containsAll(Collection<?> c) { return delegate.containsAll(c); } @Override public Iterator<E> descendingIterator() { return delegate.descendingIterator(); } @Override public NavigableSet<E> descendingSet() { return new SafeTreeSet<E>(delegate.descendingSet()); } @Override public E first() { return delegate.first(); } @Override public E floor(E e) { return delegate.floor(checkValid(e)); } @Override public SortedSet<E> headSet(E toElement) { return headSet(toElement, false); } @Override public NavigableSet<E> headSet(E toElement, boolean inclusive) { return new SafeTreeSet<E>( delegate.headSet(checkValid(toElement), inclusive)); } @Override public E higher(E e) { return delegate.higher(checkValid(e)); } @Override public boolean isEmpty() { return delegate.isEmpty(); } @Override public Iterator<E> iterator() { return delegate.iterator(); } @Override public E last() { return delegate.last(); } @Override public E lower(E e) { return delegate.lower(checkValid(e)); } @Override public E pollFirst() { return delegate.pollFirst(); } @Override public E pollLast() { return delegate.pollLast(); } @Override public boolean remove(Object object) { return delegate.remove(checkValid(object)); } @Override public boolean removeAll(Collection<?> c) { return delegate.removeAll(c); } @Override public boolean retainAll(Collection<?> c) { return delegate.retainAll(c); } @Override public int size() { return delegate.size(); } @Override public NavigableSet<E> subSet( E fromElement, boolean fromInclusive, E toElement, boolean toInclusive) { return new SafeTreeSet<E>( delegate.subSet(checkValid(fromElement), fromInclusive, checkValid(toElement), toInclusive)); } @Override public SortedSet<E> subSet(E fromElement, E toElement) { return subSet(fromElement, true, toElement, false); } @Override public SortedSet<E> tailSet(E fromElement) { return tailSet(fromElement, true); } @Override public NavigableSet<E> tailSet(E fromElement, boolean inclusive) { return delegate.tailSet(checkValid(fromElement), inclusive); } @Override public Object[] toArray() { return delegate.toArray(); } @Override public <T> T[] toArray(T[] a) { return delegate.toArray(a); } private <T> T checkValid(T t) { // a ClassCastException is what's supposed to happen! @SuppressWarnings("unchecked") E e = (E) t; comparator().compare(e, e); return t; } @Override public boolean equals(Object obj) { return delegate.equals(obj); } @Override public int hashCode() { return delegate.hashCode(); } @Override public String toString() { return delegate.toString(); } private static final long serialVersionUID = 0L; }