package edu.stanford.nlp.ling.tokensregex;
import edu.stanford.nlp.util.*;
import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;
import static edu.stanford.nlp.ling.tokensregex.SequenceMatcher.FindType.FIND_NONOVERLAPPING;
/**
* A generic sequence matcher.
*
* <p>
* Similar to Java's {@code Matcher} except it matches sequences over an arbitrary type {@code T}
* instead of characters.
* For a type {@code T} to be matchable, it has to have a corresponding {@code NodePattern<T>} that indicates
* whether a node is matched or not.
* </p>
*
* <p>
* A matcher is created as follows:
* <pre><code>
* SequencePattern<T> p = SequencePattern<T>.compile("...");
* SequencePattern<T> m = p.getMatcher(List<T> sequence);
* </code></pre>
* </p>
*
* <p>
* Functions for searching
* <pre><code>
* boolean matches()
* boolean find()
* boolean find(int start)
* </code></pre>
* Functions for retrieving matched patterns
* <pre><code>
* int groupCount()
* List<T> groupNodes(), List<T> groupNodes(int g)
* String group(), String group(int g)
* int start(), int start(int g), int end(), int end(int g)
* </code></pre>
* Functions for replacing
* <pre><code>
* List<T> replaceFirst(List<T> seq), List replaceAll(List<T> seq)
* List<T> replaceFirstExtended(List<MatchReplacement<T>> seq), List<T> replaceAllExtended(List<MatchReplacement<T>> seq)
* </code></pre>
* Functions for defining the region of the sequence to search over
* (default region is entire sequence)
* <pre><code>
* void region(int start, int end)
* int regionStart()
* int regionEnd()
* </code></pre>
* </p>
*
* <p>
* NOTE: When find is used, matches are attempted starting from the specified start index of the sequence
* The match with the earliest starting index is returned.
* </p>
*
* @author Angel Chang
*/
public class SequenceMatcher<T> extends BasicSequenceMatchResult<T> {
private static final Logger logger = Logger.getLogger(SequenceMatcher.class.getName());
boolean includeEmptyMatches = false;
boolean matchingCompleted = false;
boolean matched = false;
boolean matchWithResult = false; // If result of matches should be kept
int nextMatchStart = 0;
int regionStart = 0;
int regionEnd = -1;
// TODO: Check and fix implementation for FIND_ALL
/**
* Type of search to perform
* <ul>
* <li>FIND_NONOVERLAPPING - Find nonoverlapping matches (default)</li>
* <li>FIND_ALL - Find all potential matches
* Greedy/reluctant quantifiers are not enforced
* (perhaps should add syntax where some of them are enforced...)</li>
* </ul>
*/
public enum FindType { FIND_NONOVERLAPPING, FIND_ALL }
FindType findType = FIND_NONOVERLAPPING;
// For FIND_ALL
Iterator<Integer> curMatchIter = null;
MatchedStates<T> curMatchStates = null;
Set<String> prevMatchedSignatures = new HashSet<>();
// Branching limit for searching with back tracking. Higher value makes the search faster but uses more memory.
int branchLimit = 32;
protected SequenceMatcher(SequencePattern<T> pattern, List<? extends T> elements)
{
this.pattern = pattern;
// NOTE: It is important elements DO NOT change as we do matches
// TODO: Should we just make a copy of the elements?
this.elements = elements;
if (elements == null) {
throw new IllegalArgumentException("Cannot match against null elements");
}
this.regionEnd = elements.size();
this.priority = pattern.priority;
this.score = pattern.weight;
this.varGroupBindings = pattern.varGroupBindings;
matchedGroups = new MatchedGroup[pattern.totalGroups];
}
public void setBranchLimit(int blimit){
this.branchLimit = blimit;
}
/**
* Interface that specifies what to replace a matched pattern with
* @param <T>
*/
public interface MatchReplacement<T> {
/**
* Append to replacement list
* @param match Current matched sequence
* @param list replacement list
*/
public void append(SequenceMatchResult<T> match, List list);
}
/**
* Replacement item is a sequence of items
* @param <T>
*/
public static class BasicMatchReplacement<T> implements MatchReplacement<T> {
List<T> replacement;
@SafeVarargs
public BasicMatchReplacement(T... replacement) {
this.replacement = Arrays.asList(replacement);
}
public BasicMatchReplacement(List<T> replacement) {
this.replacement = replacement;
}
/**
* Append to replacement list our list of replacement items
* @param match Current matched sequence
* @param list replacement list
*/
@Override
public void append(SequenceMatchResult<T> match, List list) {
list.addAll(replacement);
}
}
/**
* Replacement item is a matched group specified with a group name
* @param <T>
*/
public static class NamedGroupMatchReplacement<T> implements MatchReplacement<T> {
String groupName;
public NamedGroupMatchReplacement(String groupName) {
this.groupName = groupName;
}
/**
* Append to replacement list the matched group with the specified group name
* @param match Current matched sequence
* @param list replacement list
*/
public void append(SequenceMatchResult<T> match, List list) {
list.addAll(match.groupNodes(groupName));
}
}
/**
* Replacement item is a matched group specified with a group id
* @param <T>
*/
public static class GroupMatchReplacement<T> implements MatchReplacement<T> {
int group;
public GroupMatchReplacement(int group) {
this.group = group;
}
/**
* Append to replacement list the matched group with the specified group id
* @param match Current matched sequence
* @param list replacement list
*/
public void append(SequenceMatchResult<T> match, List list) {
list.addAll(match.groupNodes(group));
}
}
/**
* Replaces all occurrences of the pattern with the specified list
* of replacement items (can include matched groups).
* @param replacement What to replace the matched sequence with
* @return New list with all occurrences of the pattern replaced
* @see #replaceFirst(java.util.List)
* @see #replaceFirstExtended(java.util.List)
* @see #replaceAllExtended(java.util.List)
*/
public List<T> replaceAllExtended(List<MatchReplacement<T>> replacement) {
List<T> res = new ArrayList<>();
FindType oldFindType = findType;
findType = FindType.FIND_NONOVERLAPPING;
int index = 0;
while (find()) {
// Copy from current index to found index
res.addAll(elements().subList(index, start()));
for (MatchReplacement<T> r:replacement) {
r.append(this, res);
}
index = end();
}
res.addAll(elements().subList(index, elements().size()));
findType = oldFindType;
return res;
}
/**
* Replaces the first occurrence of the pattern with the specified list
* of replacement items (can include matched groups).
* @param replacement What to replace the matched sequence with
* @return New list with the first occurrence of the pattern replaced
* @see #replaceFirst(java.util.List)
* @see #replaceAll(java.util.List)
* @see #replaceAllExtended(java.util.List)
*/
public List<T> replaceFirstExtended(List<MatchReplacement<T>> replacement) {
List<T> res = new ArrayList<>();
FindType oldFindType = findType;
findType = FindType.FIND_NONOVERLAPPING;
int index = 0;
if (find()) {
// Copy from current index to found index
res.addAll(elements().subList(index, start()));
for (MatchReplacement<T> r:replacement) {
r.append(this, res);
}
index = end();
}
res.addAll(elements().subList(index, elements().size()));
findType = oldFindType;
return res;
}
/**
* Replaces all occurrences of the pattern with the specified list.
* Use {@link #replaceAllExtended(java.util.List)} to replace with matched groups.
* @param replacement What to replace the matched sequence with
* @return New list with all occurrences of the pattern replaced
* @see #replaceAllExtended(java.util.List)
* @see #replaceFirst(java.util.List)
* @see #replaceFirstExtended(java.util.List)
*/
public List<T> replaceAll(List<T> replacement) {
List<T> res = new ArrayList<>();
FindType oldFindType = findType;
findType = FindType.FIND_NONOVERLAPPING;
int index = 0;
while (find()) {
// Copy from current index to found index
res.addAll(elements().subList(index, start()));
res.addAll(replacement);
index = end();
}
res.addAll(elements().subList(index, elements().size()));
findType = oldFindType;
return res;
}
/**
* Replaces the first occurrence of the pattern with the specified list.
* Use {@link #replaceFirstExtended(java.util.List)} to replace with matched groups.
* @param replacement What to replace the matched sequence with
* @return New list with the first occurrence of the pattern replaced
* @see #replaceAll(java.util.List)
* @see #replaceAllExtended(java.util.List)
* @see #replaceFirstExtended(java.util.List)
*/
public List<T> replaceFirst(List<T> replacement) {
List<T> res = new ArrayList<>();
FindType oldFindType = findType;
findType = FindType.FIND_NONOVERLAPPING;
int index = 0;
if (find()) {
// Copy from current index to found index
res.addAll(elements().subList(index, start()));
res.addAll(replacement);
index = end();
}
res.addAll(elements().subList(index, elements().size()));
findType = oldFindType;
return res;
}
public FindType getFindType() {
return findType;
}
public void setFindType(FindType findType) {
this.findType = findType;
}
public boolean isMatchWithResult() {
return matchWithResult;
}
public void setMatchWithResult(boolean matchWithResult) {
this.matchWithResult = matchWithResult;
}
/**
* Reset the matcher and then searches for pattern at the specified start index
* @param start - Index at which to start the search
* @return true if a match is found (false otherwise)
* @throws IndexOutOfBoundsException if start is {@literal <} 0 or larger then the size of the sequence
* @see #find()
*/
public boolean find(int start)
{
if (start < 0 || start > elements.size()) {
throw new IndexOutOfBoundsException("Invalid region start=" + start + ", need to be between 0 and " + elements.size());
}
reset();
return find(start, false);
}
protected boolean find(int start, boolean matchStart) {
boolean done = false;
while (!done) {
boolean res = find0(start, matchStart);
if (res) {
boolean empty = this.group().isEmpty();
if (!empty || includeEmptyMatches) return res;
else {
start = start + 1;
}
}
done = !res;
}
return false;
}
protected boolean find0(int start, boolean matchStart)
{
boolean match = false;
matched = false;
matchingCompleted = false;
if (matchStart) {
match = findMatchStart(start, false);
} else {
for (int i = start; i < regionEnd; i++) {
match = findMatchStart(i, false);
if (match) {
break;
}
}
}
matched = match;
matchingCompleted = true;
if (matched) {
nextMatchStart = (findType == FindType.FIND_NONOVERLAPPING)? end(): start()+1;
} else {
nextMatchStart = -1;
}
return match;
}
/**
* Searches for pattern in the region starting
* at the next index
* @return true if a match is found (false otherwise)
*/
private boolean findNextNonOverlapping()
{
if (nextMatchStart < 0) { return false; }
return find(nextMatchStart, false);
}
private boolean findNextAll()
{
if (curMatchIter != null && curMatchIter.hasNext()) {
while (curMatchIter.hasNext()) {
int next = curMatchIter.next();
curMatchStates.setMatchedGroups(next);
String sig = getMatchedSignature();
if (!prevMatchedSignatures.contains(sig)) {
prevMatchedSignatures.add(sig);
return true;
}
}
}
if (nextMatchStart < 0) { return false; }
prevMatchedSignatures.clear();
boolean matched = find(nextMatchStart, false);
if (matched) {
Collection<Integer> matchedBranches = curMatchStates.getMatchIndices();
curMatchIter = matchedBranches.iterator();
int next = curMatchIter.next();
curMatchStates.setMatchedGroups(next);
prevMatchedSignatures.add(getMatchedSignature());
}
return matched;
}
/**
* Applies the matcher and returns all non overlapping matches
* @return a Iterable of match results
*/
public Iterable<SequenceMatchResult<T>> findAllNonOverlapping() {
Iterator<SequenceMatchResult<T>> iter = new Iterator<SequenceMatchResult<T>>() {
SequenceMatchResult<T> next;
private SequenceMatchResult<T> getNext() {
boolean found = find();
if (found) {
return toBasicSequenceMatchResult();
} else {
return null;
}
}
@Override
public boolean hasNext() {
if (next == null) {
next = getNext();
return (next != null);
} else {
return true;
}
}
@Override
public SequenceMatchResult<T> next() {
if (!hasNext()) { throw new NoSuchElementException(); }
SequenceMatchResult<T> res = next;
next = null;
return res;
}
public void remove() {
throw new UnsupportedOperationException();
}
};
return new IterableIterator<>(iter);
}
/**
* Searches for the next occurrence of the pattern
* @return true if a match is found (false otherwise)
* @see #find(int)
*/
public boolean find()
{
switch (findType) {
case FIND_NONOVERLAPPING:
return findNextNonOverlapping();
case FIND_ALL:
return findNextAll();
default:
throw new UnsupportedOperationException("Unsupported findType " + findType);
}
}
protected boolean findMatchStart(int start, boolean matchAllTokens) {
switch (findType) {
case FIND_NONOVERLAPPING:
return findMatchStartBacktracking(start, matchAllTokens);
case FIND_ALL:
// TODO: Should use backtracking here too, need to keep track of todo stack
// so we can recover after finding a match
return findMatchStartNoBacktracking(start, matchAllTokens);
default:
throw new UnsupportedOperationException("Unsupported findType " + findType);
}
}
// Does not do backtracking - alternative matches are stored as we go
protected boolean findMatchStartNoBacktracking(int start, boolean matchAllTokens) {
boolean matchAll = true;
MatchedStates<T> cStates = getStartStates();
cStates.matchLongest = matchAllTokens;
// Save cStates for FIND_ALL ....
curMatchStates = cStates;
for(int i = start; i < regionEnd; i++){
boolean match = cStates.match(i);
if (cStates == null || cStates.size() == 0) {
break;
}
if (!matchAllTokens) {
if ((matchAll && cStates.isAllMatch())
|| (!matchAll && cStates.isMatch())) {
cStates.completeMatch();
return true;
}
}
}
cStates.completeMatch();
return cStates.isMatch();
}
// Does some backtracking...
protected boolean findMatchStartBacktracking(int start, boolean matchAllTokens) {
boolean matchAll = true;
Stack<MatchedStates> todo = new Stack<>();
MatchedStates cStates = getStartStates();
cStates.matchLongest = matchAllTokens;
cStates.curPosition = start-1;
todo.push(cStates);
while (!todo.empty()) {
cStates = todo.pop();
int s = cStates.curPosition+1;
for(int i = s; i < regionEnd; i++){
if (Thread.interrupted()) {
throw new RuntimeInterruptedException();
}
boolean match = cStates.match(i);
if (cStates == null || cStates.size() == 0) {
break;
}
if (!matchAllTokens) {
if ((matchAll && cStates.isAllMatch())
|| (!matchAll && cStates.isMatch())) {
cStates.completeMatch();
return true;
}
}
if (branchLimit >= 0 && cStates.branchSize() > branchLimit) {
MatchedStates s2 = cStates.split(branchLimit);
todo.push(s2);
}
}
if (cStates.isMatch()) {
cStates.completeMatch();
return true;
}
cStates.clean();
}
return false;
}
/**
* Checks if the pattern matches the entire sequence
* @return true if the entire sequence is matched (false otherwise)
* @see #find()
*/
public boolean matches() {
matched = false;
matchingCompleted = false;
boolean status = findMatchStart(0, true);
if (status) {
// Check if entire region is matched
status = ((matchedGroups[0].matchBegin == regionStart) && (matchedGroups[0].matchEnd == regionEnd));
}
matchingCompleted = true;
matched = status;
return status;
}
private void clearMatched() {
for (int i = 0; i < matchedGroups.length; i++) {
matchedGroups[i] = null;
}
if (matchedResults != null) {
for (int i = 0; i < matchedResults.length; i++) {
matchedResults[i] = null;
}
}
}
private String getStateMessage() {
if (!matchingCompleted) {
return "Matching not completed";
} else if (!matched) {
return "No match found";
} else {
return "Match successful";
}
}
/**
* Set region to search in
* @param start - start index
* @param end - end index (exclusive)
*/
public void region(int start, int end) {
if (start < 0 || start > elements.size()) {
throw new IndexOutOfBoundsException("Invalid region start=" + start + ", need to be between 0 and " + elements.size());
}
if (end < 0 || end > elements.size()) {
throw new IndexOutOfBoundsException("Invalid region end=" + end + ", need to be between 0 and " + elements.size());
}
if (start > end) {
throw new IndexOutOfBoundsException("Invalid region end=" + end + ", need to be larger then start=" + start);
}
this.regionStart = start;
this.nextMatchStart = start;
this.regionEnd = end;
}
public int regionEnd()
{
return regionEnd;
}
public int regionStart()
{
return regionStart;
}
/**
* Returns a copy of the current match results. Use this method
* to save away match results for later use, since future operations
* using the SequenceMatcher changes the match results.
* @return Copy of the the current match results
*/
public BasicSequenceMatchResult<T> toBasicSequenceMatchResult() {
if (matchingCompleted && matched) {
return super.toBasicSequenceMatchResult();
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public int start(int group) {
if (matchingCompleted && matched) {
return super.start(group);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public int end(int group) {
if (matchingCompleted && matched) {
return super.end(group);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public List<T> groupNodes(int group) {
if (matchingCompleted && matched) {
return super.groupNodes(group);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public Object groupValue(int group) {
if (matchingCompleted && matched) {
return super.groupValue(group);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public MatchedGroupInfo<T> groupInfo(int group) {
if (matchingCompleted && matched) {
return super.groupInfo(group);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public List<Object> groupMatchResults(int group) {
if (matchingCompleted && matched) {
return super.groupMatchResults(group);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public Object groupMatchResult(int group, int index) {
if (matchingCompleted && matched) {
return super.groupMatchResult(group, index);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
public Object nodeMatchResult(int index) {
if (matchingCompleted && matched) {
return super.nodeMatchResult(index);
} else {
String message = getStateMessage();
throw new IllegalStateException(message);
}
}
/**
* Clears matcher
* - Clears matched groups, reset region to be entire sequence
*/
public void reset() {
regionStart = 0;
regionEnd = elements.size();
nextMatchStart = 0;
matchingCompleted = false;
matched = false;
clearMatched();
// Clearing for FIND_ALL
prevMatchedSignatures.clear();
curMatchIter = null;
curMatchStates = null;
}
/**
* Returns the ith element
* @param i - index
* @return ith element
*/
public T get(int i)
{
return elements.get(i);
}
/** Returns a non-null MatchedStates, which has a non-empty states list inside. */
private MatchedStates<T> getStartStates()
{
return new MatchedStates<>(this, pattern.root);
}
/**
* Contains information about a branch of running the NFA matching
*/
private static class BranchState
{
// Branch id
int bid;
// Parent branch state
BranchState parent;
// Map of group id to matched group
Map<Integer,MatchedGroup> matchedGroups;
// Map of sequence index id to matched node result
Map<Integer,Object> matchedResults;
// Map of state to object storing information about the state for this branch of execution
// Used for states corresponding to
// repeating patterns: key is RepeatState, object is Pair<Integer,Boolean>
// pair indicates sequence index and whether the match was complete
// multinode patterns: key is MultiNodePatternState, object is Interval<Integer>
// the interval indicates the start and end node indices for the multinode pattern
// conjunction patterns: key is ConjStartState, object is ConjMatchStateInfo
Map<SequencePattern.State, Object> matchStateInfo;
//Map<SequencePattern.State, Pair<Integer,Boolean>> matchStateCount;
Set<Integer> bidsToCollapse; // Branch ids to collapse together with this branch
// Used for conjunction states, which requires multiple paths
// through the NFA to hold
Set<Integer> collapsedBids; // Set of Branch ids that has already been collapsed ...
// assumes that after being collapsed no more collapsing required
public BranchState(int bid) {
this(bid, null);
}
public BranchState(int bid, BranchState parent) {
this.bid = bid;
this.parent = parent;
if (parent != null) {
if (parent.matchedGroups != null) {
matchedGroups = new LinkedHashMap<>(parent.matchedGroups);
}
if (parent.matchedResults != null) {
matchedResults = new LinkedHashMap<>(parent.matchedResults);
}
/* if (parent.matchStateCount != null) {
matchStateCount = new LinkedHashMap<SequencePattern.State, Pair<Integer,Boolean>>(parent.matchStateCount);
} */
if (parent.matchStateInfo != null) {
matchStateInfo = new LinkedHashMap<>(parent.matchStateInfo);
}
if (parent.bidsToCollapse != null) {
bidsToCollapse = new ArraySet<>(parent.bidsToCollapse.size());
bidsToCollapse.addAll(parent.bidsToCollapse);
}
if (parent.collapsedBids != null) {
collapsedBids = new ArraySet<>(parent.collapsedBids.size());
collapsedBids.addAll(parent.collapsedBids);
}
}
}
// Add to list of related branch ids that we would like to keep...
private void updateKeepBids(BitSet bids) {
if (matchStateInfo != null) {
// TODO: Make values of matchStateInfo more organized (implement some interface) so we don't
// need this kind of specialized code
for (SequencePattern.State s:matchStateInfo.keySet()) {
if (s instanceof SequencePattern.ConjStartState) {
SequencePattern.ConjMatchStateInfo info = (SequencePattern.ConjMatchStateInfo) matchStateInfo.get(s);
info.updateKeepBids(bids);
}
}
}
}
private void addBidsToCollapse(int[] bids)
{
if (bidsToCollapse == null) {
bidsToCollapse = new ArraySet<>(bids.length);
}
for (int b:bids) {
if (b != bid) {
bidsToCollapse.add(b);
}
}
}
private void addMatchedGroups(Map<Integer,MatchedGroup> g)
{
for (Integer k:g.keySet()) {
if (!matchedGroups.containsKey(k)) {
matchedGroups.put(k, g.get(k));
}
}
}
private void addMatchedResults(Map<Integer,Object> res)
{
if (res != null) {
for (Integer k:res.keySet()) {
if (!matchedResults.containsKey(k)) {
matchedResults.put(k, res.get(k));
}
}
}
}
}
private static class State
{
int bid;
SequencePattern.State tstate;
public State(int bid, SequencePattern.State tstate) {
this.bid = bid;
this.tstate = tstate;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
State state = (State) o;
if (bid != state.bid) {
return false;
}
if (tstate != null ? !tstate.equals(state.tstate) : state.tstate != null) {
return false;
}
return true;
}
@Override
public int hashCode() {
int result = bid;
result = 31 * result + (tstate != null ? tstate.hashCode() : 0);
return result;
}
}
/**
* Overall information about the branching of paths through the NFA
* (maintained for one attempt at matching, with multiple MatchedStates)
*/
static class BranchStates
{
// Index of global branch id to pair of parent branch id and branch index
// (the branch index is with respect to parent, from 1 to number of branches the parent has)
// TODO: This index can grow rather large, use index that allows for shrinkage
// (has remove function and generate new id every time)
HashIndex<Pair<Integer,Integer>> bidIndex = new HashIndex<>(512);
// Map of branch id to branch state
Map<Integer,BranchState> branchStates = new HashMap<>();//Generics.newHashMap();
// The activeMatchedStates is only kept to determine what branch states are still needed
// It's okay if it overly conservative and has more states than needed,
// And while ideally a set, it's okay to have duplicates (esp if it is a bit faster for normal cases).
Collection<MatchedStates> activeMatchedStates = new ArrayList<>();//= Generics.newHashSet();
/**
* Links specified MatchedStates to us (list of MatchedStates
* is used to determine what branch states still need to be kept)
* @param s
*/
private void link(MatchedStates s)
{
activeMatchedStates.add(s);
}
/**
* Unlinks specified MatchedStates to us (list of MatchedStates
* is used to determine what branch states still need to be kept)
* @param s
*/
private void unlink(MatchedStates s) {
// Make sure all instances of s are removed
while (activeMatchedStates.remove(s)) {}
}
protected int getBid(int parent, int child)
{
return bidIndex.indexOf(new Pair<>(parent,child));
}
protected int newBid(int parent, int child)
{
return bidIndex.addToIndexUnsafe(new Pair<>(parent,child));
}
protected int size()
{
return branchStates.size();
}
/**
* Removes branch states are are no longer needed
*/
private void condense()
{
BitSet keepBidStates = new BitSet();
// Set<Integer> curBidSet = new HashSet<Integer>();//Generics.newHashSet();
// Set<Integer> keepBidStates = new HashSet<Integer>();//Generics.newHashSet();
for (MatchedStates ms:activeMatchedStates) {
// Trim out unneeded states info
List<State> states = ms.states;
if (logger.isLoggable(Level.FINEST)) {
logger.finest("Condense matched state: curPosition=" + ms.curPosition
+ ", totalTokens=" + ms.matcher.elements.size()
+ ", nStates=" + states.size());
}
for (State state: states) {
keepBidStates.set(state.bid);
}
}
for (MatchedStates ms : activeMatchedStates) {
for (State state : (List<State>) ms.states) {
int bid = state.bid;
BranchState bs = getBranchState(bid);
if (bs != null) {
keepBidStates.set(bs.bid);
bs.updateKeepBids(keepBidStates);
if (bs.bidsToCollapse != null) {
mergeBranchStates(bs);
}
}
}
}
Iterator<Integer> iter = branchStates.keySet().iterator();
while (iter.hasNext()) {
int bid = iter.next();
if (!keepBidStates.get(bid)) {
if (logger.isLoggable(Level.FINEST)) {
logger.finest("Remove state for bid=" + bid);
}
iter.remove();
}
}
/* note[gabor]: replaced code below with the above
Collection<Integer> curBidStates = new ArrayList<Integer>(branchStates.keySet());
for (int bid:curBidStates) {
if (!keepBidStates.get(bid)) {
if (logger.isLoggable(Level.FINEST)) {
logger.finest("Remove state for bid=" + bid);
}
branchStates.remove(bid);
}
} */
// TODO: We should be able to trim some bids from our bidIndex as well....
/*
if (bidIndex.size() > 1000) {
logger.warning("Large bid index of size " + bidIndex.size());
}
*/
}
/** A safe version of {@link SequenceMatcher.BranchStates#getParents(int, Integer[])} */
private List<Integer> getParents(int bid) {
List<Integer> pids = new ArrayList<>();
Pair<Integer,Integer> p = bidIndex.get(bid);
while (p != null && p.first() >= 0) {
pids.add(p.first());
p = bidIndex.get(p.first());
}
Collections.reverse(pids);
return pids;
}
/**
* Given a branch id, return a list of parent branches
* @param bid branch id
* @return list of parent branch ids
*/
private List<Integer> getParents(int bid, Integer[] buffer)
{
int index = buffer.length - 1;
buffer[index] = bid;
index -= 1;
Pair<Integer,Integer> p = bidIndex.get(bid);
while (p != null && p.first() >= 0) {
buffer[index] = p.first;
index -= 1;
if (index < 0) {
return getParents(bid); // optimization failed -- back off to the old version
}
p = bidIndex.get(p.first());
}
return Arrays.asList(buffer).subList(index + 1, buffer.length);
}
/**
* Returns the branch state for a given branch id
* (the appropriate ancestor branch state is returned if
* there is no branch state associated with the given branch id)
* @param bid branch id
* @return BranchState associated with the given branch id
*/
protected BranchState getBranchState(int bid)
{
BranchState bs = branchStates.get(bid);
if (bs == null) {
BranchState pbs = null;
int id = bid;
while (pbs == null && id >= 0) {
Pair<Integer, Integer> p = bidIndex.get(id);
id = p.first;
pbs = branchStates.get(id);
}
bs = pbs;
}
return bs;
}
/**
* Returns the branch state for a given branch id
* (the appropriate ancestor branch state is returned if
* there is no branch state associated with the given branch id)
* If add is true, then adds a new branch state for this branch id
* (ensuring that the returned branch state is for the specified branch id)
* @param bid branch id
* @param add whether a new branched state should be added
* @return BranchState associated with the given branch id
*/
protected BranchState getBranchState(int bid, boolean add)
{
BranchState bs = getBranchState(bid);
if (add) {
if (bs == null) {
bs = new BranchState(bid);
} else if (bs.bid != bid) {
bs = new BranchState(bid, bs);
}
branchStates.put(bid, bs);
}
return bs;
}
protected Map<Integer,MatchedGroup> getMatchedGroups(int bid, boolean add)
{
BranchState bs = getBranchState(bid, add);
if (bs == null) {
return null;
}
if (add && bs.matchedGroups == null) {
bs.matchedGroups = new LinkedHashMap<>();
}
return bs.matchedGroups;
}
protected MatchedGroup getMatchedGroup(int bid, int groupId)
{
Map<Integer,MatchedGroup> map = getMatchedGroups(bid, false);
if (map != null) {
return map.get(groupId);
} else {
return null;
}
}
protected void setGroupStart(int bid, int captureGroupId, int curPosition)
{
if (captureGroupId >= 0) {
Map<Integer,MatchedGroup> matchedGroups = getMatchedGroups(bid, true);
MatchedGroup mg = matchedGroups.get(captureGroupId);
if (mg != null) {
// This is possible if we have patterns like "( ... )+" in which case multiple nodes can match as the subgroup
// We will match the first occurrence and use that as the subgroup (Java uses the last match as the subgroup)
logger.fine("Setting matchBegin=" + curPosition + ": Capture group " + captureGroupId + " already exists: " + mg);
}
matchedGroups.put(captureGroupId, new MatchedGroup(curPosition, -1, null));
}
}
protected void setGroupEnd(int bid, int captureGroupId, int curPosition, Object value)
{
if (captureGroupId >= 0) {
Map<Integer,MatchedGroup> matchedGroups = getMatchedGroups(bid, true);
MatchedGroup mg = matchedGroups.get(captureGroupId);
int end = curPosition+1;
if (mg != null) {
if (mg.matchEnd == -1) {
matchedGroups.put(captureGroupId, new MatchedGroup(mg.matchBegin, end, value));
} else {
if (mg.matchEnd != end) {
logger.warning("Cannot set matchEnd=" + end + ": Capture group " + captureGroupId + " already ended: " + mg);
}
}
} else {
logger.warning("Cannot set matchEnd=" + end + ": Capture group " + captureGroupId + " is null");
}
}
}
protected void clearGroupStart(int bid, int captureGroupId)
{
if (captureGroupId >= 0) {
Map<Integer,MatchedGroup> matchedGroups = getMatchedGroups(bid, false);
if (matchedGroups != null) {
matchedGroups.remove(captureGroupId);
}
}
}
protected Map<Integer,Object> getMatchedResults(int bid, boolean add)
{
BranchState bs = getBranchState(bid, add);
if (bs == null) {
return null;
}
if (add && bs.matchedResults == null) {
bs.matchedResults = new LinkedHashMap<>();
}
return bs.matchedResults;
}
protected Object getMatchedResult(int bid, int index)
{
Map<Integer,Object> map = getMatchedResults(bid, false);
if (map != null) {
return map.get(index);
} else {
return null;
}
}
protected void setMatchedResult(int bid, int index, Object obj)
{
if (index >= 0) {
Map<Integer,Object> matchedResults = getMatchedResults(bid, true);
Object oldObj = matchedResults.get(index);
if (oldObj != null) {
logger.warning("Setting matchedResult=" + obj + ": index " + index + " already exists: " + oldObj);
}
matchedResults.put(index, obj);
}
}
protected int getBranchId(int bid, int nextBranchIndex, int nextTotal)
{
if (nextBranchIndex <= 0 || nextBranchIndex > nextTotal) {
throw new IllegalArgumentException("Invalid nextBranchIndex=" + nextBranchIndex + ", nextTotal=" + nextTotal);
}
if (nextTotal == 1) {
return bid;
} else {
Pair<Integer,Integer> p = new Pair<>(bid, nextBranchIndex);
int i = bidIndex.indexOf(p);
if (i < 0) {
for (int j = 0; j < nextTotal; j++) {
bidIndex.add(new Pair<>(bid, j + 1));
}
i = bidIndex.indexOf(p);
}
return i;
}
}
protected Map<SequencePattern.State,Object> getMatchStateInfo(int bid, boolean add)
{
BranchState bs = getBranchState(bid, add);
if (bs == null) {
return null;
}
if (add && bs.matchStateInfo == null) {
bs.matchStateInfo = new LinkedHashMap<>();
}
return bs.matchStateInfo;
}
protected Object getMatchStateInfo(int bid, SequencePattern.State node)
{
Map<SequencePattern.State,Object> matchStateInfo = getMatchStateInfo(bid, false);
return (matchStateInfo != null)? matchStateInfo.get(node):null;
}
protected void removeMatchStateInfo(int bid, SequencePattern.State node)
{
Object obj = getMatchStateInfo(bid, node);
if (obj != null) {
Map<SequencePattern.State,Object> matchStateInfo = getMatchStateInfo(bid, true);
matchStateInfo.remove(node);
}
}
protected void setMatchStateInfo(int bid, SequencePattern.State node, Object obj)
{
Map<SequencePattern.State,Object> matchStateInfo = getMatchStateInfo(bid, true);
matchStateInfo.put(node, obj);
}
protected void startMatchedCountInc(int bid, SequencePattern.State node) {
startMatchedCountInc(bid, node, 1, 1);
}
protected void startMatchedCountDec(int bid, SequencePattern.State node) {
startMatchedCountInc(bid, node, 0, -1);
}
protected void startMatchedCountInc(int bid, SequencePattern.State node, int initialValue, int delta)
{
Map<SequencePattern.State,Object> matchStateCount = getMatchStateInfo(bid, true);
Pair<Integer,Boolean> p = (Pair<Integer,Boolean>) matchStateCount.get(node);
if (p == null) {
matchStateCount.put(node, new Pair<>(initialValue, false));
} else {
matchStateCount.put(node, new Pair<>(p.first() + delta, false));
}
}
protected int endMatchedCountInc(int bid, SequencePattern.State node)
{
Map<SequencePattern.State,Object> matchStateCount = getMatchStateInfo(bid, false);
if (matchStateCount == null) { return 0; }
matchStateCount = getMatchStateInfo(bid, true);
Pair<Integer,Boolean> p = (Pair<Integer,Boolean>) matchStateCount.get(node);
if (p != null) {
int v = p.first();
matchStateCount.put(node, new Pair<>(v, true));
return v;
} else {
return 0;
}
}
protected void clearMatchedCount(int bid, SequencePattern.State node)
{
removeMatchStateInfo(bid, node);
}
protected void setMatchedInterval(int bid, SequencePattern.State node, HasInterval<Integer> interval)
{
Map<SequencePattern.State,Object> matchStateInfo = getMatchStateInfo(bid, true);
HasInterval<Integer> p = (HasInterval<Integer>) matchStateInfo.get(node);
if (p == null) {
matchStateInfo.put(node, interval);
} else {
logger.warning("Interval already exists for bid=" + bid);
}
}
protected HasInterval<Integer> getMatchedInterval(int bid, SequencePattern.State node)
{
Map<SequencePattern.State,Object> matchStateInfo = getMatchStateInfo(bid, true);
HasInterval<Integer> p = (HasInterval<Integer>) matchStateInfo.get(node);
return p;
}
protected void addBidsToCollapse(int bid, int[] bids)
{
BranchState bs = getBranchState(bid, true);
bs.addBidsToCollapse(bids);
}
private void mergeBranchStates(BranchState bs)
{
if (bs.bidsToCollapse != null && bs.bidsToCollapse.size() > 0) {
for (int cbid:bs.bidsToCollapse) {
// Copy over the matched group info
if (cbid != bs.bid) {
BranchState cbs = getBranchState(cbid);
if (cbs != null) {
bs.addMatchedGroups(cbs.matchedGroups);
bs.addMatchedResults(cbs.matchedResults);
} else {
logger.finest("Unable to find state info for bid=" + cbid);
}
}
}
if (bs.collapsedBids == null) {
bs.collapsedBids = bs.bidsToCollapse;
} else {
bs.collapsedBids.addAll(bs.bidsToCollapse);
}
bs.bidsToCollapse = null;
}
}
}
private String getMatchedSignature() {
if (matchedGroups == null) return null;
StringBuilder sb = new StringBuilder();
for (MatchedGroup g : matchedGroups) {
sb.append("(").append(g.matchBegin).append(",").append(g.matchEnd).append(")");
}
return sb.toString();
}
/**
* Utility class that helps us perform pattern matching against a sequence
* Keeps information about:
* <ul>
* <li>the states we need to visit</li>
* <li>the current position in the sequence we are at</li>
* <li>state for each branch we took</li>
* </ul>
* @param <T> Type of node that the matcher is operating on
*/
static class MatchedStates<T>
{
// Sequence matcher with pattern that we are matching against and sequence
final SequenceMatcher<T> matcher;
// Branch states
BranchStates branchStates;
// set of old states along with their branch ids (used to avoid reallocating mem)
List<State> oldStates;
// new states to be explored (along with their branch ids)
List<State> states;
// Current position to match
int curPosition = -1;
// Favor matching longest
boolean matchLongest;
protected MatchedStates(SequenceMatcher<T> matcher, SequencePattern.State state)
{
this(matcher, new BranchStates());
int bid = branchStates.newBid(-1, 0);
states.add(new State(bid,state));
}
private MatchedStates(SequenceMatcher<T> matcher, BranchStates branchStates) {
this.matcher = matcher;
states = new ArrayList<>();
oldStates = new ArrayList<>();
this.branchStates = branchStates;
branchStates.link(this);
}
protected BranchStates getBranchStates()
{
return branchStates;
}
/**
* Split part of the set of states to explore into another MatchedStates
* @param branchLimit - rough limit on the number of branches we want
* to keep in each MatchedStates
* @return new MatchedStates with part of the states still to be explored
*/
protected MatchedStates split(int branchLimit)
{
Set<Integer> curBidSet = new HashSet<>();//Generics.newHashSet();
for (State state:states) {
curBidSet.add(state.bid);
}
List<Integer> bids = new ArrayList<>(curBidSet);
Collections.sort(bids, (o1, o2) -> {
int res = compareMatches(o1, o2);
return res;
});
MatchedStates<T> newStates = new MatchedStates<>(matcher, branchStates);
int v = Math.min(branchLimit, (bids.size()+1)/2);
Set<Integer> keepBidSet = new HashSet<>();//Generics.newHashSet();
keepBidSet.addAll(bids.subList(0, v));
swapAndClear();
for (State s:oldStates) {
if (keepBidSet.contains(s.bid)) {
states.add(s);
} else {
newStates.states.add(s);
}
}
newStates.curPosition = curPosition;
branchStates.condense();
return newStates;
}
protected List<? extends T> elements()
{
return matcher.elements;
}
protected T get()
{
return matcher.get(curPosition);
}
protected int size()
{
return states.size();
}
protected int branchSize()
{
return branchStates.size();
}
private void swap()
{
List<State> tmpStates = oldStates;
oldStates = states;
states = tmpStates;
}
private void swapAndClear()
{
swap();
states.clear();
}
// Attempts to match element at the specified position
private boolean match(int position)
{
curPosition = position;
boolean matched = false;
swapAndClear();
// Start with old state, and try to match next element
// New states to search after successful match will be updated during the match process
for (State state:oldStates) {
if (state.tstate.match(state.bid, this)) {
matched = true;
}
}
// Run NFA to process non consuming states
boolean done = false;
while (!done) {
swapAndClear();
boolean matched0 = false;
for (State state:oldStates) {
if (state.tstate.match0(state.bid, this)) {
matched0 = true;
}
}
done = !matched0;
}
branchStates.condense();
return matched;
}
private final Integer[] p1Buffer = new Integer[128];
private final Integer[] p2Buffer = new Integer[128];
protected int compareMatches(int bid1, int bid2)
{
if (bid1 == bid2) return 0;
List<Integer> p1 = branchStates.getParents(bid1, p1Buffer);
// p1.add(bid1);
List<Integer> p2 = branchStates.getParents(bid2, p2Buffer);
// p2.add(bid2);
int n = Math.min(p1.size(), p2.size());
for (int i = 0; i < n; i++) {
if (p1.get(i) < p2.get(i)) return -1;
if (p1.get(i) > p2.get(i)) return 1;
}
if (p1.size() < p2.size()) return -1;
if (p1.size() > p2.size()) return 1;
return 0;
}
/**
* Returns index of state that results in match (-1 if no matches)
*/
private int getMatchIndex()
{
for (int i = 0; i < states.size(); i++) {
State state = states.get(i);
if (state.tstate.equals(SequencePattern.MATCH_STATE)) {
return i;
}
}
return -1;
}
/**
* Returns a set of indices that results in a match
*/
private Collection<Integer> getMatchIndices()
{
HashSet<Integer> allMatchIndices = new LinkedHashSet<>();// Generics.newHashSet();
for (int i = 0; i < states.size(); i++) {
State state = states.get(i);
if (state.tstate.equals(SequencePattern.MATCH_STATE)) {
allMatchIndices.add(i);
}
}
return allMatchIndices;
}
/**
* Of the potential match indices, selects one and returns it
* (returns -1 if no matches)
*/
private int selectMatchIndex()
{
int best = -1;
int bestbid = -1;
MatchedGroup bestMatched = null;
int bestMatchedLength = -1;
for (int i = 0; i < states.size(); i++) {
State state = states.get(i);
if (state.tstate.equals(SequencePattern.MATCH_STATE)) {
if (best < 0) {
best = i;
bestbid = state.bid;
bestMatched = branchStates.getMatchedGroup(bestbid, 0);
bestMatchedLength = (bestMatched != null)? bestMatched.matchLength() : -1;
} else {
// Compare if this match is better?
int bid = state.bid;
MatchedGroup mg = branchStates.getMatchedGroup(bid, 0);
int matchLength = (mg != null)? mg.matchLength() : -1;
// Select the branch that matched the most
// TODO: Do we need to roll the matchedLength to bestMatchedLength check into the compareMatches?
boolean better;
if (matchLongest) {
better = (matchLength > bestMatchedLength || (matchLength == bestMatchedLength && compareMatches(bestbid, bid) > 0));
} else {
better = compareMatches(bestbid, bid) > 0;
}
if (better) {
bestbid = bid;
best = i;
bestMatched = branchStates.getMatchedGroup(bestbid, 0);
bestMatchedLength = (bestMatched != null)? bestMatched.matchLength() : -1;
}
}
}
}
return best;
}
private void completeMatch()
{
int matchStateIndex = selectMatchIndex();
setMatchedGroups(matchStateIndex);
}
/**
* Set the indices of the matched groups
* @param matchStateIndex
*/
private void setMatchedGroups(int matchStateIndex)
{
matcher.clearMatched();
if (matchStateIndex >= 0) {
State state = states.get(matchStateIndex);
int bid = state.bid;
BranchState bs = branchStates.getBranchState(bid);
if (bs != null) {
branchStates.mergeBranchStates(bs);
Map<Integer,MatchedGroup> matchedGroups = bs.matchedGroups;
if (matchedGroups != null) {
for (int group:matchedGroups.keySet()) {
matcher.matchedGroups[group] = matchedGroups.get(group);
}
}
Map<Integer,Object> matchedResults = bs.matchedResults;
if (matchedResults != null) {
if (matcher.matchedResults == null) {
matcher.matchedResults = new Object[matcher.elements().size()];
}
for (int index:matchedResults.keySet()) {
matcher.matchedResults[index] = matchedResults.get(index);
}
}
}
}
}
private boolean isAllMatch()
{
boolean allMatch = true;
if (states.size() > 0) {
for (State state:states) {
if (!state.tstate.equals(SequencePattern.MATCH_STATE)) {
allMatch = false;
break;
}
}
} else {
allMatch = false;
}
return allMatch;
}
private boolean isMatch()
{
int matchStateIndex = getMatchIndex();
return (matchStateIndex >= 0);
}
protected void addStates(int bid, Collection<SequencePattern.State> newStates)
{
int i = 0;
for (SequencePattern.State s:newStates) {
i++;
int id = branchStates.getBranchId(bid, i, newStates.size());
states.add(new State(id, s));
}
}
protected void addState(int bid, SequencePattern.State state)
{
this.states.add(new State(bid, state));
}
private void clean()
{
branchStates.unlink(this);
branchStates = null;
}
protected void setGroupStart(int bid, int captureGroupId)
{
branchStates.setGroupStart(bid, captureGroupId, curPosition);
}
protected void setGroupEnd(int bid, int captureGroupId, Object value)
{
branchStates.setGroupEnd(bid, captureGroupId, curPosition, value);
}
protected void setGroupEnd(int bid, int captureGroupId, int position, Object value)
{
branchStates.setGroupEnd(bid, captureGroupId, position, value);
}
protected void clearGroupStart(int bid, int captureGroupId)
{
branchStates.clearGroupStart(bid, captureGroupId);
}
}
}