package org.wikibrain.phrases; import com.typesafe.config.Config; import org.wikibrain.conf.Configuration; import org.wikibrain.conf.ConfigurationException; import org.wikibrain.conf.Configurator; import java.util.*; /** * A simple pruner for PrunedCounts. * Prunes by count, rank, and fraction of total count. * * @author Shilad Sen */ public class SimplePruner<K> implements PrunedCounts.Pruner<K> { private final int minCount; private final int maxRank; private final double minFrac; public SimplePruner(int minCount, int maxRank, double minFrac) { this.minCount = minCount; this.maxRank = maxRank; this.minFrac = minFrac; } /** * Prunes counts down. * Returns the pruned counts (i.e. with some keys removed, sorted by count, and total unchanged) * or null if the entry should not appear in the database at all. The resulting hashmap * is in decreasing order by size. * @param allCounts All counts. * @return */ @Override public PrunedCounts<K> prune(final Map<K, Integer> allCounts) { List<K> keys = new ArrayList<K>(allCounts.keySet()); Collections.sort(keys, new Comparator<K>() { @Override public int compare(K key1, K key2) { return -1 * (allCounts.get(key1) - allCounts.get(key2)); } }); int sum = 0; for (Integer c : allCounts.values()) { sum += c; } PrunedCounts<K> pruned = new PrunedCounts<K>(sum); for (K key : keys) { int c = allCounts.get(key); if (pruned.size() >= maxRank || c < minCount || 1.0 * c / sum < minFrac) break; pruned.put(key, c); } if (pruned.isEmpty()) { return null; } else { return pruned; } } public static class Provider extends org.wikibrain.conf.Provider<PrunedCounts.Pruner> { public Provider(Configurator configurator, Configuration config) throws ConfigurationException { super(configurator, config); } @Override public Class getType() { return PrunedCounts.Pruner.class; } @Override public String getPath() { return "phrases.pruners"; } @Override public PrunedCounts.Pruner get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException { if (!config.getString("type").equals("simple")) { return null; } int minCount = config.getInt("minCount"); int maxRank = config.getInt("maxRank"); double minFraction = config.getDouble("minFraction"); return new SimplePruner(minCount, maxRank, minFraction); } } }