package edu.stanford.nlp.semparse.open.model.candidate;
import java.util.*;
import edu.stanford.nlp.semparse.open.dataset.Example;
import edu.stanford.nlp.semparse.open.model.feature.FeatureExtractor;
import edu.stanford.nlp.semparse.open.model.tree.KNode;
import fig.basic.LogInfo;
import fig.basic.Option;
public class CandidateGenerator {
public static class Options {
@Option(gloss = "Maximum number of XPath entries to toggle the indices")
public int maxTweakDepth = 8;
@Option(gloss = "Minimum number of selected entities to be considered a valid candidate")
public int minNumCandidateEntity = 2;
@Option(gloss = "Use the advanced tree traverser")
public boolean useAdvancedTreeTraverser = false;
@Option(gloss = "Maximum number of XPath entries that can be toggled to wildcard")
public int allowWildcards = 0;
@Option(gloss = "Maximum number of XPath entries that can be end-cut")
public int allowEndCuts = 0;
@Option(gloss = "Maximum depth of XPath entries that can be advancedly tweaked")
public int maxAdvancedTweakDepth = 4;
}
public static Options opts = new Options();
public final List<String> BLACKLISTED_TAGS = Arrays.asList(
"html", "head", "body", "script", "noscript", "link", "style"
);
public void process(Example ex) {
if (ex.candidates != null) {
LogInfo.warnings("Example %s already has a candidate list", ex);
return;
}
LogInfo.begin_track("Extracting candidates ...");
ex.candidateGroups = new ArrayList<>();
ex.candidates = new ArrayList<>();
new CandidatePopulator(ex).populateCandidates();
LogInfo.logs("Found %d candidates (%d groups)", ex.candidates.size(), ex.candidateGroups.size());
LogInfo.end_track();
LogInfo.begin_track("Extracting features ...");
for (CandidateGroup group : ex.candidateGroups)
FeatureExtractor.featureExtractor.extract(group);
for (Candidate candidate : ex.candidates)
FeatureExtractor.featureExtractor.extract(candidate);
LogInfo.end_track();
}
// ============================================================
// Find candidates
// ============================================================
class CandidatePopulator {
Example ex;
public CandidatePopulator(Example ex) {
this.ex = ex;
}
void populateCandidates() {
populateCandidates(ex.tree);
}
private void populateCandidates(KNode rootNode) {
// Only start from the top <html> tags
if (rootNode.type != KNode.Type.TAG || !rootNode.value.equals("html")) {
for (KNode child : rootNode.getChildren()) {
populateCandidates(child);
}
return;
}
// Traverse the knowledge tree and collect all possible paths
TreeTraverser traverser = opts.useAdvancedTreeTraverser ? new AdvancedTreeTraverser(rootNode)
: new BasicTreeTraverser(rootNode);
Map<List<KNode>, CandidateGroup> nodesToCandidateGroup = new HashMap<>();
for (List<PathEntry> path : traverser.getFoundPaths()) {
// Execute the path and check if the path is valid.
List<KNode> nodes = new ArrayList<>();
PathUtils.executePath(path, rootNode, nodes);
if (nodes.size() > opts.minNumCandidateEntity) {
CandidateGroup group = nodesToCandidateGroup.get(nodes);
if (group == null) {
ex.candidateGroups.add(group = new CandidateGroup(ex, nodes));
nodesToCandidateGroup.put(nodes, group);
}
ex.candidates.add(group.addCandidate(new TreePattern(rootNode, path, group.selectedNodes)));
}
}
}
}
interface TreeTraverser {
public Collection<List<PathEntry>> getFoundPaths();
}
class BasicTreeTraverser implements TreeTraverser {
List<PathEntry> ancestors;
Set<List<PathEntry>> foundPaths;
public BasicTreeTraverser(KNode rootNode) {
ancestors = new ArrayList<>();
foundPaths = new HashSet<>();
traverseTree(rootNode);
}
private void traverseTree(KNode currentNode) {
if (currentNode.parent.countChildren(currentNode.value) > 1)
ancestors.add(new PathEntry(currentNode.value, currentNode.getChildIndexOfSameTag()));
else
ancestors.add(new PathEntry(currentNode.value));
// Process current node
if (!isBlacklisted(currentNode))
tweakPaths(1);
// Traverse children
for (KNode child : currentNode.getChildren()) {
if (child.type == KNode.Type.TAG) traverseTree(child);
}
ancestors.remove(ancestors.size() - 1);
}
private boolean isBlacklisted(KNode node) {
if (node.fullText == null || node.fullText.isEmpty()) return true;
if (BLACKLISTED_TAGS.contains(node.value)) return true;
return false;
}
/**
* Toggle the indices of the xpath entries.
* For example, /html/body/div[3]/a[1] will produce
* - /html/body/div[3]/a[1]
* - /html/body/div[3]/a
* - /html/body/div/a[1]
* - /html/body/div/a
*
* Implemented using recursion on depth:
* The toggled xpath entry is xpath[xpath.length - depth]
* (depth = 1, 2, ..., opts.maxTweakDepth)
*
* Exception: the first entry (html) will not be toggled.
*/
private void tweakPaths(int depth) {
if (depth > opts.maxTweakDepth || depth >= ancestors.size()) {
foundPaths.add(new ArrayList<>(ancestors));
return;
}
tweakPaths(depth + 1);
PathEntry swap = ancestors.get(ancestors.size() - depth);
if (swap.index != -1) {
ancestors.set(ancestors.size() - depth, swap.getNoIndexVersion());
tweakPaths(depth + 1);
ancestors.set(ancestors.size() - depth, swap);
}
}
@Override
public Collection<List<PathEntry>> getFoundPaths() {
return foundPaths;
}
}
// ============================================================
// Advanced Tree Traverser
// ============================================================
class PathEntryAugmented {
public final String tag;
public final int childIndex, childIndexOfTag; // 0-indexed
public final int numSiblings, numSiblingsOfTag; // including self too
public PathEntryAugmented(KNode node) {
this.tag = node.value;
int countChildIndex = -1, countChildIndexOfTag = -1, countNumSiblings = 0, countNumSiblingsOfTag = 0;
for (KNode sibling : node.parent.getChildren()) {
if (sibling == node) {
countChildIndex = countNumSiblings;
countChildIndexOfTag = countNumSiblingsOfTag;
}
if (sibling.type == KNode.Type.TAG) {
countNumSiblings++;
if (node.value.equals(sibling.value)) countNumSiblingsOfTag++;
}
}
childIndex = countChildIndex;
childIndexOfTag = countChildIndexOfTag;
numSiblings = countNumSiblings;
numSiblingsOfTag = countNumSiblingsOfTag;
if (countChildIndex == -1 || countChildIndexOfTag == -1)
LogInfo.fails("WTF? %s %s %s", node, tag, node.fullText);
}
@Override
public String toString() {
return String.format("%s[%d/%d]", tag, childIndexOfTag, numSiblingsOfTag);
}
@Override
public boolean equals(Object obj) {
if (obj == this)
return true;
if (obj == null || obj.getClass() != this.getClass())
return false;
PathEntryAugmented that = (PathEntryAugmented) obj;
return (this.tag.equals(that.tag)
&& this.childIndex == that.childIndex
&& this.childIndexOfTag == that.childIndexOfTag
&& this.numSiblings == that.numSiblings
&& this.numSiblingsOfTag == that.numSiblingsOfTag);
}
@Override
public int hashCode() {
return tag.hashCode() | childIndex << 24 | childIndexOfTag << 16 | numSiblings << 8 | numSiblingsOfTag;
}
}
class AdvancedTreeTraverser implements TreeTraverser {
List<PathEntryAugmented> ancestors;
Set<List<PathEntryAugmented>> foundRawPaths;
List<PathEntryAugmented> currentRawPath;
List<PathEntry> currentTweakedPath;
Set<List<PathEntry>> foundTweakedPaths;
public AdvancedTreeTraverser(KNode rootNode) {
ancestors = new ArrayList<>();
foundRawPaths = new HashSet<>();
foundTweakedPaths = new HashSet<>();
traverseTree(rootNode);
LogInfo.logs("Found %d raw paths", foundRawPaths.size());
for (List<PathEntryAugmented> rawPath : foundRawPaths) {
//LogInfo.log(rawPath);
currentRawPath = rawPath;
createInitialTweakedPath();
tweakPaths(1);
}
LogInfo.logs("Found %d tweaked paths", foundTweakedPaths.size());
}
private void traverseTree(KNode currentNode) {
ancestors.add(new PathEntryAugmented(currentNode));
// Process current node
if (!isBlacklisted(currentNode))
savePath();
// Traverse children
for (KNode child : currentNode.getChildren()) {
if (child.type == KNode.Type.TAG) traverseTree(child);
}
ancestors.remove(ancestors.size() - 1);
}
private boolean isBlacklisted(KNode node) {
if (node.fullText == null || node.fullText.isEmpty()) return true;
if (BLACKLISTED_TAGS.contains(node.value)) return true;
return false;
}
private void savePath() {
foundRawPaths.add(new ArrayList<>(ancestors));
}
int numWildCards = 0;
int numEndCuts = 0;
private void createInitialTweakedPath() {
currentTweakedPath = new ArrayList<>();
for (PathEntryAugmented entry : currentRawPath) {
if (entry.numSiblingsOfTag == 1) {
currentTweakedPath.add(new PathEntry(entry.tag));
} else {
currentTweakedPath.add(new PathEntry(entry.tag, entry.childIndexOfTag));
}
}
numWildCards = numEndCuts = 0;
}
/**
* Tweak currentTweakedPath[n - depth]
* (depth = 1, 2, ..., opts.maxTweakDepth)
*/
private void tweakPaths(int depth) {
int n = currentTweakedPath.size();
if (depth > opts.maxTweakDepth || depth >= n) {
foundTweakedPaths.add(new ArrayList<>(currentTweakedPath));
return;
}
tweakPaths(depth + 1);
PathEntry swap = currentTweakedPath.get(n - depth);
if (swap.index != -1) {
currentTweakedPath.set(n - depth, swap.getNoIndexVersion());
tweakPaths(depth + 1);
currentTweakedPath.set(n - depth, swap);
}
if (depth <= opts.maxAdvancedTweakDepth && numEndCuts < opts.allowEndCuts) {
int numSiblings = swap.tag.equals("*") ? currentRawPath.get(n - depth).numSiblings :
currentRawPath.get(n - depth).numSiblingsOfTag;
if (numSiblings > 1) {
numEndCuts++;
currentTweakedPath.set(n - depth, new PathEntryWithRange(swap.tag, 1, 0));
tweakPaths(depth + 1);
currentTweakedPath.set(n - depth, new PathEntryWithRange(swap.tag, 0, 1));
tweakPaths(depth + 1);
currentTweakedPath.set(n - depth, swap);
numEndCuts--;
}
}
if (depth <= opts.maxAdvancedTweakDepth && numWildCards < opts.allowWildcards && !swap.tag.equals("*")) {
numWildCards++;
if (currentRawPath.get(n - depth).numSiblings == 1) {
currentTweakedPath.set(n - depth, new PathEntry("*"));
} else {
currentTweakedPath.set(n - depth, new PathEntry("*", currentRawPath.get(n - depth).childIndex));
}
tweakPaths(depth);
currentTweakedPath.set(n - depth, swap);
numWildCards--;
}
}
@Override
public Collection<List<PathEntry>> getFoundPaths() {
//debugPrint();
return foundTweakedPaths;
}
protected void debugPrint() {
if (CandidateGenerator.iter == 0) {
LogInfo.begin_track("Found paths");
List<String> paths = new ArrayList<>();
for (List<PathEntry> path : foundTweakedPaths) {
paths.add(PathUtils.getXPathString(path));
}
Collections.sort(paths);
for (String path : paths) {
LogInfo.log(path);
}
LogInfo.end_track();
}
CandidateGenerator.iter++;
}
}
private static int iter = 0;
}