package edu.umd.hooka;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Map;
import java.util.TreeMap;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.Writable;
public class Phrase2CountMap extends TreeMap<Phrase,FloatWritable>
implements Writable {
private static final long serialVersionUID = 1093017050863486402L;
public final void plusEquals(Phrase2CountMap rhs) {
for (Map.Entry<Phrase, FloatWritable> ri : rhs.entrySet()) {
FloatWritable cv = this.get(ri.getKey());
if (cv == null) {
cv = new FloatWritable(0);
this.put(ri.getKey(), cv);
}
cv.set(cv.get() + ri.getValue().get());
}
}
public final void setPhraseCount(Phrase key, float value) {
this.put(key, new FloatWritable(value));
}
public final float getPhraseCount(Phrase key) {
FloatWritable x = this.get(key);
if (x == null) return 0.0f;
return x.get();
}
public void normalize() {
float total = 0.0f;
for (Map.Entry<Phrase, FloatWritable> i : this.entrySet())
total += i.getValue().get();
if (total > 0.0f)
for (Map.Entry<Phrase, FloatWritable> i : this.entrySet()) {
FloatWritable cur = i.getValue();
cur.set(cur.get() / total);
}
else
throw new RuntimeException(this +"\ntotal=0.0 : please implement uniform distribution");
}
public void readFields(DataInput in) throws IOException {
this.clear();
int size = in.readInt();
for (int i = 0; i < size; i++) {
Phrase p = new Phrase();
FloatWritable c = new FloatWritable();
p.readFields(in);
c.readFields(in);
this.put(p, c);
}
}
public void write(DataOutput out) throws IOException {
out.writeInt(this.size());
for (Map.Entry<Phrase, FloatWritable> i : this.entrySet()) {
i.getKey().write(out);
i.getValue().write(out);
}
}
}