/**
* Copyright 2010 Newcastle University
*
* http://research.ncl.ac.uk/smart/
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.amber.oauth2.rsfilter;
import java.io.IOException;
import java.security.Principal;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.apache.amber.oauth2.common.OAuth;
import org.apache.amber.oauth2.common.error.OAuthError;
import org.apache.amber.oauth2.common.exception.OAuthProblemException;
import org.apache.amber.oauth2.common.exception.OAuthSystemException;
import org.apache.amber.oauth2.common.message.OAuthResponse;
import org.apache.amber.oauth2.common.message.types.ParameterStyle;
import org.apache.amber.oauth2.rs.request.OAuthAccessResourceRequest;
import org.apache.amber.oauth2.rs.response.OAuthRSResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Maciej Machulak (m.p.machulak@ncl.ac.uk)
* @author Lukasz Moren (lukasz.moren@ncl.ac.uk)
* @author Aad van Moorsel (aad.vanmoorsel@ncl.ac.uk)
*/
public class OAuthFilter implements Filter {
private Logger log = LoggerFactory.getLogger(getClass());
public static final String OAUTH_RS_PROVIDER_CLASS = "oauth.rs.provider-class";
public static final String RS_REALM = "oauth.rs.realm";
public static final String RS_REALM_DEFAULT = "OAuth Protected Service";
public static final String RS_TOKENS = "oauth.rs.tokens";
public static final ParameterStyle RS_TOKENS_DEFAULT = ParameterStyle.HEADER;
private static final String TOKEN_DELIMITER = ",";
private String realm;
private OAuthRSProvider provider;
private ParameterStyle[] parameterStyles;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
provider = OAuthUtils.initiateServletContext(
filterConfig.getServletContext(), OAUTH_RS_PROVIDER_CLASS,
OAuthRSProvider.class);
realm = filterConfig.getServletContext().getInitParameter(RS_REALM);
if (OAuthUtils.isEmpty(realm)) {
realm = RS_REALM_DEFAULT;
}
String parameterStylesString = filterConfig.getServletContext()
.getInitParameter(RS_TOKENS);
if (OAuthUtils.isEmpty(parameterStylesString)) {
parameterStyles = new ParameterStyle[] { RS_TOKENS_DEFAULT };
} else {
String[] parameters = parameterStylesString.split(TOKEN_DELIMITER);
if (parameters != null && parameters.length > 0) {
parameterStyles = new ParameterStyle[parameters.length];
for (int i = 0; i < parameters.length; i++) {
ParameterStyle tempParameterStyle = ParameterStyle
.valueOf(parameters[i]);
if (tempParameterStyle != null) {
parameterStyles[i] = tempParameterStyle;
} else {
throw new ServletException("Incorrect ParameterStyle: "
+ parameters[i]);
}
}
}
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse res = (HttpServletResponse) response;
try {
log.debug("Filtering {}", req.getRequestURI());
// Make an OAuth Request out of this servlet request
OAuthAccessResourceRequest oauthRequest = new OAuthAccessResourceRequest(
req, parameterStyles);
// Get the access token
String accessToken = oauthRequest.getAccessToken();
log.debug("Filtering token: {}", accessToken);
final OAuthDecision decision = provider.validateRequest(realm,
accessToken, req);
if (!decision.isAuthorized()) {
OAuthResponse oauthResponse = OAuthRSResponse
.errorResponse(HttpServletResponse.SC_UNAUTHORIZED)
.setRealm(realm).setError("Invalid token")
.setErrorDescription("Please authorize")
.buildHeaderMessage();
res.addHeader(OAuth.HeaderType.WWW_AUTHENTICATE, oauthResponse
.getHeader(OAuth.HeaderType.WWW_AUTHENTICATE));
res.sendError(oauthResponse.getResponseStatus());
return;
}
final Principal principal = decision.getPrincipal();
request = new HttpServletRequestWrapper(
(HttpServletRequest) request) {
@Override
public String getRemoteUser() {
return principal != null ? principal.getName() : null;
}
@Override
public Principal getUserPrincipal() {
return principal;
}
};
request.setAttribute(OAuth.OAUTH_CLIENT_ID, decision
.getOAuthClient().getClientId());
request.setAttribute(OAuth.OAUTH_TOKEN, accessToken);
chain.doFilter(request, response);
return;
} catch (OAuthSystemException e1) {
throw new ServletException(e1);
} catch (OAuthProblemException e) {
log.error("OAuth exception", e);
respondWithError(res, e);
return;
}
}
@Override
public void destroy() {
}
private void respondWithError(HttpServletResponse resp,
OAuthProblemException error) throws IOException, ServletException {
OAuthResponse oauthResponse = null;
try {
if (OAuthUtils.isEmpty(error.getError())) {
oauthResponse = OAuthRSResponse
.errorResponse(HttpServletResponse.SC_UNAUTHORIZED)
.setRealm(realm).buildHeaderMessage();
} else {
int responseCode = 401;
if (error.getError().equals(
OAuthError.CodeResponse.INVALID_REQUEST)) {
responseCode = 400;
} else if (error.getError().equals(
OAuthError.ResourceResponse.INSUFFICIENT_SCOPE)) {
responseCode = 403;
}
oauthResponse = OAuthRSResponse.errorResponse(responseCode)
.setRealm(realm).setError(error.getError())
.setErrorDescription(error.getDescription())
.setErrorUri(error.getUri()).buildHeaderMessage();
}
resp.addHeader(OAuth.HeaderType.WWW_AUTHENTICATE,
oauthResponse.getHeader(OAuth.HeaderType.WWW_AUTHENTICATE));
resp.sendError(oauthResponse.getResponseStatus());
} catch (OAuthSystemException e) {
throw new ServletException(e);
}
}
}