/* The contents of this file are subject to the license and copyright terms * detailed in the license directory at the root of the source tree (also * available online at http://fedora-commons.org/license/). */ package fedora.server.security.servletfilters; import java.security.Principal; import java.util.Enumeration; import java.util.HashSet; import java.util.Hashtable; import java.util.Iterator; import java.util.Map; import java.util.Set; import javax.servlet.ServletContext; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import org.apache.commons.codec.binary.Base64; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import fedora.server.errors.authorization.AuthzOperationalException; /** * @author Bill Niebel */ public class ExtendedHttpServletRequestWrapper extends HttpServletRequestWrapper implements ExtendedHttpServletRequest { private static Log log = LogFactory.getLog(ExtendedHttpServletRequestWrapper.class); private String username = null; private String password = null; private String authority; private Principal userPrincipal; private boolean wrapperWriteLocked = false; public final void lockWrapper() throws Exception { lockSponsoredUser(); wrapperWriteLocked = true; } private String sponsoredUser = null; // null == not yet set; "" == no sponsored user; other values valid private void setSponsoredUser(String sponsoredUser) throws Exception { if (this.sponsoredUser == null) { this.sponsoredUser = sponsoredUser; } } public void setSponsoredUser() throws Exception { String method = "setSponsoredUser"; String sponsoredUser = ""; log.debug(method + " , isSponsoredUserRequested()==" + isSponsoredUserRequested()); if (isSponsoredUserRequested()) { sponsoredUser = getFromHeader(); log.debug(method + " , sponsoredUser==" + sponsoredUser); } setSponsoredUser(sponsoredUser); } public void lockSponsoredUser() throws Exception { setSponsoredUser(""); } public void setAuthenticated(Principal userPrincipal, String authority) throws Exception { if (wrapperWriteLocked) { throw new Exception(); } if (isAuthenticated()) { throw new Exception(); } this.userPrincipal = userPrincipal; this.authority = authority; } public Principal getUserPrincipal() { //this order reinforces that container-supplied userPrincipal should not be overridden in setUserPrincipal() Principal userPrincipal = super.getUserPrincipal(); if (userPrincipal == null) { userPrincipal = this.userPrincipal; } return userPrincipal; } public final boolean isUserSponsored() { return !(sponsoredUser == null || "".equals(sponsoredUser)); } protected boolean isSponsoredUserRequested() { String sponsoredUser = getFromHeader(); boolean isSponsoredUserRequested = !(sponsoredUser == null || "".equals(sponsoredUser)); return isSponsoredUserRequested; } public final boolean isAuthenticated() { return getUserPrincipal() != null; } public String getRemoteUser() { String remoteUser = null; if (isUserSponsored()) { remoteUser = sponsoredUser; } else { remoteUser = super.getRemoteUser(); if (remoteUser == null && userPrincipal != null) { remoteUser = userPrincipal.getName(); } } return remoteUser; } private final Map authenticatedAttributes = new Hashtable(); private final Map sponsoredAttributes = new Hashtable(); public final void auditInnerMap(Map map) { if (log.isDebugEnabled()) { for (Iterator it = map.keySet().iterator(); it.hasNext();) { String key = (String) it.next(); Object value = map.get(key); StringBuffer sb = new StringBuffer(key + "=="); String comma = ""; if (value instanceof String) { sb.append(value); } else if (value instanceof String[]) { sb.append("["); for (int i = 0; i < ((String[]) value).length; i++) { Object o = ((String[]) value)[i]; if (o instanceof String) { sb.append(comma + o); comma = ","; } else { sb.append(comma + "UNKNOWN"); comma = ","; } } sb.append("]"); } else if (value instanceof Set) { sb.append("{"); for (Iterator it2 = ((Set) value).iterator(); it2.hasNext();) { Object o = it2.next(); if (o instanceof String) { sb.append(comma + o); comma = ","; } else { sb.append(comma + "UNKNOWN"); comma = ","; } } sb.append("}"); } else { sb.append("UNKNOWN"); } log.debug(sb.toString()); } } } public final void auditInnerSet(Set set) { if (log.isDebugEnabled()) { for (Iterator it = set.iterator(); it.hasNext();) { Object value = it.next(); if (value instanceof String) { log.debug(value); } else { log.debug("UNKNOWN"); } } } } public final void auditOuterMap(Map map, String desc) { if (log.isDebugEnabled()) { log.debug(""); log.debug("auditing " + desc); for (Iterator it = map.keySet().iterator(); it.hasNext();) { Object key = it.next(); Object inner = map.get(key); String authority = ""; if (key instanceof String) { authority = (String) key; } else { authority = "<authority not a string>"; } if (inner instanceof Map) { log.debug(authority + " maps to . . ."); auditInnerMap((Map) inner); } else if (inner instanceof Set) { log.debug(authority + " maps to . . ."); auditInnerSet((Set) inner); } else { log.debug(authority + " maps to an unknown object==" + map.getClass().getName()); } } } } public void audit() { if (log.isDebugEnabled()) { log.debug("\n===AUDIT==="); log.debug("auditing wrapped request"); auditOuterMap(authenticatedAttributes, "authenticatedAttributes"); auditOuterMap(sponsoredAttributes, "sponsoredAttributes"); log.debug("===AUDIT===\n"); } } public boolean getAttributeDefined(String key) throws AuthzOperationalException { boolean defined = false; Map map = null; if (isUserSponsored()) { map = sponsoredAttributes; } else { map = authenticatedAttributes; } for (Iterator iterator = map.values().iterator(); iterator.hasNext();) { Map attributesFromOneAuthority = (Map) iterator.next(); if (attributesFromOneAuthority.containsKey(key)) { defined = true; break; } } return defined; } public Set getAttributeValues(String key) throws AuthzOperationalException { Set accumulatedValues4Key = null; Map map = null; if (isUserSponsored()) { map = sponsoredAttributes; } else { map = authenticatedAttributes; } for (Iterator iterator = map.values().iterator(); iterator.hasNext();) { Map attributesFromOneAuthority = (Map) iterator.next(); if (attributesFromOneAuthority.containsKey(key)) { Set someValues4Key = (Set) attributesFromOneAuthority.get(key); if (someValues4Key != null && !someValues4Key.isEmpty()) { if (accumulatedValues4Key == null) { accumulatedValues4Key = new HashSet(); } accumulatedValues4Key.addAll(someValues4Key); } } } if (accumulatedValues4Key == null) { accumulatedValues4Key = IMMUTABLE_NULL_SET; } return accumulatedValues4Key; } public boolean hasAttributeValues(String key) throws AuthzOperationalException { Set temp = getAttributeValues(key); return !temp.isEmpty(); } public boolean isAttributeDefined(String key) throws AuthzOperationalException { boolean isAttributeDefined; isAttributeDefined = getAttributeDefined(key); return isAttributeDefined; } private void putIntoMap(Map map, String key, Object value) throws Exception { if (wrapperWriteLocked) { throw new Exception(); } if (!isAuthenticated()) { throw new Exception("can't collect user roles/attributes/groups until after authentication"); } if (map == null || key == null || value == null) { throw new Exception("null parm, map==" + map + ", key==" + key + ", value==" + value); } if (map.containsKey(key)) { throw new Exception("map already contains key==" + key); } log.debug("mapping " + key + " => " + value + " in " + map); map.put(key, value); } private void putMapIntoMap(Map map, String key, Object value) throws Exception { if (!(value instanceof Map)) { throw new Exception("input parm must be a map"); } putIntoMap(map, key, value); } public void addAttributes(String authority, Map attributes) throws Exception { if (isUserSponsored()) { // after user is sponsored, only sponsored-user roles/attributes/groups are collected putMapIntoMap(sponsoredAttributes, authority, attributes); } else { // before user is sponsored, only authenticated-user roles/attributes/groups are collected putMapIntoMap(authenticatedAttributes, authority, attributes); } } private Map getAllAttributes(Map attributeGroup) { Map all = new Hashtable(); for (Iterator it = attributeGroup.values().iterator(); it.hasNext();) { Map m = (Map) it.next(); all.putAll(m); } return all; } public Map getAllAttributes() throws AuthzOperationalException { Map all = null; if (isUserSponsored()) { all = getAllAttributes(sponsoredAttributes); } else { all = getAllAttributes(authenticatedAttributes); } return all; } public static final String BASIC = "Basic"; private final String[] parseUsernamePassword(String header) throws Exception { String here = "parseUsernamePassword():"; String[] usernamePassword = null; String msg = here + "header intact"; if (header == null || "".equals(header)) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", header==" + header); throw new Exception(exceptionMsg); } log.debug(msg + SUCCEEDED); String authschemeUsernamepassword[] = header.split("\\s+"); msg = here + "header split"; if (authschemeUsernamepassword.length != 2) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", header==" + header); throw new Exception(exceptionMsg); } log.debug(msg + SUCCEEDED); msg = here + "auth scheme"; String authscheme = authschemeUsernamepassword[0]; if (authscheme == null && !BASIC.equalsIgnoreCase(authscheme)) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", authscheme==" + authscheme); throw new Exception(exceptionMsg); } log.debug(msg + SUCCEEDED); msg = here + "digest non-null"; String usernamepassword = authschemeUsernamepassword[1]; if (usernamepassword == null || "".equals(usernamepassword)) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", usernamepassword==" + usernamepassword); throw new Exception(exceptionMsg); } log.debug(msg + SUCCEEDED + ", usernamepassword==" + usernamepassword); byte[] encoded = usernamepassword.getBytes(); msg = here + "digest base64-encoded"; if (!Base64.isArrayByteBase64(encoded)) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", encoded==" + encoded); throw new Exception(exceptionMsg); } if (log.isDebugEnabled()) { log.debug(msg + SUCCEEDED + ", encoded==" + encoded); } byte[] decodedAsByteArray = Base64.decodeBase64(encoded); log.debug(here + "got decoded bytes" + SUCCEEDED + ", decodedAsByteArray==" + decodedAsByteArray); String decoded = new String(decodedAsByteArray); //decodedAsByteArray.toString(); log.debug(here + "got decoded string" + SUCCEEDED + ", decoded==" + decoded); msg = here + "digest decoded"; if (decoded == null || "".equals(decoded)) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", digest decoded==" + decoded); throw new Exception(exceptionMsg); } log.debug(msg + SUCCEEDED); String DELIMITER = ":"; if (decoded == null) { log .error("decoded user/password is null . . . returning 0-length strings"); usernamePassword = new String[2]; usernamePassword[0] = ""; usernamePassword[1] = ""; } else if (decoded.indexOf(DELIMITER) < 0) { String exceptionMsg = "decoded user/password lacks delimiter"; log.fatal(exceptionMsg + " . . . throwing exception"); throw new Exception(exceptionMsg); } else if (decoded.startsWith(DELIMITER)) { log .error("decoded user/password is lacks user . . . returning 0-length strings"); usernamePassword = new String[2]; usernamePassword[0] = ""; usernamePassword[1] = ""; } else if (decoded.endsWith(DELIMITER)) { // no password, e.g., user == "guest" usernamePassword = new String[2]; usernamePassword[0] = decoded.substring(0, decoded.length() - 1); usernamePassword[1] = ""; } else { // usual, expected case usernamePassword = decoded.split(DELIMITER); } msg = here + "user/password split"; if (usernamePassword.length != 2) { String exceptionMsg = msg + FAILED; log.fatal(exceptionMsg + ", digest decoded==" + decoded); throw new Exception(exceptionMsg); } log.debug(msg + SUCCEEDED); return usernamePassword; } public static final String AUTHORIZATION = "Authorization"; public final String getAuthorizationHeader() { log.debug("getAuthorizationHeader()"); log.debug("getting this headers"); for (Enumeration enu = getHeaderNames(); enu.hasMoreElements();) { String name = (String) enu.nextElement(); log.debug("another headername==" + name); String value = getHeader(name); log.debug("another headervalue==" + value); } log.debug("getting super headers"); for (Enumeration enu = super.getHeaderNames(); enu.hasMoreElements();) { String name = (String) enu.nextElement(); log.debug("another headername==" + name); String value = super.getHeader(name); log.debug("another headervalue==" + value); } return getHeader(AUTHORIZATION); } public static final String FROM = "From"; public final String getFromHeader() { return getHeader(FROM); } public final String getUser() throws Exception { if (username == null) { log.debug("username==null, so will grok now"); String authorizationHeader = getAuthorizationHeader(); log.debug("authorizationHeader==" + authorizationHeader); if (authorizationHeader != null && !"".equals(authorizationHeader)) { log.debug("authorizationHeader is intact"); String[] usernamePassword = parseUsernamePassword(authorizationHeader); log.debug("usernamePassword[] length==" + usernamePassword.length); username = usernamePassword[0]; log.debug("username (usernamePassword[0])==" + username); if (super.getRemoteUser() == null) { log.debug("had none before"); } else if (super.getRemoteUser() == username || super.getRemoteUser().equals(username)) { log.debug("got same now"); } else { throw new Exception("somebody got it wrong"); } } } log.debug("return user==" + username); return username; } public final String getPassword() throws Exception { if (password == null) { String authorizationHeader = getAuthorizationHeader(); if (authorizationHeader != null && !"".equals(authorizationHeader)) { String[] usernamePassword = parseUsernamePassword(authorizationHeader); password = usernamePassword[1]; } } log.debug("return password==" + password); return password; } public final String getAuthority() { return authority; } public ExtendedHttpServletRequestWrapper(HttpServletRequest wrappedRequest) throws Exception { super(wrappedRequest); } /** * @deprecated As of Version 2.1 of the Java Servlet API, use * {@link ServletContext#getRealPath(java.lang.String)}. */ @Deprecated public String getRealPath(String path) { return super.getRealPath(path); } /** * @deprecated As of Version 2.1 of the Java Servlet API, use * {@link #isRequestedSessionIdFromURL()}. */ @Deprecated public boolean isRequestedSessionIdFromUrl() { return isRequestedSessionIdFromURL(); } public boolean isSecure() { log.debug("super.isSecure()==" + super.isSecure()); log.debug("this.getLocalPort()==" + getLocalPort()); log.debug("this.getProtocol()==" + getProtocol()); log.debug("this.getServerPort()==" + getServerPort()); log.debug("this.getRequestURI()==" + getRequestURI()); return super.isSecure(); } }