package edu.stanford.nlp.parser.lexparser;
import java.util.Set;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
public class BinaryGrammarExtractor extends AbstractTreeExtractor<Pair<UnaryGrammar,BinaryGrammar>> {
protected Index<String> stateIndex;
private ClassicCounter<UnaryRule> unaryRuleCounter = new ClassicCounter<>();
private ClassicCounter<BinaryRule> binaryRuleCounter = new ClassicCounter<>();
protected ClassicCounter<String> symbolCounter = new ClassicCounter<>();
private Set<BinaryRule> binaryRules = Generics.newHashSet();
private Set<UnaryRule> unaryRules = Generics.newHashSet();
// protected void tallyTree(Tree t, double weight) {
// super.tallyTree(t, weight);
// System.out.println("Tree:");
// t.pennPrint();
// }
public BinaryGrammarExtractor(Options op, Index<String> index) {
super(op);
this.stateIndex = index;
}
@Override
protected void tallyInternalNode(Tree lt, double weight) {
if (lt.children().length == 1) {
UnaryRule ur = new UnaryRule(stateIndex.addToIndex(lt.label().value()),
stateIndex.addToIndex(lt.children()[0].label().value()));
symbolCounter.incrementCount(stateIndex.get(ur.parent), weight);
unaryRuleCounter.incrementCount(ur, weight);
unaryRules.add(ur);
} else {
BinaryRule br = new BinaryRule(stateIndex.addToIndex(lt.label().value()),
stateIndex.addToIndex(lt.children()[0].label().value()),
stateIndex.addToIndex(lt.children()[1].label().value()));
symbolCounter.incrementCount(stateIndex.get(br.parent), weight);
binaryRuleCounter.incrementCount(br, weight);
binaryRules.add(br);
}
}
@Override
public Pair<UnaryGrammar,BinaryGrammar> formResult() {
stateIndex.addToIndex(Lexicon.BOUNDARY_TAG);
BinaryGrammar bg = new BinaryGrammar(stateIndex);
UnaryGrammar ug = new UnaryGrammar(stateIndex);
// add unaries
for (UnaryRule ur : unaryRules) {
ur.score = (float) Math.log(unaryRuleCounter.getCount(ur) / symbolCounter.getCount(stateIndex.get(ur.parent)));
if (op.trainOptions.compactGrammar() >= 4) {
ur.score = (float) unaryRuleCounter.getCount(ur);
}
ug.addRule(ur);
}
// add binaries
for (BinaryRule br : binaryRules) {
br.score = (float) Math.log((binaryRuleCounter.getCount(br) - op.trainOptions.ruleDiscount) / symbolCounter.getCount(stateIndex.get(br.parent)));
if (op.trainOptions.compactGrammar() >= 4) {
br.score = (float) binaryRuleCounter.getCount(br);
}
bg.addRule(br);
}
return new Pair<>(ug, bg);
}
} // end class BinaryGrammarExtractor