package org.atricore.idbus.capabilities.sso.main.binding;
import oasis.names.tc.saml._2_0.protocol.RequestAbstractType;
import org.apache.camel.Exchange;
import org.apache.camel.Message;
import org.apache.camel.ProducerTemplate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.atricore.idbus.capabilities.sso.main.claims.SSOCredentialClaimsRequest;
import org.atricore.idbus.capabilities.sso.support.binding.SSOBinding;
import org.atricore.idbus.kernel.main.federation.metadata.EndpointDescriptor;
import org.atricore.idbus.kernel.main.mediation.*;
import org.atricore.idbus.kernel.main.mediation.camel.CamelIdentityMediationUnitContainer;
import org.atricore.idbus.kernel.main.mediation.camel.component.binding.AbstractMediationBinding;
import org.atricore.idbus.kernel.main.mediation.camel.component.binding.CamelMediationExchange;
import org.atricore.idbus.kernel.main.mediation.camel.component.binding.CamelMediationMessage;
import org.atricore.idbus.kernel.main.mediation.state.LocalState;
import org.atricore.idbus.kernel.main.mediation.state.ProviderStateContext;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;
/**
* @author <a href="mailto:sgonzalez@atricore.org">Sebastian Gonzalez Oyuela</a>
* @version $Id$
*/
public class SamlR2LocalBinding extends AbstractMediationBinding {
private static final Log logger = LogFactory.getLog(SamlR2LocalBinding.class);
public SamlR2LocalBinding(Channel channel) {
super(SSOBinding.SAMLR2_LOCAL.getValue(), channel);
}
public MediationMessage createMessage(CamelMediationMessage message) {
CamelMediationExchange samlR2exchange = message.getExchange();
Exchange exchange = samlR2exchange.getExchange();
logger.debug("Create Message Body from exchange " + exchange.getClass().getName());
// Converting from Local Message to SAMLR2 Message
// Is this a Loca message?
Message in = exchange.getIn();
if (in.getBody() instanceof RequestAbstractType) {
MediationState state = null;
LocalState lState = null;
MediationMessage body;
RequestAbstractType samlReq = (RequestAbstractType) in.getBody();
try {
Method getSessionIndex = samlReq.getClass().getMethod("getSessionIndex");
List<String> sessionIndexes = (List<String>) getSessionIndex.invoke(samlReq);
if (sessionIndexes != null) {
if (sessionIndexes.size() > 0) {
String sessionIndex = sessionIndexes.get(0);
ProviderStateContext ctx = createProviderStateContext();
lState = ctx.retrieve("idpSsoSessionId", sessionIndex);
if (logger.isDebugEnabled())
logger.debug("Local state was" + (lState == null ? " NOT" : "") + " retrieved for ssoSessionId " + sessionIndex);
}
}
} catch (NoSuchMethodException e) {
// Ignore this ...
if (logger.isTraceEnabled())
logger.trace("SSO Request does not have session index : " + e.getMessage());
} catch (InvocationTargetException e) {
logger.error("Cannot recover local state : " + e.getMessage(), e);
} catch (IllegalAccessException e) {
logger.error("Cannot recover local state : " + e.getMessage(), e);
}
if (lState == null) {
// Create a new local state instance ?
state = createMediationState(exchange);
} else {
state = new MediationStateImpl(lState);
}
// Process Saml Response in SOAP Channel
body = new MediationMessageImpl(
in.getMessageId(),
in.getBody(),
null,
null,
null,
state);
return body;
} else if (in.getBody() instanceof SSOCredentialClaimsRequest) {
MediationState state = null;
LocalState lState = null;
MediationMessage body;
SSOCredentialClaimsRequest samlr2ClaimRequest = (SSOCredentialClaimsRequest) in.getBody();
ProviderStateContext ctx = createProviderStateContext();
lState = ctx.retrieve(samlr2ClaimRequest.getTargetRelayState());
if (lState == null) {
// Create a new local state instance ?
state = createMediationState(exchange);
} else {
state = new MediationStateImpl(lState);
}
// Process Saml Response in SOAP Channel
body = new MediationMessageImpl(
in.getMessageId(),
in.getBody(),
null,
null,
null,
state);
return body;
} else {
throw new IllegalArgumentException("Unknown message type " + in.getBody());
}
}
public void copyMessageToExchange(CamelMediationMessage message, Exchange exchange) {
if (logger.isDebugEnabled())
logger.debug("Copying SAML 2.0 LOCAL Message");
MediationMessage outMsg = message.getMessage();
copyBackState(outMsg.getState(), exchange);
exchange.getOut().setBody(outMsg.getContent());
}
public void copyFaultMessageToExchange(CamelMediationMessage faultMessage, Exchange exchange) {
if (logger.isTraceEnabled())
logger.trace("Copy Fault to Exchange for Local binding!");
// TODO : Is copyFaultMessageToExchange necessary ?
logger.warn("'copyFaultMessageToExchange' Not implemented !");
MediationMessage m = faultMessage.getMessage();
if (m.getFault() != null) {
logger.error(m.getFault().getMessage(), m.getFault());
}
}
@Override
public Object sendMessage(MediationMessage message) throws IdentityMediationException {
if (logger.isTraceEnabled())
logger.trace("Sending new SAML 2.0 message using Local Binding");
IdentityMediationUnitContainer uc = channel.getUnitContainer();
if (uc instanceof CamelIdentityMediationUnitContainer) {
EndpointDescriptor ed = message.getDestination();
ProducerTemplate t = ((CamelIdentityMediationUnitContainer)uc).getTemplate();
String camelEndpoint = "direct:" + ed.getLocation();
if (logger.isTraceEnabled())
logger.trace("Sending message content to [" + camelEndpoint+"]");
Object o = t.sendBody(camelEndpoint, message.getContent());
if (logger.isTraceEnabled())
logger.trace("Received from ["+camelEndpoint+"] " + o);
return o;
} else {
throw new UnsupportedOperationException("Unint container type unknown " + uc);
}
}
}