/*
* JBoss, a division of Red Hat
* Copyright 2012, Red Hat Middleware, LLC, and individual
* contributors as indicated by the @authors tag. See the
* copyright.txt in the distribution for a full listing of
* individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.gatein.sso.saml.plugin;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.util.EntityUtils;
import org.apache.log4j.Logger;
import org.gatein.sso.plugin.RestCallbackCaller;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;
import java.security.Principal;
import java.security.acl.Group;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Login module, which can be executed on SAML Identity provider side. It executes REST requests to GateIn to verify authentication of single user
* against GateIn or obtain list of roles from GateIn.
*
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class SAML2IdpLoginModule implements LoginModule
{
// This option can have two values: "STATIC" or "PORTAL_CALLBACK"
// "STATIC" means that roles of authenticated user will be statically obtained from "staticRolesList", which means that all users will have same list of roles.
// "PORTAL_CALLBACK" means that roles will be obtained from GateIn via callback request to GateIn REST service
private static final String OPTION_ROLES_PROCESSING = "rolesProcessing";
// This option is valid only if rolesProcessing is STATIC. It contains list of static roles, which will be assigned to each authenticated user.
private static final String OPTION_STATIC_ROLES_LIST = "staticRolesList";
// gateIn URL related property, which will be used to send REST callback requests
private static final String OPTION_GATEIN_URL = "gateInURL";
// HTTP method ("POST" or "GET") which will be used to send REST callback requests
private static final String OPTION_HTTP_METHOD = "httpMethod";
private static Logger log = Logger.getLogger(SAML2IdpLoginModule.class);
private Subject subject;
private CallbackHandler callbackHandler;
@SuppressWarnings("unchecked")
private Map sharedState;
@SuppressWarnings("unchecked")
private Map options;
private String gateInURL;
private String httpMethod;
private ROLES_PROCESSING_TYPE rolesProcessingType;
private List<String> staticRolesList;
public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options)
{
this.subject = subject;
this.callbackHandler = callbackHandler;
this.sharedState = sharedState;
this.options = options;
// Read options for this login module
String rolesProcessingType = readOption(OPTION_ROLES_PROCESSING, "STATIC");
if ("STATIC".equals(rolesProcessingType) || "PORTAL_CALLBACK".equals(rolesProcessingType))
{
this.rolesProcessingType = ROLES_PROCESSING_TYPE.valueOf(rolesProcessingType);
}
else
{
this.rolesProcessingType = ROLES_PROCESSING_TYPE.STATIC;
}
String staticRoles = readOption(OPTION_STATIC_ROLES_LIST, "users");
this.staticRolesList = Arrays.asList(staticRoles.split(","));
this.gateInURL = readOption(OPTION_GATEIN_URL, "http://localhost:8080/portal");
this.httpMethod = readOption(OPTION_HTTP_METHOD, "POST");
}
public boolean login() throws LoginException
{
try
{
Callback[] callbacks = new Callback[2];
callbacks[0] = new NameCallback("Username");
callbacks[1] = new PasswordCallback("Password", false);
callbackHandler.handle(callbacks);
String username = ((NameCallback)callbacks[0]).getName();
String password = new String(((PasswordCallback)callbacks[1]).getPassword());
((PasswordCallback)callbacks[1]).clearPassword();
if (username == null || password == null)
{
return false;
}
boolean authenticationSuccess = validateUser(username, password);
if (authenticationSuccess)
{
log.debug("Successful REST login request for authentication of user " + username);
sharedState.put("javax.security.auth.login.name", username);
return true;
}
else
{
String message = "Remote login via REST failed for username " + username;
log.warn(message);
throw new LoginException(message);
}
}
catch (LoginException le)
{
throw le;
}
catch (Exception e)
{
log.warn("Exception during login: " + e.getMessage(), e);
throw new LoginException(e.getMessage());
}
}
public boolean commit() throws LoginException
{
String username = (String)sharedState.get("javax.security.auth.login.name");
Set<Principal> principals = subject.getPrincipals();
Group roleGroup = new SimpleGroup("Roles");
for (String role : getRoles(username))
{
roleGroup.addMember(new SimplePrincipal(role));
}
// group principal
principals.add(roleGroup);
// username principal
principals.add(new SimplePrincipal(username));
return true;
}
public boolean abort() throws LoginException
{
return true;
}
public boolean logout() throws LoginException
{
// Remove all principals from Subject
Set<Principal> principals = new HashSet(subject.getPrincipals());
for (Principal p : principals)
{
subject.getPrincipals().remove(p);
}
return true;
}
// ********** PROTECTED HELPER METHODS ****************************
protected boolean validateUser(String username, String password) throws Exception
{
RestCallbackCaller restCallbackCaller = new RestCallbackCaller(this.gateInURL, this.httpMethod);
return restCallbackCaller.executeRemoteCall(username, password);
}
protected Collection<String> getRoles(String username)
{
if (rolesProcessingType == ROLES_PROCESSING_TYPE.STATIC)
{
return staticRolesList;
}
else
{
// TODO: Use RestCallbackCaller here as well
// We need to execute REST callback to GateIn to ask for roles
StringBuilder urlBuffer = new StringBuilder();
urlBuffer.append(this.gateInURL
+ "/rest/sso/authcallback/roles/" + username);
String url = urlBuffer.toString();
log.debug("Execute callback HTTP request: " + url);
ResponseContext responseContext = this.executeRemoteCall(url);
if (responseContext.status == 200)
{
String rolesString = responseContext.response;
String[] roles = rolesString.split(",");
return Arrays.asList(roles);
}
else
{
log.warn("Incorrect response received from REST callback for roles. Status=" + responseContext.status + ", Response=" + responseContext.response);
return new ArrayList<String>();
}
}
}
// ********** PRIVATE HELPER METHODS ****************************
private String readOption(String key, String defaultValue)
{
String result = (String)options.get(key);
if (result == null)
{
result = defaultValue;
}
if (log.isTraceEnabled())
{
log.trace("Read option " + key + "=" + result);
}
return result;
}
private ResponseContext executeRemoteCall(String authUrl)
{
DefaultHttpClient client = new DefaultHttpClient();
HttpGet method;
try
{
method = new HttpGet(authUrl);
HttpResponse httpResponse = client.execute(method);
int status = httpResponse.getStatusLine().getStatusCode();
HttpEntity entity = httpResponse.getEntity();
String response = entity == null ? null : EntityUtils.toString(entity);
if (log.isTraceEnabled())
{
log.trace("Received response from REST call: status=" + status + ", response=" + response);
}
return new ResponseContext(status, response);
}
catch (Exception e)
{
log.warn("Error when sending request through HTTP client", e);
return new ResponseContext(1000, e.getMessage());
}
finally
{
client.getConnectionManager().shutdown();
}
}
private static class ResponseContext
{
private final int status;
private final String response;
private ResponseContext(int status, String response)
{
this.status = status;
this.response = response;
}
}
private static enum ROLES_PROCESSING_TYPE
{
STATIC,
PORTAL_CALLBACK
}
}