package org.apache.cxf.clustering;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.apache.cxf.endpoint.Client;
import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.endpoint.Retryable;
import org.apache.cxf.helpers.CastUtils;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.message.Exchange;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageUtils;
import org.apache.cxf.transport.Conduit;
import org.apache.cxf.transport.http.HTTPConduit;
import org.apache.cxf.transports.http.configuration.HTTPClientPolicy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.github.jaceko.circuitswitcher.Circuit;
public class CircuitSwitcherTargetSelector extends FailoverTargetSelector {
private static final Logger LOG = LoggerFactory
.getLogger(CircuitSwitcherTargetSelector.class);
private static final String IS_SELECTED = "org.apache.cxf.clustering.CircuitBreakerTargetSelector.IS_SELECTED";
private Collection<Circuit> circuits = new LinkedHashSet<Circuit>();
private long resetTimeout;
private int failureThreshold;
private Long receiveTimeout;
public CircuitSwitcherTargetSelector(List<String> addressList, long resetTimeout,
int failureThreshold, Long receiveTimeout) {
this.resetTimeout = resetTimeout;
this.failureThreshold = failureThreshold;
this.receiveTimeout = receiveTimeout;
if (addressList != null) {
setAddressList(new ArrayList<String>(addressList));
LOG.info("Failover nodes: " + addressList.toString());
}
LOG.info("Failover reset timeout: " + resetTimeout);
LOG.info("Failure threshold: " + failureThreshold);
LOG.info("Receive timeout: " + receiveTimeout);
}
/**
* Called when a Conduit is actually required.
*
* @param message
* @return the Conduit to use for mediation of the message
*/
public synchronized Conduit selectConduit(Message message) {
Conduit c = message.get(Conduit.class);
if (c == null) {
Exchange exchange = message.getExchange();
InvocationKey key = new InvocationKey(exchange);
InvocationContext invocation = getInvocation(key);
if ((invocation != null) && !invocation.getContext().containsKey(IS_SELECTED)) {
Endpoint target = getAvailableTarget();
if (target != null && targetChanged(message, target)) {
setEndpoint(target);
message.put(Message.ENDPOINT_ADDRESS, target.getEndpointInfo().getAddress());
overrideAddressProperty(invocation.getContext());
invocation.getContext().put(IS_SELECTED, "");
} else if (target == null) {
throw new Fault(new IOException("No available targets"));
}
}
message.put(CONDUIT_COMPARE_FULL_URL, Boolean.TRUE);
c = getSelectedConduit(message);
}
if (receiveTimeout != null) {
HTTPClientPolicy httpClientPolicy = ((HTTPConduit) c).getClient();
httpClientPolicy.setReceiveTimeout(receiveTimeout);
}
return c;
}
private boolean targetChanged(Message message, Endpoint target) {
Object endpoinAddress = message.get(Message.ENDPOINT_ADDRESS);
return endpoinAddress == null
|| !endpoinAddress.toString().contains(target.getEndpointInfo().getAddress());
}
/**
* Get the failover target endpoint, if a suitable one is available.
*
* @param exchange
* the current Exchange
* @param invocation
* the current InvocationContext
* @return a failover endpoint if one is available
*
*
*/
@Override
protected Endpoint getFailoverTarget(Exchange exchange, InvocationContext invocation) {
Endpoint failoverTarget = getAvailableTarget();
if (failoverTarget != null) {
LOG.error("Connection error, retrying " + failoverTarget.getEndpointInfo().getAddress());
} else {
LOG.error("No more failover nodes available");
}
return failoverTarget;
}
/**
* Get first available target endpoint, if a suitable one is available.
*
* @return healthy endpoint if one is available
*/
private Endpoint getAvailableTarget() {
if ((circuits == null) || (circuits.isEmpty())) {
LOG.error("No adresses configured");
return null;
}
Iterator<Circuit> iterator = circuits.iterator();
String alternateAddress = null;
LOG.info("Checking available targets:");
while (iterator.hasNext()) {
Circuit target = iterator.next();
LOG.info("Target: {}", target);
if (target.connectionAvailable()) {
alternateAddress = target.getTargetAddress();
LOG.info("Selecting: {}", target);
break;
}
}
if (alternateAddress != null) {
Endpoint distributionTarget = getEndpoint();
distributionTarget.getEndpointInfo().setAddress(alternateAddress);
return distributionTarget;
} else {
return null;
}
}
/**
* Called on completion of the MEP for which the Conduit was required.
*
* @param exchange
* represents the completed MEP
*/
public void complete(Exchange exchange) {
InvocationKey key = new InvocationKey(exchange);
InvocationContext invocation = getInvocation(key);
boolean failover = false;
Conduit old = (Conduit) exchange.getOutMessage().remove(Conduit.class.getName());
if (requiresFailover(exchange)) {
onFailure(invocation.getContext());
LOG.debug("Failover {}", invocation.getContext());
Endpoint failoverTarget = getFailoverTarget(exchange, invocation);
if (failoverTarget != null) {
setEndpoint(failoverTarget);
if (old != null) {
old.close();
conduits.remove(old);
}
failover = performFailover(exchange, invocation);
}
} else {
if (invocation != null) {
onSuccess(invocation.getContext());
}
}
if (!failover) {
LOG.debug("Failover not required");
synchronized (this) {
inProgress.remove(key);
}
if (MessageUtils.isTrue(exchange.get("KeepConduitAlive"))) {
return;
}
try {
if (exchange.getInMessage() != null) {
Conduit c = (Conduit) exchange.getOutMessage().get(Conduit.class);
if (c == null) {
getSelectedConduit(exchange.getInMessage()).close(exchange.getInMessage());
} else {
c.close(exchange.getInMessage());
}
}
} catch (IOException e) {
}
}
}
private boolean performFailover(Exchange exchange, InvocationContext invocation) {
Exception prevExchangeFault = (Exception) exchange.remove(Exception.class.getName());
Message outMessage = exchange.getOutMessage();
Exception prevMessageFault = outMessage.getContent(Exception.class);
outMessage.setContent(Exception.class, null);
overrideAddressProperty(invocation.getContext());
Retryable retry = exchange.get(Retryable.class);
exchange.clear();
boolean failover = false;
if (retry != null) {
try {
failover = true;
long delay = getDelayBetweenRetries();
if (delay > 0) {
Thread.sleep(delay);
}
Map<String, Object> context = invocation.getContext();
retry.invoke(invocation.getBindingOperationInfo(), invocation.getParams(), context,
exchange);
} catch (Exception e) {
if (exchange.get(Exception.class) != null) {
exchange.put(Exception.class, prevExchangeFault);
}
if (outMessage.getContent(Exception.class) != null) {
outMessage.setContent(Exception.class, prevMessageFault);
}
}
}
return failover;
}
@Override
protected long getDelayBetweenRetries() {
return 0;
}
protected InvocationContext getInvocation(InvocationKey key) {
InvocationContext invocation;
synchronized (this) {
invocation = inProgress.get(key);
}
return invocation;
}
private Circuit findCircuit(String address) {
Circuit foundCircuit = null;
for (Circuit circuit : circuits) {
if (address.contains(circuit.getTargetAddress())) {
foundCircuit = circuit;
break;
}
}
return foundCircuit;
}
protected void onSuccess(Map<String, Object> context) {
String address = getAddressFrom(context);
Circuit circuit = findCircuit(address);
if (circuit != null) {
circuit.handleSuccesfullConnection();
LOG.debug("onSuccess: address: {}, circuit: {}, context: {}", address, circuit, context);
}
}
protected void onFailure(Map<String, Object> context) {
String address = getAddressFrom(context);
Circuit circuit = findCircuit(address);
if (circuit != null) {
circuit.handleFailedConnection();
LOG.debug("onFailure: address: {}, circuit: {}, context: {}", address, circuit, context);
}
}
private String getAddressFrom(Map<String, Object> context) {
Map<String, Object> requestContext = CastUtils.cast((Map<?, ?>) context
.get(Client.REQUEST_CONTEXT));
return (String) requestContext.get(Message.ENDPOINT_ADDRESS);
}
final void setAddressList(List<String> addressList) {
circuits = new LinkedHashSet<Circuit>();
for (String address : addressList) {
circuits.add(new Circuit(address, this.failureThreshold, this.resetTimeout));
}
}
void setResetTimeout(long resetTimeout) {
this.resetTimeout = resetTimeout;
}
void setFailureThreshold(int failureThreshold) {
this.failureThreshold = failureThreshold;
}
void setReceiveTimeout(Long receiveTimeout) {
this.receiveTimeout = receiveTimeout;
}
}