/* * Copyright (C) 2010 The Guava Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.common.collect.testing; import java.io.Serializable; import java.util.Collection; import java.util.Comparator; import java.util.Map; import java.util.NavigableMap; import java.util.NavigableSet; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; /** * A wrapper around {@code TreeMap} that aggressively checks to see if keys are * mutually comparable. This implementation passes the navigable map test * suites. * * @author Louis Wasserman */ public final class SafeTreeMap<K, V> implements Serializable, NavigableMap<K, V> { @SuppressWarnings("unchecked") private static final Comparator NATURAL_ORDER = new Comparator<Comparable>() { @Override public int compare(Comparable o1, Comparable o2) { return o1.compareTo(o2); } }; private final NavigableMap<K, V> delegate; public SafeTreeMap() { this(new TreeMap<K, V>()); } public SafeTreeMap(Comparator<? super K> comparator) { this(new TreeMap<K, V>(comparator)); } public SafeTreeMap(Map<? extends K, ? extends V> map) { this(new TreeMap<K, V>(map)); } public SafeTreeMap(SortedMap<K, ? extends V> map) { this(new TreeMap<K, V>(map)); } private SafeTreeMap(NavigableMap<K, V> delegate) { this.delegate = delegate; if (delegate == null) { throw new NullPointerException(); } for (K k : keySet()) { checkValid(k); } } @Override public Entry<K, V> ceilingEntry(K key) { return delegate.ceilingEntry(checkValid(key)); } @Override public K ceilingKey(K key) { return delegate.ceilingKey(checkValid(key)); } @Override public void clear() { delegate.clear(); } @SuppressWarnings("unchecked") @Override public Comparator<? super K> comparator() { Comparator<? super K> comparator = delegate.comparator(); if (comparator == null) { comparator = NATURAL_ORDER; } return comparator; } @Override public boolean containsKey(Object key) { try { return delegate.containsKey(checkValid(key)); } catch (NullPointerException e) { return false; } catch (ClassCastException e) { return false; } } @Override public boolean containsValue(Object value) { return delegate.containsValue(value); } @Override public NavigableSet<K> descendingKeySet() { return delegate.descendingKeySet(); } @Override public NavigableMap<K, V> descendingMap() { return new SafeTreeMap<K, V>(delegate.descendingMap()); } @Override public Set<Entry<K, V>> entrySet() { return delegate.entrySet(); } @Override public Entry<K, V> firstEntry() { return delegate.firstEntry(); } @Override public K firstKey() { return delegate.firstKey(); } @Override public Entry<K, V> floorEntry(K key) { return delegate.floorEntry(checkValid(key)); } @Override public K floorKey(K key) { return delegate.floorKey(checkValid(key)); } @Override public V get(Object key) { return delegate.get(checkValid(key)); } @Override public SortedMap<K, V> headMap(K toKey) { return headMap(toKey, false); } @Override public NavigableMap<K, V> headMap(K toKey, boolean inclusive) { return new SafeTreeMap<K, V>( delegate.headMap(checkValid(toKey), inclusive)); } @Override public Entry<K, V> higherEntry(K key) { return delegate.higherEntry(checkValid(key)); } @Override public K higherKey(K key) { return delegate.higherKey(checkValid(key)); } @Override public boolean isEmpty() { return delegate.isEmpty(); } @Override public NavigableSet<K> keySet() { return navigableKeySet(); } @Override public Entry<K, V> lastEntry() { return delegate.lastEntry(); } @Override public K lastKey() { return delegate.lastKey(); } @Override public Entry<K, V> lowerEntry(K key) { return delegate.lowerEntry(checkValid(key)); } @Override public K lowerKey(K key) { return delegate.lowerKey(checkValid(key)); } @Override public NavigableSet<K> navigableKeySet() { return delegate.navigableKeySet(); } @Override public Entry<K, V> pollFirstEntry() { return delegate.pollFirstEntry(); } @Override public Entry<K, V> pollLastEntry() { return delegate.pollLastEntry(); } @Override public V put(K key, V value) { return delegate.put(checkValid(key), value); } @Override public void putAll(Map<? extends K, ? extends V> map) { for (K key : map.keySet()) { checkValid(key); } delegate.putAll(map); } @Override public V remove(Object key) { return delegate.remove(checkValid(key)); } @Override public int size() { return delegate.size(); } @Override public NavigableMap<K, V> subMap( K fromKey, boolean fromInclusive, K toKey, boolean toInclusive) { return new SafeTreeMap<K, V>(delegate.subMap( checkValid(fromKey), fromInclusive, checkValid(toKey), toInclusive)); } @Override public SortedMap<K, V> subMap(K fromKey, K toKey) { return subMap(fromKey, true, toKey, false); } @Override public SortedMap<K, V> tailMap(K fromKey) { return tailMap(fromKey, true); } @Override public NavigableMap<K, V> tailMap(K fromKey, boolean inclusive) { return new SafeTreeMap<K, V>( delegate.tailMap(checkValid(fromKey), inclusive)); } @Override public Collection<V> values() { return delegate.values(); } private <T> T checkValid(T t) { // a ClassCastException is what's supposed to happen! @SuppressWarnings("unchecked") K k = (K) t; comparator().compare(k, k); return t; } @Override public boolean equals(Object obj) { return delegate.equals(obj); } @Override public int hashCode() { return delegate.hashCode(); } @Override public String toString() { return delegate.toString(); } private static final long serialVersionUID = 0L; }