package edu.stanford.nlp.util.concurrent;
import java.io.Serializable;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.Factory;
import edu.stanford.nlp.util.logging.PrettyLogger;
import edu.stanford.nlp.util.logging.Redwood.RedwoodChannels;
/**
* A threadsafe counter implemented as a lightweight wrapper around a
* ConcurrentHashMap.
*
* @author Spence Green
*
* @param <E>
*/
public class ConcurrentHashCounter<E> implements Serializable, Counter<E>, Iterable<E> {
private static final long serialVersionUID = -8077192206562696111L;
private static final int DEFAULT_CAPACITY = 100;
private final ConcurrentMap<E,AtomicDouble> map;
private final AtomicDouble totalCount;
private double defaultReturnValue = 0.0;
public ConcurrentHashCounter() {
this(DEFAULT_CAPACITY);
}
public ConcurrentHashCounter(int initialCapacity) {
map = new ConcurrentHashMap<>(initialCapacity);
totalCount = new AtomicDouble();
}
@Override
public Iterator<E> iterator() {
return keySet().iterator();
}
@Override
public Factory<Counter<E>> getFactory() {
return new Factory<Counter<E>>() {
private static final long serialVersionUID = 6076144467752914760L;
@Override
public Counter<E> create() {
return new ConcurrentHashCounter<>();
}
};
}
@Override
public void setDefaultReturnValue(double value) {
defaultReturnValue = value;
}
@Override
public double defaultReturnValue() {
return defaultReturnValue;
}
@Override
public double getCount(Object key) {
AtomicDouble v = map.get(key);
return v == null ? defaultReturnValue : v.get();
}
@Override
public void setCount(E key, double value) {
// TODO Inspired by Guava.AtomicLongMap
// Modify for our use?
outer: for (;;) {
AtomicDouble atomic = map.get(key);
if (atomic == null) {
atomic = map.putIfAbsent(key, new AtomicDouble(value));
if (atomic == null) {
totalCount.addAndGet(value);
return;
}
}
for (;;) {
double oldValue = atomic.get();
if (oldValue == 0.0) {
// don't compareAndSet a zero
if (map.replace(key, atomic, new AtomicDouble(value))) {
totalCount.addAndGet(value);
return;
}
continue outer;
}
if (atomic.compareAndSet(oldValue, value)) {
totalCount.addAndGet(value - oldValue);
return;
}
}
}
}
@Override
public double incrementCount(E key, double value) {
// TODO Inspired by Guava.AtomicLongMap
// Modify for our use?
outer: for (;;) {
AtomicDouble atomic = map.get(key);
if (atomic == null) {
atomic = map.putIfAbsent(key, new AtomicDouble(value));
if (atomic == null) {
totalCount.addAndGet(value);
return value;
}
}
for (;;) {
double oldValue = atomic.get();
if (oldValue == 0.0) {
// don't compareAndSet a zero
if (map.replace(key, atomic, new AtomicDouble(value))) {
totalCount.addAndGet(value);
return value;
}
continue outer;
}
double newValue = oldValue + value;
if (atomic.compareAndSet(oldValue, newValue)) {
totalCount.addAndGet(value);
return newValue;
}
}
}
}
@Override
public double incrementCount(E key) {
return incrementCount(key, 1.0);
}
@Override
public double decrementCount(E key, double value) {
return incrementCount(key, -value);
}
@Override
public double decrementCount(E key) {
return incrementCount(key, -1.0);
}
@Override
public double logIncrementCount(E key, double value) {
// TODO Inspired by Guava.AtomicLongMap
// Modify for our use?
outer: for (;;) {
AtomicDouble atomic = map.get(key);
if (atomic == null) {
atomic = map.putIfAbsent(key, new AtomicDouble(value));
if (atomic == null) {
totalCount.addAndGet(value);
return value;
}
}
for (;;) {
double oldValue = atomic.get();
if (oldValue == 0.0) {
// don't compareAndSet a zero
if (map.replace(key, atomic, new AtomicDouble(value))) {
totalCount.addAndGet(value);
return value;
}
continue outer;
}
double newValue = SloppyMath.logAdd(oldValue, value);
if (atomic.compareAndSet(oldValue, newValue)) {
totalCount.addAndGet(value);
return newValue;
}
}
}
}
@Override
public void addAll(Counter<E> counter) {
Counters.addInPlace(this, counter);
}
@Override
public double remove(E key) {
AtomicDouble atomic = map.get(key);
if (atomic == null) {
return defaultReturnValue;
}
for (;;) {
double oldValue = atomic.get();
if (oldValue == 0.0 || atomic.compareAndSet(oldValue, 0.0)) {
// only remove after setting to zero, to avoid concurrent updates
map.remove(key, atomic);
// succeed even if the remove fails, since the value was already adjusted
totalCount.addAndGet(-1.0 * oldValue);
return oldValue;
}
}
}
@Override
public boolean containsKey(E key) {
return map.containsKey(key);
}
@Override
public Set<E> keySet() {
return Collections.unmodifiableSet(map.keySet());
}
@Override
public Collection<Double> values() {
return new Collection<Double>() {
@Override
public int size() {
return map.size();
}
@Override
public boolean isEmpty() {
return map.size() == 0;
}
@Override
public boolean contains(Object o) {
if (o instanceof Double) {
double value = (Double) o;
for (AtomicDouble atomic : map.values()) {
if (atomic.get() == value) {
return true;
}
}
}
return false;
}
@Override
public Iterator<Double> iterator() {
return new Iterator<Double>() {
Iterator<AtomicDouble> iterator = map.values().iterator();
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public Double next() {
return iterator.next().get();
}
@Override
public void remove() {
iterator.remove();
}
};
}
@Override
public Object[] toArray() {
return map.values().toArray();
}
@Override
public <T> T[] toArray(T[] a) {
return map.values().toArray(a);
}
@Override
public boolean add(Double e) {
throw new UnsupportedOperationException();
}
@Override
public boolean remove(Object o) {
throw new UnsupportedOperationException();
}
@Override
public boolean containsAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean addAll(Collection<? extends Double> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean removeAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean retainAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
};
}
@Override
public Set<Entry<E, Double>> entrySet() {
return new AbstractSet<Map.Entry<E,Double>>() {
@Override
public Iterator<Entry<E, Double>> iterator() {
return new Iterator<Entry<E,Double>>() {
final Iterator<Entry<E,AtomicDouble>> inner = map.entrySet().iterator();
@Override
public boolean hasNext() {
return inner.hasNext();
}
@Override
public Entry<E, Double> next() {
return new Entry<E,Double>() {
final Entry<E,AtomicDouble> e = inner.next();
@Override
public E getKey() {
return e.getKey();
}
@Override
public Double getValue() {
return e.getValue().get();
}
@Override
public Double setValue(Double value) {
final double old = e.getValue().get();
setCount(e.getKey(), value);
e.getValue().set(value);
return old;
}
};
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
@Override
public int size() {
return map.size();
}
};
}
@Override
public void clear() {
for(;;) {
totalCount.set(0.0);
if (totalCount.get() == 0.0) {
map.clear();
return;
}
}
}
@Override
public int size() {
return map.size();
}
@Override
public double totalCount() {
return totalCount.get();
}
@SuppressWarnings("unchecked")
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
} else if ( ! (o instanceof ConcurrentHashCounter)) {
return false;
} else {
final ConcurrentHashCounter<E> other = (ConcurrentHashCounter<E>) o;
return totalCount.get() == other.totalCount.get() && map.equals(other.map);
}
}
/** Returns a hashCode which is the underlying Map's hashCode.
*
* @return A hashCode.
*/
@Override
public int hashCode() {
return map.hashCode();
}
/** Returns a String representation of the Counter, as formatted by
* the underlying Map.
*
* @return A String representation of the Counter.
*/
@Override
public String toString() {
return map.toString();
}
@Override
public void prettyLog(RedwoodChannels channels, String description) {
PrettyLogger.log(channels, description, map);
}
}