package io.craft.atom.rpc;
import io.craft.atom.io.Channel;
import io.craft.atom.io.IoConnector;
import io.craft.atom.io.IoHandler;
import io.craft.atom.nio.NioOrderedDirectChannelEventDispatcher;
import io.craft.atom.nio.api.NioFactory;
import io.craft.atom.protocol.rpc.model.RpcMessage;
import io.craft.atom.rpc.api.RpcContext;
import io.craft.atom.rpc.spi.RpcChannel;
import io.craft.atom.rpc.spi.RpcConnector;
import io.craft.atom.rpc.spi.RpcProtocol;
import io.craft.atom.util.thread.NamedThreadFactory;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import lombok.Getter;
import lombok.Setter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author mindwind
* @version 1.0, Aug 15, 2014
*/
public class DefaultRpcConnector implements RpcConnector {
private static final Logger LOG = LoggerFactory.getLogger(DefaultRpcConnector.class);
@Getter @Setter private int connectTimeoutInMillis;
@Getter @Setter private int rpcTimeoutInMillis ;
@Getter private int heartbeatInMillis ;
@Getter @Setter private int reconnectDelay ;
@Getter @Setter private boolean allowReconnect ;
@Getter private SocketAddress address ;
@Getter @Setter private Map<Long, DefaultRpcChannel> channels ;
@Getter @Setter private IoHandler ioHandler ;
@Getter @Setter private IoConnector ioConnector ;
@Getter @Setter private ScheduledExecutorService hbScheduler ;
@Getter @Setter private ExecutorService reconnectExecutor ;
@Getter private RpcProtocol protocol ;
// ~ ------------------------------------------------------------------------------------------------------------
public DefaultRpcConnector() {
reconnectDelay = 6000;
allowReconnect = true;
connectTimeoutInMillis = Integer.MAX_VALUE;
rpcTimeoutInMillis = Integer.MAX_VALUE;
heartbeatInMillis = 0;
reconnectExecutor = Executors.newSingleThreadScheduledExecutor(new NamedThreadFactory("craft-atom-rpc-connector-reconnect"));
channels = new ConcurrentHashMap<Long, DefaultRpcChannel>();
ioHandler = new RpcClientIoHandler(this);
ioConnector = NioFactory.newTcpConnectorBuilder(ioHandler)
.connectTimeoutInMillis(connectTimeoutInMillis)
.dispatcher(new NioOrderedDirectChannelEventDispatcher())
.build();
}
// ~ ------------------------------------------------------------------------------------------------------------
@Override
public long connect() throws RpcException {
try {
Future<Channel<byte[]>> future = ioConnector.connect(address);
Channel<byte[]> channel = future.get(connectTimeoutInMillis, TimeUnit.MILLISECONDS);
DefaultRpcChannel rpcChannel = new DefaultRpcChannel(channel, protocol.getRpcEncoder(), protocol.getRpcDecoder());
rpcChannel.setFutures(new ConcurrentHashMap<Long, RpcFuture<?>>());
channel.setAttribute(RpcIoHandler.RPC_CHANNEL, rpcChannel);
long id = channel.getId();
channels.put(id, rpcChannel);
LOG.debug("[CRAFT-ATOM-RPC] Rpc client connector established connection, |channel={}|.", rpcChannel);
return id;
} catch (TimeoutException e) {
throw new RpcException(RpcException.CLIENT_TIMEOUT, "client timeout", e);
} catch (IOException e) {
throw new RpcException(RpcException.NETWORK, "network error", e);
} catch (Exception e) {
throw new RpcException(RpcException.UNKNOWN, "unknown error", e);
}
}
@Override
public boolean disconnect(long connectionId) {
DefaultRpcChannel channel = channels.remove(connectionId);
if (channel != null) {
channel.close();
return true;
}
return false;
}
@Override
public void close() {
brokeAll();
channels.clear();
ioConnector.shutdown();
reconnectExecutor.shutdownNow();
if (hbScheduler != null) { hbScheduler.shutdownNow(); }
}
@Override
public RpcMessage send(RpcMessage req, boolean async) throws RpcException {
long mid = req.getId();
DefaultRpcChannel channel = select(mid);
if (channel == null) throw new RpcException(RpcException.NETWORK, "network error");
try {
boolean oneway = req.isOneway();
RpcFuture<Object> future = null;
if (!oneway) {
future = new DefaultRpcFuture<Object>();
channel.setRpcFuture(mid, future);
}
channel.write(req);
// One way request, client does not expect response
if (oneway) { return null; }
if (async) {
// async and set future
RpcContext.getContext().setFuture(future);
return null;
} else {
// sync and wait response
future.await(req.getRpcTimeoutInMillis(), TimeUnit.MILLISECONDS);
return future.getResponse();
}
} catch (RpcException e) {
throw e;
} catch (IOException e) {
throw new RpcException(RpcException.NETWORK, "network error", e);
} catch (TimeoutException e) {
throw new RpcException(RpcException.CLIENT_TIMEOUT, "client timeout", e);
} catch (Exception e) {
throw new RpcException(RpcException.UNKNOWN, "unknown error", e);
}
}
void reconnect(final long connectionId) {
if (!disconnect(connectionId)) return;
reconnectExecutor.execute(new Runnable() {
@Override
public void run() {
while (!retryConnect()) {
try { Thread.sleep(reconnectDelay); } catch (InterruptedException e) {}
}
}
private boolean retryConnect() {
try {
if (!allowReconnect) return false;
long connId = connect();
if (connId > 0) {
LOG.debug("[CRAFT-ATOM-RPC] Rpc client connector reconnect success, |connectionId={}|", connId);
return true;
} else {
LOG.debug("[CRAFT-ATOM-RPC] Rpc client connector reconnect fail");
return false;
}
} catch (Exception e) {
return false;
}
}
});
}
private DefaultRpcChannel select(long id) {
Collection<DefaultRpcChannel> collection = channels.values();
Object[] chs = collection.toArray();
if (chs.length == 0) return null;
int i = (int) (Math.abs(id) % chs.length);
return (DefaultRpcChannel) chs[i];
}
@Override
public void setAddress(SocketAddress address) {
this.address = address;
}
@Override
public void setHeartbeatInMillis(int heartbeatInMillis) {
this.heartbeatInMillis = heartbeatInMillis;
heartbeat();
}
private void heartbeat() {
if (hbScheduler != null) {
hbScheduler.shutdown();
}
if (heartbeatInMillis > 0) {
hbScheduler = Executors.newSingleThreadScheduledExecutor(new NamedThreadFactory("craft-atom-rpc-connector-heartbeat"));
hbScheduler.scheduleAtFixedRate(new Runnable() {
@Override
public void run() {
for (RpcChannel channel : channels.values()) {
try {
RpcMessage hbmsg = RpcMessages.newHbRequestRpcMessage();
channel.write(hbmsg);
LOG.debug("[CRAFT-ATOM-RPC] Rpc client connector heartbeat, |hbmsg={}, channel={}|", hbmsg, channel);
} catch (Exception e) {
LOG.warn("[CRAFT-ATOM-RPC] Rpc client connector heartbeat error", e);
}
}
}
}, 0, heartbeatInMillis, TimeUnit.MILLISECONDS);
}
}
@Override
public void setProtocol(RpcProtocol protocol) {
this.protocol = protocol;
}
@Override
public int waitCount() {
int wc = 0;
for (DefaultRpcChannel ch : channels.values()) {
wc += ch.waitCount();
}
return wc;
}
// ~ ----------------------------------------------------------------------------------------------------- for test
/**
* Broke all connections
*/
public void brokeAll() {
for (DefaultRpcChannel channel : channels.values()) {
channel.close();
}
}
/**
* @return all alive connection number at the moment.
*/
public int aliveConnectionNum() {
int num = 0;
for (DefaultRpcChannel channel : channels.values()) {
if (channel.isOpen()) num++;
}
return num;
}
}