package clear.experiment; import clear.dep.DepNode; import clear.dep.DepTree; import clear.dep.srl.SRLHead; import clear.dep.srl.SRLInfo; import clear.reader.SRLReader; import clear.util.cluster.Prob2dMap; import clear.util.tuple.JObjectDoubleTuple; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; public class SRLTopicCluster { HashMap<String, Prob2dMap> m_ta, m_at; HashMap<String, HashSet<String>> s_verbs; public SRLTopicCluster() { m_ta = new HashMap<>(); m_at = new HashMap<>(); s_verbs = new HashMap<>(); } public void retrieveTopics(DepTree tree) { DepNode node, pred; SRLInfo info; String feat; Prob2dMap pTA, pAT; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); info = node.srlInfo; if (!node.isPosx("NN.*")) { continue; } for (SRLHead head : info.heads) { pred = tree.get(head.headId); if ((feat = pred.getFeat("ct")) == null) { continue; } pTA = getSubMap(m_ta, feat); pTA.increment(head.label, node.lemma); pAT = getSubMap(m_at, feat); pAT.increment(node.lemma, head.label); getSubSet(s_verbs, feat).add(pred.lemma); } } } public void getTopics(ArrayList<HashSet<String>> aTopics, String argKey, double threshold, int num) { ArrayList<JObjectDoubleTuple<String>> aTA; Prob2dMap pTA, pAT; HashSet<String> topics, clone; outer: for (String id : m_ta.keySet()) { pTA = m_ta.get(id); pAT = m_at.get(id); if ((aTA = pTA.getProb1dList(argKey)) == null) { continue; } topics = new HashSet<>(); for (JObjectDoubleTuple<String> tup : aTA) { tup.value *= pAT.get1dProb(tup.object, argKey); if (tup.value >= threshold) { topics.add(tup.object); } } if (topics.size() >= num) { for (HashSet<String> pSet : aTopics) { clone = new HashSet<>(topics); clone.removeAll(pSet); if (clone.size() < num) { continue outer; } } aTopics.add(topics); } } } private Prob2dMap getSubMap(HashMap<String, Prob2dMap> mTa, String key) { Prob2dMap submap; if (mTa.containsKey(key)) { submap = mTa.get(key); } else { submap = new Prob2dMap(); mTa.put(key, submap); } return submap; } private HashSet<String> getSubSet(HashMap<String, HashSet<String>> mTa, String key) { HashSet<String> subset; if (mTa.containsKey(key)) { subset = mTa.get(key); } else { subset = new HashSet<>(); mTa.put(key, subset); } return subset; } static public void main(String[] args) { String inputFile = args[0]; String outputFile = args[1]; ArrayList<HashSet<String>> aTopics = new ArrayList<>(); SRLTopicCluster tbuild = new SRLTopicCluster(); SRLReader reader = new SRLReader(inputFile, true); DepTree tree; while ((tree = reader.nextTree()) != null) { tbuild.retrieveTopics(tree); } // tbuild.getTopics(aTopics, "A0", 0.005, 10); tbuild.getTopics(aTopics, "A1", 0.005, 10); try { try (ObjectOutputStream outputStream = new ObjectOutputStream(new FileOutputStream(outputFile))) { outputStream.writeObject(aTopics); } } catch (IOException e) { e.printStackTrace(); } } }