package org.rhegium.servlet.internal; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.StringTokenizer; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.rhegium.servlet.api.DispatchedAction; import org.rhegium.servlet.api.RequestDispatcher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class DefaultRequestDispatcher implements RequestDispatcher { private static final Logger LOG = LoggerFactory.getLogger(DefaultRequestDispatcher.class); private static final String PATH_SEPERATOR = "/"; private static final int WEIGHT_DIRECT = 2; private static final int WEIGHT_WILDCARD = 1; private final Map<String, DispatchedAction> dispatchedActions = new HashMap<>(); private final ReadWriteLock lock = new ReentrantReadWriteLock(); @Override public void dispatch(String requestUri, HttpServletRequest request, HttpServletResponse response) throws ServletException { Lock l = lock.readLock(); try { DispatchedAction action = null; int actionWeight = Integer.MIN_VALUE; Iterator<Entry<String, DispatchedAction>> iterator = dispatchedActions.entrySet().iterator(); while (iterator.hasNext()) { Entry<String, DispatchedAction> entry = iterator.next(); int weight = weightPattern(entry.getKey(), requestUri); if (LOG.isDebugEnabled()) { LOG.debug(String.format("DispatchedAction with pattern %s finished in a weight of %d", entry.getKey(), weight)); } if (actionWeight > weight) { actionWeight = weight; action = entry.getValue(); } } if (action == null) { throw new ServletException("No DispatchedAction was found for path " + requestUri); } action.handleRequest(request, response); } finally { l.unlock(); } } @Override public void registerDispatchedAction(String pattern, DispatchedAction dispatchedAction) { Lock l = lock.writeLock(); try { dispatchedActions.put(pattern, dispatchedAction); } finally { l.unlock(); } } @Override public void removeDispatchedAction(String pattern) { Lock l = lock.writeLock(); try { dispatchedActions.remove(pattern); } finally { l.unlock(); } } @Override public void removeDispatchedAction(DispatchedAction dispatchedAction) { Lock l = lock.writeLock(); try { Iterator<Entry<String, DispatchedAction>> iterator = dispatchedActions.entrySet().iterator(); while (iterator.hasNext()) { if (iterator.next().getValue() == dispatchedAction) { iterator.remove(); break; } } } finally { l.unlock(); } } /* * This is based on Springs org.springframework.util.AntPathMatcher but * changed to weight the result of the matching to find best matching * pattern */ private int weightPattern(String pattern, String requestUri) { if (requestUri.startsWith(PATH_SEPERATOR) != pattern.startsWith(PATH_SEPERATOR)) { return -1; } String[] patternTokens = tokenize(pattern, PATH_SEPERATOR); String[] uriTokens = tokenize(requestUri, PATH_SEPERATOR); int patternIndexStart = 0; int patternIndexEnd = patternTokens.length - 1; int uriIndexStart = 0; int uriIndexEnd = uriTokens.length - 1; int weight = 0; // Match all elements up to the first ** while (patternIndexStart <= patternIndexEnd && uriIndexStart <= uriIndexEnd) { String token = patternTokens[patternIndexStart]; if ("**".equals(token)) { weight += WEIGHT_WILDCARD; break; } if (!token.equals(uriTokens[uriIndexStart])) { return -1; } patternIndexStart++; uriIndexStart++; weight += WEIGHT_DIRECT; } if (uriIndexStart > uriIndexEnd) { if (patternIndexStart > patternIndexEnd) { if (pattern.endsWith(PATH_SEPERATOR) ? requestUri.endsWith(PATH_SEPERATOR) : !requestUri.endsWith(PATH_SEPERATOR)) { return weight; } if (patternIndexStart == patternIndexEnd && patternTokens[patternIndexStart].equals("*") && requestUri.endsWith(PATH_SEPERATOR)) { return (weight += WEIGHT_WILDCARD); } for (int i = patternIndexStart; i < patternIndexEnd; i++) { if (!patternTokens[i].equals("**")) { return -1; } weight += WEIGHT_WILDCARD; } } return weight; } else if (patternIndexStart > patternIndexEnd) { return -1; } else if ("**".equals(patternTokens[patternIndexStart])) { return (weight += WEIGHT_WILDCARD); } while (patternIndexStart != patternIndexEnd && uriIndexStart <= uriIndexEnd) { int uriIndexTemp = -1; for (int i = patternIndexStart + 1; i <= patternIndexEnd; i++) { if ("**".equals(patternTokens[i])) { uriIndexTemp = i; weight += WEIGHT_WILDCARD; break; } } if (uriIndexTemp == patternIndexStart + 1) { patternIndexStart++; continue; } int uriLength = (uriIndexTemp - patternIndexStart - 1); int stringLength = (uriIndexEnd - uriIndexStart + 1); int foundIndex = -1; for (int i = 0; i < stringLength - uriIndexEnd; i++) { if (!matchPattern(patternTokens, uriTokens, patternIndexStart, uriIndexStart, uriLength, i)) { continue; } weight += WEIGHT_DIRECT; foundIndex = uriIndexStart + 1; break; } if (foundIndex == -1) { return -1; } patternIndexStart = uriIndexTemp; uriIndexStart = foundIndex + uriLength; } for (int i = patternIndexStart; i < patternIndexEnd; i++) { if (!"**".equals(patternTokens[i])) { return -1; } weight += WEIGHT_WILDCARD; } return weight; } private boolean matchPattern(String[] patternTokens, String[] uriTokens, int patternIndexStart, int uriIndexStart, int uriLength, int index) { for (int i = 0; i < uriLength; i++) { String subPattern = patternTokens[patternIndexStart + i + 1]; String subUri = uriTokens[uriIndexStart + index + i]; if (subPattern.equals(subUri)) { return false; } } return true; } private String[] tokenize(String path, String delimiter) { if (path == null) { return null; } StringTokenizer tokenizer = new StringTokenizer(path, delimiter); List<String> tokens = new ArrayList<>(); while (tokenizer.hasMoreTokens()) { String token = tokenizer.nextToken().trim(); if (token.length() > 0) { tokens.add(token); } } return tokens.toArray(new String[tokens.size()]); } }