package clear.experiment;
import clear.parse.VoiceDetector;
import clear.propbank.PBArg;
import clear.propbank.PBInstance;
import clear.propbank.PBLoc;
import clear.propbank.PBReader;
import clear.treebank.TBNode;
import clear.treebank.TBReader;
import clear.treebank.TBTree;
import clear.util.IOUtil;
import clear.util.tuple.JObjectDoubleTuple;
import clear.util.tuple.JObjectIntTuple;
import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import java.io.File;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
public class AnalyzePBArgs {
final String TOTAL = "TOTAL";
ArrayList<TBTree> ls_trees;
TBTree tb_tree;
ArrayList<ObjectIntOpenHashMap<String>> ls_numberedAdjuncts;
ObjectIntOpenHashMap<String> m_verbs;
ObjectIntOpenHashMap<String> m_verbArgs;
ObjectIntOpenHashMap<String> m_verbArgNs;
ObjectIntOpenHashMap<String> m_verbArgMs;
ObjectIntOpenHashMap<String> m_preps;
ObjectIntOpenHashMap<String> m_verbPreps;
ObjectIntOpenHashMap<String> m_verbPrepNs;
ObjectIntOpenHashMap<String> m_verbPrepMs;
int n_count = 0;
public AnalyzePBArgs(String rootPath, String outputFile) {
init();
read(rootPath);
// trimVerbPrep();
print(outputFile);
}
void init() {
// initNumberedAdjunct();
initVerbPrepPMI();
initRequiredArgument();
}
void print(String outputFile) {
// printNumberedAdjunct(outputFile);
// printVerbPrepPMI(outputFile);
printRequiredArgument(outputFile);
}
void read(String rootPath) {
File root = new File(rootPath); // v4.0
File prop;
String path;
int count;
for (String corpusDir : root.list()) // ebc, wsj
{
path = rootPath + File.separator + corpusDir + File.separator + "prop";
prop = new File(path);
if (!prop.isDirectory()) {
continue;
}
System.out.println(path);
count = 0;
for (String propFile : prop.list()) {
propFile = path + File.separator + propFile;
if (++count % 100 == 0) {
System.out.print(".");
}
readParses(propFile.replaceAll("/prop/", "/parse/").replaceAll("\\.prop$", ".parse"));
readProps(propFile);
}
System.out.println();
}
}
void readParses(String parseFile) {
TBReader reader = new TBReader(parseFile);
TBTree tree;
ls_trees = new ArrayList<>();
while ((tree = reader.nextTree()) != null) {
ls_trees.add(tree);
}
ls_trees.trimToSize();
}
void readProps(String propFile) {
PBReader reader = new PBReader(propFile);
PBInstance instance;
while ((instance = reader.nextInstance()) != null) {
if (!instance.type.endsWith("-v")) {
continue;
}
if (instance.getArgs().size() <= 1) {
continue;
}
instance.type = instance.type.substring(0, instance.type.length() - 2);
tb_tree = ls_trees.get(instance.treeIndex);
// processNumberedAdjunct(instance);
// processVerbPrepPMI(instance);
processRequiredArgument(instance);
}
}
// ----------------------------- NumberedAdjunct -----------------------------
void initNumberedAdjunct() {
ls_numberedAdjuncts = new ArrayList<>();
for (int i = 0; i <= 5; i++) {
ls_numberedAdjuncts.add(new ObjectIntOpenHashMap<String>());
}
}
void processNumberedAdjunct(PBInstance instance) {
ObjectIntOpenHashMap<String> map;
TBNode node;
for (PBArg arg : instance.getArgs()) {
if (!arg.label.matches("ARG\\d")) {
continue;
}
map = ls_numberedAdjuncts.get(Integer.parseInt(arg.label.substring(3, 4)));
for (PBLoc loc : arg.getLocs()) {
node = tb_tree.getNode(loc.terminalId, loc.height);
if (node.isEmptyCategoryRec()) {
continue;
}
if (!node.isPos("PP")) {
continue;
}
for (TBNode child : node.getChildren()) {
if (child.isPos("IN")) {
increment(map, child.form.toLowerCase());
break;
}
}
}
increment(map, TOTAL);
}
}
void printNumberedAdjunct(String outputFile) {
try (PrintStream fout = IOUtil.createPrintFileStream(outputFile)) {
ObjectIntOpenHashMap<String> map;
ArrayList<JObjectIntTuple<String>> list;
int total;
String key;
for (int i = 0; i < ls_numberedAdjuncts.size(); i++) {
map = ls_numberedAdjuncts.get(i);
list = new ArrayList<>();
for (ObjectCursor<String> cur : map.keys()) {
key = cur.value;
if (key.equals(TOTAL)) {
continue;
}
list.add(new JObjectIntTuple<>(key, map.get(key)));
}
Collections.sort(list);
total = map.get(TOTAL);
fout.println("ARG" + i + "\t" + total);
for (JObjectIntTuple<String> tup : list) {
fout.println(tup.object + "\t" + tup.integer + "\t" + (double) tup.integer * 100 / total);
}
}
fout.flush();
}
}
// ----------------------------- VerbPrepPMI -----------------------------
void initVerbPrepPMI() {
m_verbs = new ObjectIntOpenHashMap<>();
m_preps = new ObjectIntOpenHashMap<>();
m_verbPreps = new ObjectIntOpenHashMap<>();
m_verbPrepNs = new ObjectIntOpenHashMap<>();
m_verbPrepMs = new ObjectIntOpenHashMap<>();
m_verbArgs = new ObjectIntOpenHashMap<>();
m_verbArgNs = new ObjectIntOpenHashMap<>();
m_verbArgMs = new ObjectIntOpenHashMap<>();
}
void processVerbPrepPMI(PBInstance instance) {
boolean isPassive = VoiceDetector.getPassive(tb_tree.getNode(instance.predicateId, 0)) > 0;
String vLemma = instance.type, pLemma;
TBNode node;
increment(m_verbs, vLemma);
increment(m_verbs, TOTAL);
for (PBArg arg : instance.getArgs()) {
if (!arg.label.startsWith("ARG")) {
continue;
}
if (arg.label.matches("ARGM-MOD|ARGM-NEG")) {
continue;
}
if (arg.label.matches("ARG\\d")) {
increment(m_verbArgNs, vLemma);
} else {
increment(m_verbArgMs, vLemma);
}
increment(m_verbArgs, vLemma);
increment(m_verbArgs, TOTAL);
for (PBLoc loc : arg.getLocs()) {
node = tb_tree.getNode(loc.terminalId, loc.height);
if (node.isEmptyCategoryRec()) {
continue;
}
if (!node.isPos("PP")) {
continue;
}
for (TBNode child : node.getChildren()) {
if (child.isPos("IN")) {
pLemma = child.form.toLowerCase();
if (!(isPassive && arg.label.equals("ARG0") && pLemma.equals("by"))) {
String key = vLemma + "_" + pLemma;
increment(m_verbPreps, key);
if (arg.label.matches("ARG\\d")) {
increment(m_verbPrepNs, key);
} else {
increment(m_verbPrepMs, key);
}
increment(m_preps, pLemma);
increment(m_preps, TOTAL);
// if (vLemma.equals("buy") && pLemma.equals("at"))
// System.out.println(instance.rolesetId+" "+arg.label+"\n"+tb_tree.getRootNode().toWords());
}
break;
}
}
}
}
}
void trimVerbPrep() {
String key;
int value;
String[] tmp;
for (ObjectCursor<String> cur : m_verbPreps.keys()) {
key = cur.value;
tmp = key.split("_");
value = m_verbPreps.get(key);
if (value <= 1) {
decrement(m_verbPreps, key, value);
decrement(m_verbPrepNs, key, value);
decrement(m_verbPrepMs, key, value);
decrement(m_verbArgs, key, value);
decrement(m_verbArgs, TOTAL, value);
decrement(m_verbArgNs, key, value);
decrement(m_verbArgMs, key, value);
decrement(m_verbs, tmp[0], value);
decrement(m_verbs, TOTAL, value);
decrement(m_preps, tmp[1], value);
decrement(m_preps, TOTAL, value);
}
}
}
void printVerbPrepPMI(String outputFile) {
try (PrintStream fout = IOUtil.createPrintFileStream(outputFile)) {
ArrayList<JObjectDoubleTuple<String>> list = new ArrayList<>();
int nVerb, nVerbTotal, nPrep, nVerbArg, nVerbArgTotal, nVerbPrep;
String key;
String[] tmp;
double smooth = 0.000001;
@SuppressWarnings("unused")
double pmi, pv, p, v, pnv, pmv;
nVerbTotal = m_verbs.get(TOTAL);
nVerbArgTotal = m_verbArgs.get(TOTAL);
for (ObjectCursor<String> cur : m_verbPreps.keys()) {
key = cur.value;
tmp = key.split("_");
nVerbPrep = m_verbPreps.get(key);
if (nVerbPrep == 0) {
continue;
}
nVerb = m_verbs.get(tmp[0]);
nVerbArg = m_verbArgs.get(tmp[0]);
nPrep = m_preps.get(tmp[1]);
pv = (double) nVerbPrep / nVerbArg;
p = (double) nPrep / nVerbArgTotal;
v = (double) nVerb / nVerbTotal;
pnv = smooth + (double) m_verbPrepNs.get(key) / m_verbArgNs.get(tmp[0]);
pmv = smooth + (double) m_verbPrepMs.get(key) / m_verbArgMs.get(tmp[0]);
if (m_verbArgMs.get(tmp[0]) == 0) {
pmv = smooth;
}
pmi = Math.log(pnv / pmv);
// pmi = getPMI(pv, p);
// pmi /= -(Math.log(pv) + Math.log(v));
list.add(new JObjectDoubleTuple<>(key, pmi));
}
Collections.sort(list);
for (JObjectDoubleTuple<String> tup : list) {
key = tup.object;
tmp = key.split("_");
pnv = smooth + (double) m_verbPrepNs.get(key) / m_verbArgNs.get(tmp[0]);
pmv = smooth + (double) m_verbPrepMs.get(key) / m_verbArgMs.get(tmp[0]);
if (m_verbArgMs.get(tmp[0]) == 0) {
pmv = smooth;
}
fout.println(key + "\t" + pnv + "\t" + pmv + "\t" + tup.value);
// fout.println(key+"\t"+m_verbPreps.get(key)+"\t"+m_preps.get(tmp[1])+"\t"+m_verbs.get(tmp[0])+"\t"+tup.value);
}
fout.flush();
}
}
// ----------------------------- RequiredArgument -----------------------------
void initRequiredArgument() {
m_verbs = new ObjectIntOpenHashMap<>();
m_preps = new ObjectIntOpenHashMap<>();
m_verbArgNs = new ObjectIntOpenHashMap<>();
m_verbArgMs = new ObjectIntOpenHashMap<>();
}
void processRequiredArgument(PBInstance instance) {
boolean isPassive = VoiceDetector.getPassive(tb_tree.getNode(instance.predicateId, 0)) > 0;
TBNode predicate = tb_tree.getNode(instance.predicateId, 0);
String sentence = predicate.getSentenceGroup();
if (!(!isPassive && sentence != null && sentence.equals("SQ"))) {
return;
}
n_count++;
String vLemma = instance.type;
String key;
increment(m_verbs, vLemma);
increment(m_verbs, TOTAL);
for (PBArg arg : instance.getArgs()) {
if (!arg.label.startsWith("ARG")) {
continue;
}
if (arg.label.matches("ARGM-MOD|ARGM-NEG")) {
continue;
}
if (vLemma.equals("buy") && arg.label.equals("ARGM-LOC")) {
System.out.println(instance.predicateId + " " + arg.getLocs() + " " + tb_tree.getRootNode().toWords());
}
key = vLemma + "_" + arg.label;
if (arg.label.matches("ARG\\d")) {
increment(m_verbArgNs, key);
} else {
increment(m_verbArgMs, key);
increment(m_preps, arg.label);
increment(m_preps, TOTAL);
}
}
}
void printRequiredArgument(String outputFile) {
try (PrintStream fout = IOUtil.createPrintFileStream(outputFile)) {
System.out.println(n_count);
int nVerb, nVerbTotal, nPrep, nPrepTotal, nVerbArg;
String key;
String[] tmp;
double pmi, pv, p, v;
double thresh = 0;
HashMap<String, ArrayList<JObjectDoubleTuple<String>>> map = new HashMap<>();
for (ObjectCursor<String> cur : m_verbs.keys()) {
if (cur.value.equals(TOTAL)) {
continue;
}
map.put(cur.value, new ArrayList<JObjectDoubleTuple<String>>());
}
ArrayList<JObjectDoubleTuple<String>> list;
for (ObjectCursor<String> cur : m_verbArgNs.keys()) {
key = cur.value;
tmp = key.split("_");
list = map.get(tmp[0]);
nVerbArg = m_verbArgNs.get(key);
nVerb = m_verbs.get(tmp[0]);
pmi = (double) nVerbArg * 100 / nVerb;
if (pmi > thresh) {
list.add(new JObjectDoubleTuple<>(tmp[1], pmi));
}
}
nVerbTotal = m_verbs.get(TOTAL);
nPrepTotal = m_preps.get(TOTAL);
for (ObjectCursor<String> cur : m_verbArgMs.keys()) {
key = cur.value;
tmp = key.split("_");
list = map.get(tmp[0]);
nVerbArg = m_verbArgMs.get(key);
nVerb = m_verbs.get(tmp[0]);
nPrep = m_preps.get(tmp[1]);
pv = (double) nVerbArg / nVerb;
p = (double) nPrep / nPrepTotal;
v = (double) nVerb / nVerbTotal;
pmi = getPMI(pv, p);
pmi /= -(Math.log(pv) + Math.log(v));
if (pmi > 0) {
list.add(new JObjectDoubleTuple<>(tmp[1], pmi));
}
}
for (String verb : map.keySet()) {
list = map.get(verb);
Collections.sort(list);
StringBuilder build = new StringBuilder();
build.append(verb);
build.append("\t");
build.append(m_verbs.get(verb));
for (JObjectDoubleTuple<String> tup : list) {
build.append("\t");
build.append(tup.toString());
}
fout.println(build.toString());
}
fout.flush();
}
}
double getPMI(double pxy, double px) {
return Math.log(pxy / px);
}
void increment(ObjectIntOpenHashMap<String> map, String key) {
map.put(key, map.get(key) + 1);
}
void decrement(ObjectIntOpenHashMap<String> map, String key, int dec) {
map.put(key, map.get(key) - dec);
}
double log2(double d) {
return Math.log(d) / Math.log(2);
}
boolean isArgM(String label) {
return label.startsWith("ARGM") && !label.equals("ARGM-NEG") && !label.equals("ARGM-MOD");
}
public static void main(String[] args) {
AnalyzePBArgs analyzePBArgs = new AnalyzePBArgs(args[0], args[1]);
}
}