/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.collection.ngram;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import edu.emory.clir.clearnlp.collection.map.ObjectIntHashMap;
import edu.emory.clir.clearnlp.collection.pair.ObjectDoublePair;
import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair;
import edu.emory.clir.clearnlp.util.MathUtils;
public class Unigram<T> implements Serializable
{
private static final long serialVersionUID = 2431106431004828434L;
private ObjectIntHashMap<T> g_map;
private int i_total;
private T t_best;
public Unigram()
{
g_map = new ObjectIntHashMap<>();
t_best = null;
i_total = 0;
}
public void add(T key)
{
add(key, 1);
}
public void add(T key, int inc)
{
int c = g_map.add(key, inc);
i_total += inc;
if (t_best == null || get(t_best) < c)
t_best = key;
}
public int get(T key)
{
return g_map.get(key);
}
public ObjectDoublePair<T> getBest()
{
return (t_best != null) ? new ObjectDoublePair<T>(t_best, MathUtils.divide(get(t_best), i_total)) : null;
}
public boolean contains(T key)
{
return g_map.containsKey(key);
}
public double getProbability(T key)
{
return MathUtils.divide(get(key), i_total);
}
public List<ObjectIntPair<T>> toList(int cutoff)
{
List<ObjectIntPair<T>> list = new ArrayList<>();
for (ObjectIntPair<T> p : g_map)
{
if (p.i > cutoff)
list.add(p);
}
return list;
}
public List<ObjectDoublePair<T>> toList(double threshold)
{
List<ObjectDoublePair<T>> list = new ArrayList<>();
double d;
for (ObjectIntPair<T> p : g_map)
{
d = MathUtils.divide(p.i, i_total);
if (d > threshold) list.add(new ObjectDoublePair<T>(p.o, d));
}
return list;
}
public Set<T> keySet()
{
return g_map.keySet(0);
}
/** @return a set of keys whose values are greater than the specific cutoff. */
public Set<T> keySet(int cutoff)
{
return g_map.keySet(cutoff);
}
public Set<T> keySet(double threshold)
{
Set<T> set = new HashSet<>();
double d;
for (ObjectIntPair<T> p : g_map)
{
d = MathUtils.divide(p.i, i_total);
if (d > threshold) set.add(p.o);
}
return set;
}
}