package edu.stanford.nlp.coref.statistical;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.stanford.nlp.coref.data.Dictionaries.MentionType;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
/**
* Class for filtering out input features and producing feature conjunctions.
* @author Kevin Clark
*/
public class MetaFeatureExtractor {
public enum PairConjunction {FIRST, LAST, BOTH}
public enum SingleConjunction {INDEX, INDEX_CURRENT, INDEX_OTHER, INDEX_BOTH, INDEX_LAST}
private final boolean neTypeConjuntion;
private final boolean anaphoricityClassifier;
private final Set<PairConjunction> pairConjunctions;
private final Set<SingleConjunction> singleConjunctions;
private final List<String> disallowedPrefixes;
private final String str;
public static class Builder {
private boolean anaphoricityClassifier = false;
private List<PairConjunction> pairConjunctions = Arrays.asList(
new PairConjunction[] {PairConjunction.LAST,
PairConjunction.FIRST,
PairConjunction.BOTH});
private List<SingleConjunction> singleConjunctions = Arrays.asList(
new SingleConjunction[] {SingleConjunction.INDEX,
SingleConjunction.INDEX_CURRENT,
SingleConjunction.INDEX_BOTH});
private List<String> disallowedPrefixes = new ArrayList<>();
private boolean useNEType = true;
public Builder anaphoricityClassifier(boolean anaphoricityClassifier)
{ this.anaphoricityClassifier = anaphoricityClassifier; return this; }
public Builder pairConjunctions(PairConjunction[] pairConjunctions)
{ this.pairConjunctions = Arrays.asList(pairConjunctions); return this; }
public Builder singleConjunctions(SingleConjunction[] singleConjunctions)
{ this.singleConjunctions = Arrays.asList(singleConjunctions); return this; }
public Builder disallowedPrefixes(String[] disallowedPrefixes)
{ this.disallowedPrefixes = Arrays.asList(disallowedPrefixes); return this; }
public Builder useNEType(boolean useNEType)
{ this.useNEType = useNEType; return this; }
public MetaFeatureExtractor build() {
return new MetaFeatureExtractor(this);
}
}
public static Builder newBuilder() {
return new Builder();
}
public MetaFeatureExtractor(Builder builder) {
anaphoricityClassifier = builder.anaphoricityClassifier;
if (anaphoricityClassifier) {
pairConjunctions = new HashSet<>();
} else {
pairConjunctions = new HashSet<>(builder.pairConjunctions);
}
singleConjunctions = new HashSet<>(builder.singleConjunctions);
disallowedPrefixes = builder.disallowedPrefixes;
neTypeConjuntion = builder.useNEType;
str = StatisticalCorefTrainer.fieldValues(builder);
}
public static MetaFeatureExtractor anaphoricityMFE() {
return MetaFeatureExtractor.newBuilder()
.singleConjunctions(new SingleConjunction[] {SingleConjunction.INDEX,
SingleConjunction.INDEX_LAST})
.disallowedPrefixes(new String[] {"parent-word"})
.anaphoricityClassifier(true)
.build();
}
public static Counter<String> filterOut(Counter<String> c, List<String> disallowedPrefixes) {
Counter<String> c2 = new ClassicCounter<>();
for (Map.Entry<String, Double> e : c.entrySet()) {
boolean allowed = true;
for (String prefix : disallowedPrefixes) {
allowed &= !e.getKey().startsWith(prefix);
}
if (allowed) {
c2.incrementCount(e.getKey(), e.getValue());
}
}
return c2;
}
public Counter<String> getFeatures(Example example,
Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor) {
Counter<String> features = new ClassicCounter<>();
Counter<String> pairFeatures = new ClassicCounter<>();
Counter<String> features1 = new ClassicCounter<>();
Counter<String> features2 = compressor.uncompress(mentionFeatures.get(example.mentionId2));
if (!example.isNewLink()) {
assert(!anaphoricityClassifier);
pairFeatures = compressor.uncompress(example.pairwiseFeatures);
features1 = compressor.uncompress(mentionFeatures.get(example.mentionId1));
} else {
features2.incrementCount("bias");
}
if (!disallowedPrefixes.isEmpty()) {
features1 = filterOut(features1, disallowedPrefixes);
features2 = filterOut(features2, disallowedPrefixes);
pairFeatures = filterOut(pairFeatures, disallowedPrefixes);
}
List<String> ids1 = example.isNewLink() ? new ArrayList<>() :
identifiers(features1, example.mentionType1);
List<String> ids2 = identifiers(features2, example.mentionType2);
features.addAll(pairFeatures);
for (String id1 : ids1) {
for (String id2 : ids2) {
if (pairConjunctions.contains(PairConjunction.FIRST)) {
features.addAll(getConjunction(pairFeatures, "_m1=" + id1));
}
if (pairConjunctions.contains(PairConjunction.LAST)) {
features.addAll(getConjunction(pairFeatures, "_m2=" + id2));
}
if (pairConjunctions.contains(PairConjunction.BOTH)) {
features.addAll(getConjunction(pairFeatures, "_ms=" + id1 + "_" + id2));
}
if (singleConjunctions.contains(SingleConjunction.INDEX)) {
features.addAll(getConjunction(features1, "_1"));
features.addAll(getConjunction(features2, "_2"));
}
if (singleConjunctions.contains(SingleConjunction.INDEX_CURRENT)) {
features.addAll(getConjunction(features1, "_1" + "_m=" + id1));
features.addAll(getConjunction(features2, "_2" + "_m=" + id2));
}
if (singleConjunctions.contains(SingleConjunction.INDEX_LAST)) {
features.addAll(getConjunction(features1, "_1" + "_m2=" + id2));
features.addAll(getConjunction(features2, "_2" + "_m2=" + id2));
}
if (singleConjunctions.contains(SingleConjunction.INDEX_OTHER)) {
features.addAll(getConjunction(features1, "_1" + "_m=" + id2));
features.addAll(getConjunction(features2, "_2" + "_m=" + id1));
}
if (singleConjunctions.contains(SingleConjunction.INDEX_BOTH)) {
features.addAll(getConjunction(features1, "_1" + "_ms=" + id1 + "_" + id2));
features.addAll(getConjunction(features2, "_2" + "_ms=" + id1 + "_" + id2));
}
}
}
if (example.isNewLink()) {
features.addAll(features2);
features.addAll(getConjunction(features2, "_m=" + ids2.get(0)));
Counter<String> newFeatures = new ClassicCounter<>();
for (Map.Entry<String, Double> e : features.entrySet()) {
newFeatures.incrementCount(e.getKey() + "_NEW", e.getValue());
}
features = newFeatures;
}
return features;
}
private List<String> identifiers(Counter<String> features, MentionType mentionType) {
List<String> identifiers = new ArrayList<>();
if (mentionType == MentionType.PRONOMINAL) {
for (String feature : features.keySet()) {
if (feature.startsWith("head-word=")) {
identifiers.add(feature.replace("head-word=", ""));
return identifiers;
}
}
} else if (neTypeConjuntion && mentionType == MentionType.PROPER) {
for (String feature : features.keySet()) {
if (feature.startsWith("head-ne-type=")) {
identifiers.add(mentionType.toString() + "_" + feature.replace("head-ne-type=", ""));
return identifiers;
}
}
}
identifiers.add(mentionType.toString());
return identifiers;
}
private static Counter<String> getConjunction(Counter<String> original, String suffix) {
Counter<String> conjuction = new ClassicCounter<>();
for (Map.Entry<String, Double> e : original.entrySet()) {
conjuction.incrementCount(e.getKey() + suffix, e.getValue());
}
return conjuction;
}
@Override
public String toString() {
return str;
}
}