package clear.experiment; import clear.dep.DepNode; import clear.dep.DepTree; import clear.dep.srl.SRLHead; import clear.dep.srl.SRLInfo; import clear.morph.MorphEnAnalyzer; import clear.reader.SRLReader; import clear.util.cluster.Prob1dMap; import clear.util.cluster.Prob2dMap; import clear.util.cluster.SRLClusterBuilder; import clear.util.tuple.JObjectDoubleTuple; import com.carrotsearch.hppc.ObjectDoubleOpenHashMap; import com.carrotsearch.hppc.cursors.ObjectCursor; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.regex.Pattern; public class SRLProbE { Pattern P_NOUN = Pattern.compile("NN.*"); Pattern P_STOP = Pattern.compile("\\$#ORD#\\$|\\$#CRD#\\$"); MorphEnAnalyzer m_morph; String[] a_topics; Prob1dMap m_verbs; Prob2dMap m_prob; Prob1dMap m_topicsT_VA; Prob1dMap m_topicsT_V; Prob2dMap m_topicsA_VT; public SRLProbE(String dicFile, String[] topics) { m_morph = new MorphEnAnalyzer(dicFile); a_topics = topics; m_verbs = new Prob1dMap(); m_prob = new Prob2dMap(); m_topicsT_VA = new Prob1dMap(); m_topicsT_V = new Prob1dMap(); m_topicsA_VT = new Prob2dMap(); for (String topic : topics) { m_topicsT_VA.put(topic, 0); } } public void lemmatize(DepTree tree) { DepNode node; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); if (node.isPredicate()) { node.lemma = m_morph.getLemma(node.form, "VB"); } else if (P_NOUN.matcher(node.pos).matches()) { node.lemma = m_morph.getLemma(node.form, "NN"); } } } private String getArgKey(DepNode node, SRLHead head) { String dir = (node.id < head.headId) ? "-" : "+"; return dir + head.label; } private String getArgLemma(DepTree tree, DepNode pred, DepNode node, SRLHead head) { DepNode tmp; if (head.equals("A0") && node.isLemma("by") && (tmp = tree.getRightNearestDependent(node.id)) != null) { return tmp.lemma; } else { return node.lemma; } } // ========================= 1st-iteration ========================= public void retrieveVerbs(DepTree tree, String argKey) { DepNode node, pred; String topic, verb; SRLInfo sInfo; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); sInfo = node.srlInfo; for (SRLHead head : sInfo.heads) { if (!argKey.equals(getArgKey(node, head))) { continue; } pred = tree.get(head.headId); topic = getArgLemma(tree, pred, node, head); verb = pred.lemma; if (m_topicsT_VA.containsKey(topic)) { m_verbs.increment(verb); } } } } public void printVerbs() { System.out.println(m_verbs.toStringProb()); System.out.println(); } public void clearVerbs() { m_verbs.clear(); } // ========================= 2nd-iteration ========================= public void retrieveTopics(DepTree tree, String argKey) { DepNode node, pred; String topic; SRLInfo sInfo; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); sInfo = node.srlInfo; if (!P_NOUN.matcher(node.pos).matches()) { continue; } if (P_STOP.matcher(node.lemma).matches()) { continue; } for (SRLHead head : sInfo.heads) { pred = tree.get(head.headId); topic = getArgLemma(tree, pred, node, head); if (m_verbs.containsKey(pred.lemma)) { if (argKey.equals(getArgKey(node, head))) { m_topicsT_VA.increment(topic); m_topicsA_VT.increment(topic, argKey); } else { for (int j = 0; j < 5; j++) { m_topicsA_VT.increment(topic, "REST"); } } m_topicsT_V.increment(topic); } } } } public ArrayList<JObjectDoubleTuple<String>> trimTopics(double threshold, String argKey) { ArrayList<JObjectDoubleTuple<String>> topics = new ArrayList<>(); String topic; double prob, total = 0; for (ObjectCursor<String> cur : m_topicsT_VA.keys()) { topic = cur.value; prob = m_topicsT_VA.getProb(topic) * m_topicsA_VT.get1dProb(topic, argKey) * m_topicsT_V.getProb(topic); total += prob; topics.add(new JObjectDoubleTuple<>(topic, prob)); } ArrayList<JObjectDoubleTuple<String>> remove = new ArrayList<>(); HashSet<String> sTopics = new HashSet<>(Arrays.asList(a_topics)); Collections.sort(topics); for (JObjectDoubleTuple<String> tup : topics) { topic = tup.object; tup.value /= total; if (!sTopics.contains(topic) && tup.value < threshold) { remove.add(tup); m_topicsT_VA.remove(topic); } } topics.removeAll(remove); return topics; } public void printTopics(ArrayList<JObjectDoubleTuple<String>> topics) { StringBuilder build1 = new StringBuilder(); StringBuilder build2 = new StringBuilder(); DecimalFormat format = new DecimalFormat("#0.0000"); for (JObjectDoubleTuple<String> tup : topics) { build1.append(tup.object); build1.append("|"); build2.append(tup.object); build2.append("\t"); build2.append(format.format(tup.value)); build2.append("\n"); } build1.append("\n"); System.out.println(build1.toString()); System.out.println(build2.toString()); } // ========================= 4th-iteration ========================= public void retrieveArgs(DepTree tree) { DepNode node, pred; String arg; SRLInfo sInfo; for (int i = 1; i < tree.size(); i++) { node = tree.get(i); sInfo = node.srlInfo; for (SRLHead head : sInfo.heads) { pred = tree.get(head.headId); if (!m_verbs.containsKey(pred.lemma)) { continue; } if (head.label.matches("AM-MOD|AM-NEG|R-.*")) { continue; } arg = getArgLemma(tree, pred, node, head) + ":" + head.label; m_prob.increment(pred.lemma, arg); } } } public ObjectDoubleOpenHashMap<String> getTopicweights(ArrayList<JObjectDoubleTuple<String>> topics, String argLabel) { ObjectDoubleOpenHashMap<String> map = new ObjectDoubleOpenHashMap<>(); for (JObjectDoubleTuple<String> tup : topics) { map.put(tup.object + argLabel, Math.exp(tup.value)); } return map; } public void weightArgMs(double weight) { Prob1dMap map; String label; for (String verb : m_prob.keySet()) { map = m_prob.get(verb); for (ObjectCursor<String> arg : map.keys()) { label = arg.value; if (label.contains("AM")) { map.put(label, (int) Math.ceil(weight * map.get(label))); } } } } // ========================= 5th-iteration ========================= public void retrieveCluster(double threshold, ObjectDoubleOpenHashMap<String> mWeights) { SRLClusterBuilder build = new SRLClusterBuilder(threshold); build.cluster(m_prob, mWeights); } static public void main(String[] args) { String inputFile = args[0]; String dicFile = args[1]; String argKey = args[2]; // String[] topic = "scotty,doctor".split(","); // String[] topic = "scotty,rachel,doctor,warren,father,prevot,president,boy,boxell".split(","); // String[] topic = "time".split(","); // String[] topic = "time,grip,appetite,reason,chance,appearance,remainder,spot,desire,need,ammo,exercise,fact,number,concussion,attitude,moment,glow".split(","); // String[] topic = "figure".split(","); String[] topic = "parent".split(","); SRLReader reader = new SRLReader(inputFile, true); SRLProbE prob = new SRLProbE(dicFile, topic); DepTree tree; // brown double topicThreshold = 0.02; double clusterThreshold = 0.15; // wsj-brown // double topicThreshold = 0.003; // double clusterThreshold = 0.28; System.out.println(Arrays.toString(topic) + ": " + argKey + "\n"); System.out.println("== Related verbs ==\n"); while ((tree = reader.nextTree()) != null) { prob.lemmatize(tree); prob.retrieveVerbs(tree, argKey); } prob.printVerbs(); reader.close(); System.out.println("== Related topics ==\n"); reader.open(inputFile); while ((tree = reader.nextTree()) != null) { prob.lemmatize(tree); prob.retrieveTopics(tree, argKey); } ArrayList<JObjectDoubleTuple<String>> topics = prob.trimTopics(topicThreshold, argKey); prob.printTopics(topics); reader.close(); System.out.println("== More related verbs ==\n"); reader.open(inputFile); prob.clearVerbs(); while ((tree = reader.nextTree()) != null) { prob.lemmatize(tree); prob.retrieveVerbs(tree, argKey); } prob.printVerbs(); reader.close(); System.out.println("== Verb clusters ==\n"); reader.open(inputFile); while ((tree = reader.nextTree()) != null) { prob.lemmatize(tree); prob.retrieveArgs(tree); } System.out.println("Clustering"); ObjectDoubleOpenHashMap<String> mWeights = prob.getTopicweights(topics, argKey.substring(1)); prob.retrieveCluster(clusterThreshold, mWeights); reader.close(); } }