/******************************************************************************* * Copyright 2016 Observational Health Data Sciences and Informatics * * This file is part of WhiteRabbit * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ package org.ohdsi.utilities.collections; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; /** * Class for counting recurring objects. * * @author schuemie * @param <T> */ public class CountingSet<T> implements Set<T> { public Map<T, Count> key2count; public CountingSet() { key2count = new HashMap<T, Count>(); } public CountingSet(int capacity) { key2count = new HashMap<T, Count>(capacity); } public CountingSet(CountingSet<T> set) { key2count = new HashMap<T, Count>(set.key2count.size()); for (Map.Entry<T, Count> entry : set.key2count.entrySet()) key2count.put(entry.getKey(), new Count(entry.getValue().count)); } public int getCount(T key) { Count count = key2count.get(key); if (count == null) return 0; else return count.count; } /** * Computes the sum of the counts * * @return */ public int getSum() { int sum = 0; for (Count count : key2count.values()) sum += count.count; return sum; } /** * Returns the maximum count * * @return */ public int getMax() { int max = 0; for (Count count : key2count.values()) if (count.count > max) max = count.count; return max; } /** * Computes the mean of the counts * * @return */ public double getMean() { return (getSum() / (double) key2count.size()); } /** * Computes the standard deviations of the counts * * @return */ public double getSD() { double mean = getMean(); double sum = 0; for (Count count : key2count.values()) sum += sqr(count.count - mean); return Math.sqrt(sum / (double) key2count.size()); } private double sqr(double d) { return d * d; } public int size() { return key2count.size(); } public boolean isEmpty() { return key2count.isEmpty(); } public boolean contains(Object arg0) { return key2count.containsKey(arg0); } public Iterator<T> iterator() { return key2count.keySet().iterator(); } public Object[] toArray() { return key2count.keySet().toArray(); } @SuppressWarnings("unchecked") public Object[] toArray(Object[] arg0) { return key2count.keySet().toArray(arg0); } public boolean add(T arg0) { Count count = key2count.get(arg0); if (count == null) { count = new Count(); key2count.put(arg0, count); return true; } else { count.count++; return false; } } public boolean add(T arg0, int inc) { Count count = key2count.get(arg0); if (count == null) { count = new Count(); count.count = inc; key2count.put(arg0, count); return true; } else { count.count += inc; return false; } } public boolean remove(Object arg0) { return (key2count.remove(arg0) != null); } public boolean containsAll(Collection<?> arg0) { return key2count.keySet().containsAll(arg0); } public boolean addAll(Collection<? extends T> arg0) { boolean changed = false; for (T object : arg0) { if (add(object)) changed = true; } return changed; } public boolean retainAll(Collection<?> arg0) { return key2count.keySet().retainAll(arg0); } public boolean removeAll(Collection<?> arg0) { return key2count.keySet().removeAll(arg0); } public void clear() { key2count.clear(); } /** * Keep the n most frequent values, remove the rest * * @param n */ public void keepTopN(int n) { if (size() < n) return; List<Map.Entry<T, Count>> list = new ArrayList<Map.Entry<T, Count>>(key2count.entrySet()); Collections.sort(list, new Comparator<Map.Entry<T, Count>>() { @Override public int compare(Entry<T, Count> arg0, Entry<T, Count> arg1) { return IntegerComparator.compare(arg1.getValue().count, arg0.getValue().count); } }); Map<T, Count> newMap = new HashMap<T, CountingSet.Count>(n); for (int i = 0; i < n; i++) newMap.put(list.get(i).getKey(), list.get(i).getValue()); key2count = newMap; } public static class Count { public int count = 1; public Count() { } public Count(int count) { this.count = count; } } public void printCounts() { List<Map.Entry<T, Count>> result = new ArrayList<Map.Entry<T, Count>>(key2count.entrySet()); Collections.sort(result, new Comparator<Map.Entry<T, Count>>() { public int compare(Entry<T, Count> o1, Entry<T, Count> o2) { return IntegerComparator.compare(o2.getValue().count, o1.getValue().count); } }); for (Map.Entry<T, Count> entry : result) System.out.println(entry.getKey() + "\t" + entry.getValue().count); } }