package org.xdi.oxauth.rp.demo; import com.google.common.base.Preconditions; import com.google.common.base.Strings; import org.apache.log4j.Logger; import org.xdi.oxauth.client.*; import org.xdi.oxauth.model.common.AuthenticationMethod; import org.xdi.oxauth.model.common.GrantType; import javax.servlet.*; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.PrintWriter; /** * @author yuriyz on 07/19/2016. */ public class LoginFilter implements Filter { private static final Logger LOG = Logger.getLogger(LoginFilter.class); public static final String WELL_KNOWN_CONNECT_PATH = "/.well-known/openid-configuration"; private String authorizeParameters; private String redirectUri; private String authorizationServerHost; private String clientId; private String clientSecret; private OpenIdConfigurationResponse discoveryResponse; @Override public void init(FilterConfig filterConfig) throws ServletException { authorizeParameters = filterConfig.getInitParameter("authorizeParameters"); redirectUri = filterConfig.getInitParameter("redirectUri"); authorizationServerHost = filterConfig.getInitParameter("authorizationServerHost"); clientId = filterConfig.getInitParameter("clientId"); clientSecret = filterConfig.getInitParameter("clientSecret"); Preconditions.checkState(redirectUri.startsWith("https:"), "Redirect URI must use https protocol for client application_type=web."); } @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletResponse response = (HttpServletResponse) servletResponse; boolean redirectForLogin = fetchTokenIfCodeIsPresent(request); Object accessToken = request.getSession(true).getAttribute("access_token"); if (accessToken == null) { if (redirectForLogin) { redirectToLogin(request, response); } else { LOG.trace("Login failed."); response.setContentType("text/html;charset=utf-8"); PrintWriter pw = response.getWriter(); pw.println("<h3>Login failed.</h3>"); } } else { LOG.trace("User is already authenticated."); filterChain.doFilter(servletRequest, servletResponse); } } private void fetchDiscovery(HttpServletRequest request) { try { if (discoveryResponse != null) { // already initialized return; } OpenIdConfigurationClient discoveryClient = new OpenIdConfigurationClient(authorizationServerHost + WELL_KNOWN_CONNECT_PATH); discoveryClient.setExecutor(Utils.createTrustAllExecutor()); discoveryResponse = discoveryClient.execOpenIdConfiguration(); LOG.trace("Discovery: " + discoveryResponse); if (discoveryResponse.getStatus() == 200) { return; } } catch (Exception e) { LOG.error(e.getMessage(), e); } throw new RuntimeException("Failed to fetch discovery information at : " + authorizationServerHost + WELL_KNOWN_CONNECT_PATH); } /** * @param request request * @return whether login is still required */ private boolean fetchTokenIfCodeIsPresent(HttpServletRequest request) { String code = request.getParameter("code"); if (code != null && !code.trim().isEmpty()) { LOG.trace("Fetching token for code " + code + " ..."); fetchDiscovery(request); TokenRequest tokenRequest = new TokenRequest(GrantType.AUTHORIZATION_CODE); tokenRequest.setCode(code); tokenRequest.setRedirectUri(redirectUri); tokenRequest.setAuthUsername(clientId); tokenRequest.setAuthPassword(clientSecret); tokenRequest.setAuthenticationMethod(AuthenticationMethod.CLIENT_SECRET_BASIC); TokenClient tokenClient = new TokenClient(discoveryResponse.getTokenEndpoint()); tokenClient.setExecutor(Utils.createTrustAllExecutor()); tokenClient.setRequest(tokenRequest); TokenResponse tokenResponse = tokenClient.exec(); if (!Strings.isNullOrEmpty(tokenResponse.getAccessToken())) { LOG.trace("Token is successfully fetched."); LOG.trace("Put in session access_token: " + tokenResponse.getAccessToken() + ", id_token: " + tokenResponse.getIdToken() + ", userinfo_endpoint: " + discoveryResponse.getUserInfoEndpoint()); request.getSession(true).setAttribute("access_token", tokenResponse.getAccessToken()); request.getSession(true).setAttribute("id_token", tokenResponse.getIdToken()); request.getSession(true).setAttribute("userinfo_endpoint", discoveryResponse.getUserInfoEndpoint()); } else { LOG.trace("Failed to obtain token. Status: " + tokenResponse.getStatus() + ", entity: " + tokenResponse.getEntity()); } return false; } return true; } private void redirectToLogin(HttpServletRequest request, HttpServletResponse response) throws IOException { fetchDiscovery(request); String redirectTo = discoveryResponse.getAuthorizationEndpoint() + "?redirect_uri=" + redirectUri + "&client_id=" + clientId + "&" + authorizeParameters; LOG.trace("Redirecting to authorization url : " + redirectTo); response.sendRedirect(redirectTo); } @Override public void destroy() { } }