package clear.experiment;
import clear.dep.DepNode;
import clear.dep.DepTree;
import clear.dep.srl.SRLHead;
import clear.dep.srl.SRLInfo;
import clear.reader.SRLReader;
import clear.util.IOUtil;
import com.carrotsearch.hppc.IntOpenHashSet;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
public class ExcludeUnknowVB {
public ExcludeUnknowVB(String trainFile, String testFile, String outputFile) {
removeKnownVerbs(testFile, outputFile, getknownVerbs(trainFile));
}
private HashSet<String> getknownVerbs(String trainFile) {
HashSet<String> set = new HashSet<>();
SRLReader reader = new SRLReader(trainFile, true);
DepTree tree;
DepNode node;
while ((tree = reader.nextTree()) != null) {
for (int i = 1; i < tree.size(); i++) {
node = tree.get(i);
if (node.isPredicate()) {
set.add(node.lemma);
}
}
}
return set;
}
private void removeKnownVerbs(String testFile, String outputFile, HashSet<String> verbs) {
SRLReader reader = new SRLReader(testFile, true);
try (PrintStream fout = IOUtil.createPrintFileStream(outputFile)) {
DepTree tree;
DepNode node;
SRLInfo info;
IntOpenHashSet set;
ArrayList<SRLHead> list;
while ((tree = reader.nextTree()) != null) {
set = new IntOpenHashSet();
for (int i = 1; i < tree.size(); i++) {
node = tree.get(i);
if (node.isPredicate() && verbs.contains(node.lemma)) {
set.add(node.id);
}
}
for (int i = 1; i < tree.size(); i++) {
node = tree.get(i);
info = node.srlInfo;
list = new ArrayList<>();
for (SRLHead head : info.heads) {
if (set.contains(head.headId)) {
list.add(head);
}
}
if (!list.isEmpty()) {
info.heads.removeAll(list);
}
}
fout.println(tree + "\n");
}
}
}
static public void main(String[] args) {
String trainFile = args[0];
String testFile = args[1];
String outputFile = args[2];
ExcludeUnknowVB excludeUnknowVB =
new ExcludeUnknowVB(trainFile, testFile, outputFile);
}
}