/* The contents of this file are subject to the license and copyright terms
* detailed in the license directory at the root of the source tree (also
* available online at http://fedora-commons.org/license/).
*/
package org.fcrepo.server.security.servletfilters;
import java.security.Principal;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import org.apache.commons.codec.binary.Base64;
import org.fcrepo.server.errors.authorization.AuthzOperationalException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Bill Niebel
*/
public class ExtendedHttpServletRequestWrapper
extends HttpServletRequestWrapper
implements ExtendedHttpServletRequest {
private static final Logger logger =
LoggerFactory.getLogger(ExtendedHttpServletRequestWrapper.class);
private String username = null;
private String password = null;
private String authority;
private Principal userPrincipal;
private boolean wrapperWriteLocked = false;
public final void lockWrapper() throws Exception {
lockSponsoredUser();
wrapperWriteLocked = true;
}
private String sponsoredUser = null; // null == not yet set; "" == no sponsored user; other values valid
private void setSponsoredUser(String sponsoredUser) throws Exception {
if (this.sponsoredUser == null) {
this.sponsoredUser = sponsoredUser;
}
}
public void setSponsoredUser() throws Exception {
String method = "setSponsoredUser";
String sponsoredUser = "";
logger.debug(method + " , isSponsoredUserRequested()=="
+ isSponsoredUserRequested());
if (isSponsoredUserRequested()) {
sponsoredUser = getFromHeader();
logger.debug(method + " , sponsoredUser==" + sponsoredUser);
}
setSponsoredUser(sponsoredUser);
}
public void lockSponsoredUser() throws Exception {
setSponsoredUser("");
}
public void setAuthenticated(Principal userPrincipal, String authority)
throws Exception {
if (wrapperWriteLocked) {
throw new Exception();
}
if (isAuthenticated()) {
throw new Exception();
}
this.userPrincipal = userPrincipal;
this.authority = authority;
}
@Override
public Principal getUserPrincipal() {
//this order reinforces that container-supplied userPrincipal should not be overridden in setUserPrincipal()
Principal userPrincipal = super.getUserPrincipal();
if (userPrincipal == null) {
userPrincipal = this.userPrincipal;
}
return userPrincipal;
}
public final boolean isUserSponsored() {
return !(sponsoredUser == null || sponsoredUser.isEmpty());
}
protected boolean isSponsoredUserRequested() {
String sponsoredUser = getFromHeader();
boolean isSponsoredUserRequested =
!(sponsoredUser == null || sponsoredUser.isEmpty());
return isSponsoredUserRequested;
}
public final boolean isAuthenticated() {
return getUserPrincipal() != null;
}
@Override
public String getRemoteUser() {
String remoteUser = null;
if (isUserSponsored()) {
remoteUser = sponsoredUser;
} else {
remoteUser = super.getRemoteUser();
if (remoteUser == null && userPrincipal != null) {
remoteUser = userPrincipal.getName();
}
}
return remoteUser;
}
private final Map<String, Map<String, Set<?>>> authenticatedAttributes =
new Hashtable<String, Map<String,Set<?>>>();
private final Map<String, Map<String, Set<?>>> sponsoredAttributes =
new Hashtable<String, Map<String,Set<?>>>();
public final void auditInnerMap(Map<String,?> map) {
if (logger.isDebugEnabled()) {
for (Iterator<String> it = map.keySet().iterator(); it.hasNext();) {
String key = it.next();
Object value = map.get(key);
StringBuffer sb = new StringBuffer(key + "==");
String comma = "";
if (value instanceof String) {
sb.append(value);
} else if (value instanceof String[]) {
sb.append("[");
for (int i = 0; i < ((String[]) value).length; i++) {
Object o = ((String[]) value)[i];
if (o instanceof String) {
sb.append(comma + o);
comma = ",";
} else {
sb.append(comma + "UNKNOWN");
comma = ",";
}
}
sb.append("]");
} else if (value instanceof Set) {
sb.append("{");
for (Iterator<?> it2 = ((Set<?>) value).iterator(); it2.hasNext();) {
Object o = it2.next();
if (o instanceof String) {
sb.append(comma + o);
comma = ",";
} else {
sb.append(comma + "UNKNOWN");
comma = ",";
}
}
sb.append("}");
} else {
sb.append("UNKNOWN");
}
logger.debug(sb.toString());
}
}
}
public final void auditInnerSet(Set<?> set) {
if (logger.isDebugEnabled()) {
for (Iterator<?> it = set.iterator(); it.hasNext();) {
Object value = it.next();
if (value instanceof String) {
logger.debug((String) value);
} else {
logger.debug("UNKNOWN");
}
}
}
}
public final void auditOuterMap(Map<String, Map<String, Set<?>>> authenticatedAttributes2,
String desc) {
if (logger.isDebugEnabled()) {
logger.debug("");
logger.debug("auditing " + desc);
for (Iterator<String> it = authenticatedAttributes2.keySet().iterator(); it.hasNext();) {
String authority = it.next();
Map<String,Set<?>> inner = authenticatedAttributes2.get(authority);
logger.debug("{} maps to . . .", authority);
auditInnerMap(inner);
}
}
}
public void audit() {
if (logger.isDebugEnabled()) {
logger.debug("\n===AUDIT===");
logger.debug("auditing wrapped request");
auditOuterMap(authenticatedAttributes, "authenticatedAttributes");
auditOuterMap(sponsoredAttributes, "sponsoredAttributes");
logger.debug("===AUDIT===\n");
}
}
public boolean getAttributeDefined(String key)
throws AuthzOperationalException {
boolean defined = false;
Map<String, Map<String,Set<?>>> map = null;
if (isUserSponsored()) {
map = sponsoredAttributes;
} else {
map = authenticatedAttributes;
}
for (Iterator<Map<String, Set<?>>> iterator = map.values().iterator(); iterator.hasNext();) {
Map<?,?> attributesFromOneAuthority = (Map<?,?>) iterator.next();
if (attributesFromOneAuthority.containsKey(key)) {
defined = true;
break;
}
}
return defined;
}
public Set<?> getAttributeValues(String key) throws AuthzOperationalException {
Set<Object> accumulatedValues4Key = null;
Map<String,Map<String,Set<?>>> map = null;
if (isUserSponsored()) {
map = sponsoredAttributes;
} else {
map = authenticatedAttributes;
}
for (Iterator<Map<String, Set<?>>> iterator = map.values().iterator(); iterator.hasNext();) {
Map<String, Set<?>> attributesFromOneAuthority = iterator.next();
if (attributesFromOneAuthority.containsKey(key)) {
Set<?> someValues4Key = (Set<?>) attributesFromOneAuthority.get(key);
if (someValues4Key != null && !someValues4Key.isEmpty()) {
if (accumulatedValues4Key == null) {
accumulatedValues4Key = new HashSet<Object>();
}
accumulatedValues4Key.addAll(someValues4Key);
}
}
}
if (accumulatedValues4Key == null) {
accumulatedValues4Key = Collections.emptySet();
}
return accumulatedValues4Key;
}
public boolean hasAttributeValues(String key)
throws AuthzOperationalException {
Set<?> temp = getAttributeValues(key);
return !temp.isEmpty();
}
public boolean isAttributeDefined(String key)
throws AuthzOperationalException {
boolean isAttributeDefined;
isAttributeDefined = getAttributeDefined(key);
return isAttributeDefined;
}
private void putMapIntoMap(Map<String, Map<String, Set<?>>> sponsoredAttributes2,
String key, Map<String, Set<?>> attributes) throws Exception {
if (wrapperWriteLocked) {
throw new Exception();
}
if (!isAuthenticated()) {
throw new Exception("can't collect user roles/attributes/groups until after authentication");
}
if (sponsoredAttributes2 == null || key == null || attributes == null) {
throw new Exception("null parm, map==" + sponsoredAttributes2 + ", key==" + key
+ ", value==" + attributes);
}
if (sponsoredAttributes2.containsKey(key)) {
throw new Exception("map already contains key==" + key);
}
logger.debug("mapping {} => {} in {}", key, attributes, sponsoredAttributes2);
sponsoredAttributes2.put(key, attributes);
}
@Override
public void addAttributes(String authority, Map<String, Set<?>> attributes)
throws Exception {
if (isUserSponsored()) {
// after user is sponsored, only sponsored-user roles/attributes/groups are collected
putMapIntoMap(sponsoredAttributes, authority, attributes);
} else {
// before user is sponsored, only authenticated-user roles/attributes/groups are collected
putMapIntoMap(authenticatedAttributes, authority, attributes);
}
}
private Map<String, Set<?>> getAllAttributes(Map<String, Map<String, Set<?>>> attributeGroup) {
Map<String, Set<?>> all = new Hashtable<String, Set<?>>();
for (Iterator<Map<String, Set<?>>> it = attributeGroup.values().iterator();
it.hasNext();) {
Map<String, Set<?>> m = it.next();
all.putAll(m);
}
return all;
}
public Map<String, Set<?>> getAllAttributes() throws AuthzOperationalException {
if (isUserSponsored()) {
return getAllAttributes(sponsoredAttributes);
} else {
return getAllAttributes(authenticatedAttributes);
}
}
public static final String BASIC = "Basic";
private final String[] parseUsernamePassword(String header)
throws Exception {
String here = "parseUsernamePassword():";
String[] usernamePassword = null;
String msg = here + "header intact";
if (header == null || header.isEmpty()) {
String exceptionMsg = msg + FAILED;
logger.error(exceptionMsg + ", header==" + header);
throw new Exception(exceptionMsg);
}
logger.debug("{}{}", msg, SUCCEEDED);
String authschemeUsernamepassword[] = header.split("\\s+");
msg = here + "header split";
if (authschemeUsernamepassword.length != 2) {
String exceptionMsg = msg + FAILED;
logger.error(exceptionMsg + ", header==" + header);
throw new Exception(exceptionMsg);
}
logger.debug("{}{}", msg, SUCCEEDED);
msg = here + "auth scheme";
String authscheme = authschemeUsernamepassword[0];
if (authscheme == null && !BASIC.equalsIgnoreCase(authscheme)) {
String exceptionMsg = msg + FAILED;
logger.error(exceptionMsg + ", authscheme==" + authscheme);
throw new Exception(exceptionMsg);
}
logger.debug("{}{}", msg, SUCCEEDED);
msg = here + "digest non-null";
String usernamepassword = authschemeUsernamepassword[1];
if (usernamepassword == null || usernamepassword.isEmpty()) {
String exceptionMsg = msg + FAILED;
logger.error(exceptionMsg + ", usernamepassword==" + usernamepassword);
throw new Exception(exceptionMsg);
}
logger.debug("{}{}, usernamepassword=={}", msg, SUCCEEDED,
usernamepassword);
byte[] encoded = usernamepassword.getBytes();
if (!Base64.isBase64(encoded)) {
String exceptionMsg = here + "digest base64-encoded" + FAILED;
logger.error(exceptionMsg + ", encoded==" + encoded);
throw new Exception(exceptionMsg);
}
if (logger.isDebugEnabled()) {
logger.debug("{}digest base64-encoded{}, encoded=={}",
here, SUCCEEDED,encoded);
}
byte[] decodedAsByteArray = Base64.decodeBase64(encoded);
logger.debug("{}got decoded bytes{}, decodedAsByteArray=={}",
here, SUCCEEDED, decodedAsByteArray);
String decoded = new String(decodedAsByteArray); //decodedAsByteArray.toString();
logger.debug("{}got decoded string{}, decoded=={}",
here, SUCCEEDED, decoded);
if (decoded == null || decoded.isEmpty()) {
String exceptionMsg = msg + FAILED;
logger.error(exceptionMsg + ", digest decoded==" + decoded);
throw new Exception(exceptionMsg);
}
logger.debug("{}digest decoded{}", here, SUCCEEDED);
char DELIMITER = ':';
if (decoded.indexOf(DELIMITER) < 0) {
String exceptionMsg = "decoded user/password lacks delimiter";
logger.error(exceptionMsg + " . . . throwing exception");
throw new Exception(exceptionMsg);
} else if (decoded.charAt(0) == DELIMITER) {
logger.error("decoded user/password is lacks user . . . returning 0-length strings");
usernamePassword = new String[2];
usernamePassword[0] = "";
usernamePassword[1] = "";
} else if (decoded.charAt(decoded.length()-1) == DELIMITER) { // no password, e.g., user == "guest"
usernamePassword = new String[2];
usernamePassword[0] = decoded.substring(0, decoded.length() - 1);
usernamePassword[1] = "";
} else { // usual, expected case
usernamePassword = new String[2];
int ix = decoded.indexOf(DELIMITER);
usernamePassword[0] = decoded.substring(0, ix);
usernamePassword[1] = decoded.substring(ix + 1);
}
if (usernamePassword.length != 2) {
String exceptionMsg = here + "user/password split" + FAILED;
logger.error(exceptionMsg + ", digest decoded==" + decoded);
throw new Exception(exceptionMsg);
}
logger.debug("{}user/password split{}", here, SUCCEEDED);
return usernamePassword;
}
public static final String AUTHORIZATION = "Authorization";
public final String getAuthorizationHeader() {
logger.debug("getAuthorizationHeader()");
logger.debug("getting this headers");
for (Enumeration<?> enu = getHeaderNames(); enu.hasMoreElements();) {
String name = (String) enu.nextElement();
logger.debug("another headername==" + name);
String value = getHeader(name);
logger.debug("another headervalue==" + value);
}
logger.debug("getting super headers");
for (Enumeration<?> enu = super.getHeaderNames(); enu.hasMoreElements();) {
String name = (String) enu.nextElement();
logger.debug("another headername==" + name);
String value = super.getHeader(name);
logger.debug("another headervalue==" + value);
}
return getHeader(AUTHORIZATION);
}
public static final String FROM = "From";
public final String getFromHeader() {
return getHeader(FROM);
}
public final String getUser() throws Exception {
if (username == null) {
logger.debug("username==null, so will grok now");
String authorizationHeader = getAuthorizationHeader();
logger.debug("authorizationHeader=={}", authorizationHeader);
if (authorizationHeader != null && !authorizationHeader.isEmpty()) {
logger.debug("authorizationHeader is intact");
String[] usernamePassword =
parseUsernamePassword(authorizationHeader);
logger.debug("usernamePassword[] length=="
+ usernamePassword.length);
username = usernamePassword[0];
logger.debug("username (usernamePassword[0])=={}", username);
if (super.getRemoteUser() == null) {
logger.debug("had none before");
} else if (super.getRemoteUser() == username
|| super.getRemoteUser().equals(username)) {
logger.debug("got same now");
} else {
throw new Exception("somebody got it wrong");
}
}
}
logger.debug("return user=={}", username);
return username;
}
public final String getPassword() throws Exception {
if (password == null) {
String authorizationHeader = getAuthorizationHeader();
if (authorizationHeader != null && !authorizationHeader.isEmpty()) {
String[] usernamePassword =
parseUsernamePassword(authorizationHeader);
password = usernamePassword[1];
}
}
logger.debug("return password=={}", password);
return password;
}
public final String getAuthority() {
return authority;
}
public ExtendedHttpServletRequestWrapper(HttpServletRequest wrappedRequest)
throws Exception {
super(wrappedRequest);
}
/**
* @deprecated As of Version 2.1 of the Java Servlet API, use
* {@link ServletContext#getRealPath(java.lang.String)}.
*/
@Override
@Deprecated
public String getRealPath(String path) {
return super.getRealPath(path);
}
/**
* @deprecated As of Version 2.1 of the Java Servlet API, use
* {@link #isRequestedSessionIdFromURL()}.
*/
@Override
@Deprecated
public boolean isRequestedSessionIdFromUrl() {
return isRequestedSessionIdFromURL();
}
@Override
public boolean isSecure() {
if (logger.isDebugEnabled()){
logger.debug("super.isSecure()=={}", super.isSecure());
logger.debug("this.getLocalPort()=={}", getLocalPort());
logger.debug("this.getProtocol()=={}", getProtocol());
logger.debug("this.getServerPort()=={}", getServerPort());
logger.debug("this.getRequestURI()=={}", getRequestURI());
}
return super.isSecure();
}
}