package clear.util.cluster; import clear.util.IOUtil; import clear.util.tuple.JObjectDoubleTuple; import com.carrotsearch.hppc.ObjectDoubleOpenHashMap; import com.carrotsearch.hppc.cursors.ObjectCursor; import java.io.PrintStream; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @SuppressWarnings("serial") public class Prob2dMap extends HashMap<String, Prob1dMap> { private int n_total; public Prob2dMap() { n_total = 0; } public double get2dProb(String key2d) { Prob1dMap map = get(key2d); return (double) map.n_total / n_total; } public double get2dProb(String key2d, String key1d) { Prob1dMap map = get(key2d); double prob2d = (double) map.n_total / n_total; return prob2d * map.getProb(key1d); } public double get1dProb(String key2d, String key1d) { Prob1dMap map = get(key2d); return map.getProb(key1d); } public ObjectDoubleOpenHashMap<String> getProb1dMap(String key2d) { Prob1dMap map1d = get(key2d); if (map1d == null) { return null; } return map1d.getProbMap(); } /** * @return probabilistic map: P(1D|2D) * P(2D). */ public ObjectDoubleOpenHashMap<String> getProb2dMap(String key2d) { Prob1dMap map1d = get(key2d); if (map1d == null) { return null; } ObjectDoubleOpenHashMap<String> map = new ObjectDoubleOpenHashMap<>(map1d.size()); double prob2d = (double) map1d.n_total / n_total; for (ObjectCursor<String> cur : map1d.keys()) { map.put(cur.value, map1d.getProb(cur.value) * prob2d); } return map; } /** * @return sorted list generated from a 1st-degree map: P(1D|2D) */ public ArrayList<JObjectDoubleTuple<String>> getProb1dList(String key2d) { return map2list(getProb1dMap(key2d)); } /** * @return sorted list generated from a 1st-degree map: P(1D|2D) * P(2D) */ public ArrayList<JObjectDoubleTuple<String>> getProb2dList(String key2d) { return map2list(getProb2dMap(key2d)); } private ArrayList<JObjectDoubleTuple<String>> map2list(ObjectDoubleOpenHashMap<String> map) { if (map == null) { return null; } ArrayList<JObjectDoubleTuple<String>> list = new ArrayList<>(map.size()); for (ObjectCursor<String> cur : map.keys()) { list.add(new JObjectDoubleTuple<>(cur.value, map.get(cur.value))); } Collections.sort(list); return list; } /** * Increments both 1st and 2nd-degree maps. */ public void increment(String key2d, String key1d) { Prob1dMap map = get1dMap(key2d); n_total++; map.increment(key1d); } /** * Increments a top map once, and all sub maps */ public void increment(String key2d, Collection<String> keys1d) { Prob1dMap map = get1dMap(key2d); n_total += keys1d.size(); for (String key1d : keys1d) { map.increment(key1d); } } /** * Increments a 2nd-degree map. */ public Prob1dMap get1dMap(String key2d) { Prob1dMap map1d; if (containsKey(key2d)) { map1d = get(key2d); } else { map1d = new Prob1dMap(); put(key2d, map1d); } return map1d; } public void print(String filename, DecimalFormat format) { try (PrintStream fout = IOUtil.createPrintFileStream(filename)) { fout.print(toString(format)); } } @Override public String toString() { return toString(new DecimalFormat("#0.0000")); } public String toString(DecimalFormat format) { ArrayList<String> keys2d = new ArrayList<>(keySet()); StringBuilder build = new StringBuilder(); Collections.sort(keys2d); for (String key2d : keys2d) { build.append(toString(key2d, format)); build.append("\n"); } return build.toString(); } public String toString(String key2d, DecimalFormat format) { ArrayList<JObjectDoubleTuple<String>> list = getProb2dList(key2d); StringBuilder build = new StringBuilder(); build.append(key2d); for (JObjectDoubleTuple<String> tup : list) { build.append(" "); build.append(tup.object); build.append(":"); build.append(format.format(tup.value)); } return build.toString(); } }