package edu.stanford.nlp.semparse.open.model.feature;
import java.util.regex.Pattern;
import edu.stanford.nlp.semparse.open.dataset.Example;
import edu.stanford.nlp.semparse.open.ling.LingUtils;
import edu.stanford.nlp.semparse.open.ling.QueryTypeTable;
import edu.stanford.nlp.semparse.open.ling.WordNetClusterTable;
import edu.stanford.nlp.semparse.open.model.FeatureMatcher;
import edu.stanford.nlp.semparse.open.model.FeatureVector;
import edu.stanford.nlp.semparse.open.model.candidate.Candidate;
import edu.stanford.nlp.semparse.open.model.candidate.CandidateGroup;
import fig.basic.Fmt;
import fig.basic.LogInfo;
import fig.basic.Option;
public class FeaturePostProcessorConjoin extends FeaturePostProcessor {
public static class Options {
@Option(gloss = "conjoin features with an abstract representation of the query")
public boolean useConjoin = false;
@Option public String cjQueryTypeName = null;
@Option public boolean cjConjoinWithWordNetClusters = false;
@Option public String cjRegExConjoin = "^(ling|entity).*";
@Option public boolean cjKeepOriginalFeatures = false;
@Option public double cjScaleConjoinFeatures = 1.0;
}
public static Options opts = new Options();
public static void debugPrintOptions() {
if (opts.cjQueryTypeName != null && !opts.cjQueryTypeName.isEmpty())
LogInfo.logs("Conjoining query type: %s", opts.cjQueryTypeName);
else
LogInfo.log("Conjoining ALL query types");
if (opts.cjRegExConjoin != null && !opts.cjRegExConjoin.isEmpty())
LogInfo.logs("Conjoining features matching regex: %s", opts.cjRegExConjoin);
else
LogInfo.log("Conjoining ALL features");
if (opts.cjKeepOriginalFeatures)
LogInfo.log("... also keep original features");
if (opts.cjScaleConjoinFeatures != 1.0)
LogInfo.logs("... also scale conjoined features by %s", Fmt.D(opts.cjScaleConjoinFeatures));
}
@Override
public void process(Candidate candidate) {
if (!opts.useConjoin) return;
String prefix = getConjoiningPrefix(candidate.ex);
candidate.features = getConjoinedFeatureVector(candidate.features, prefix);
}
@Override
public void process(CandidateGroup group) {
if (!opts.useConjoin) return;
String prefix = getConjoiningPrefix(group.ex);
group.features = getConjoinedFeatureVector(group.features, prefix);
}
// ============================================================
// Compute the abstract representation g(query)
// ============================================================
private String getQueryType(Example ex) {
return getQueryType(ex.phrase);
}
private String getQueryType(String phrase) {
String queryType;
if (opts.cjConjoinWithWordNetClusters) {
queryType = WordNetClusterTable.getCluster(LingUtils.findHeadWord(phrase, true));
} else {
queryType = QueryTypeTable.getQueryType(phrase);
}
return "" + queryType;
}
private String getConjoiningPrefix(Example ex) {
if (opts.cjQueryTypeName != null && !opts.cjQueryTypeName.isEmpty())
return opts.cjQueryTypeName.equals(getQueryType(ex)) ? "I" : "O";
else
return getQueryType(ex);
}
// ============================================================
// Converting feature f to (g(query), f)
// ============================================================
class RegExFeatureMatcher implements FeatureMatcher {
public final Pattern regex;
public final boolean inverse;
public RegExFeatureMatcher(String regex) {
this(Pattern.compile(regex), false);
}
public RegExFeatureMatcher(Pattern regex) {
this(regex, false);
}
public RegExFeatureMatcher(String regex, boolean inverse) {
this(Pattern.compile(regex), inverse);
}
public RegExFeatureMatcher(Pattern regex, boolean inverse) {
this.regex = regex;
this.inverse = inverse;
}
@Override
public boolean matches(String feature) {
boolean match = regex.matcher(feature).matches();
return inverse ? !match : match;
}
}
private FeatureVector getConjoinedFeatureVector(FeatureVector vOld, String queryType) {
FeatureVector v = new FeatureVector();
if (opts.cjRegExConjoin != null) {
FeatureMatcher matcher = new RegExFeatureMatcher(opts.cjRegExConjoin),
invMatcher = new RegExFeatureMatcher(opts.cjRegExConjoin, true);
if (opts.cjKeepOriginalFeatures) {
v.addConjoin(vOld, "ALL");
} else {
v.addConjoin(vOld, "ALL", invMatcher);
}
if (opts.cjScaleConjoinFeatures != 1.0) {
v.addConjoin(vOld, queryType, matcher, opts.cjScaleConjoinFeatures);
} else {
v.addConjoin(vOld, queryType, matcher);
}
} else {
if (opts.cjKeepOriginalFeatures) v.addConjoin(vOld, "ALL");
v.addConjoin(vOld, queryType);
}
return v;
}
}