package vnet.sms.gateway.server.framework.test;
import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;
import org.jboss.netty.handler.codec.serialization.ClassResolvers;
import org.jboss.netty.handler.codec.serialization.ObjectDecoder;
import org.jboss.netty.handler.codec.serialization.ObjectEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.BadCredentialsException;
import vnet.sms.common.messages.GsmPdu;
import vnet.sms.common.messages.LoginRequest;
import vnet.sms.common.messages.LoginResponse;
import vnet.sms.common.messages.PingRequest;
import vnet.sms.common.messages.PingResponse;
import vnet.sms.gateway.transports.serialization.ReferenceableMessageContainer;
public class IntegrationTestClient {
private final Logger log = LoggerFactory.getLogger(getClass());
private final InetSocketAddress serverAddress;
private ClientBootstrap bootstrap;
private Channel serverConnection;
/**
* @param serverAddress
*/
public IntegrationTestClient(final String host, final int port) {
this.serverAddress = new InetSocketAddress(host, port);
}
public void connect() throws Exception {
connect(false);
}
public void connect(final boolean respondToPing) throws Exception {
this.log.info("Connecting to {} ...", this.serverAddress);
this.bootstrap = new ClientBootstrap(new NioClientSocketChannelFactory(
Executors.newCachedThreadPool(),
Executors.newCachedThreadPool()));
this.bootstrap
.setPipelineFactory(new IntegrationTestClientChannelPipelineFactory(
respondToPing));
final ChannelFuture channelConnected = this.bootstrap
.connect(this.serverAddress);
this.serverConnection = channelConnected.awaitUninterruptibly()
.getChannel();
if (!channelConnected.isSuccess()) {
this.log.error("Failed to connect to " + this.serverAddress + ": "
+ channelConnected.getCause().getMessage(),
channelConnected.getCause());
this.bootstrap.releaseExternalResources();
this.bootstrap = null;
throw new RuntimeException(channelConnected.getCause());
}
this.log.info("Connected to {}", this.serverAddress);
}
public void sendMessage(final int messageReference, final GsmPdu gsmPdu)
throws Throwable {
sendMessage(messageReference, gsmPdu, null);
}
public void sendMessage(final int messageReference, final GsmPdu gsmPdu,
final MessageEventListener responseListener) throws Throwable {
this.log.debug("Sending message {} to {} via channel {}...",
new Object[] { gsmPdu, this.serverAddress,
this.serverConnection });
maybeInstallMessageListener(responseListener);
final ChannelFuture writeCompleted = getMandatoryServerConnection()
.write(ReferenceableMessageContainer.wrap(messageReference,
gsmPdu));
writeCompleted.awaitUninterruptibly();
if (!writeCompleted.isSuccess()) {
this.log.error("Failed to send " + gsmPdu + ": "
+ writeCompleted.getCause().getMessage(),
writeCompleted.getCause());
throw writeCompleted.getCause();
}
this.log.debug("Successfully sent message {} to {} via channel {}",
new Object[] { gsmPdu, this.serverAddress,
this.serverConnection });
}
private void maybeInstallMessageListener(
final MessageEventListener messageListener) {
if (messageListener != null) {
installMessageListener(messageListener);
}
}
private void installMessageListener(
final MessageEventListener messageListener) {
if (getMandatoryServerConnection().getPipeline().get(
ResponseListenerChannelHandler.NAME) != null) {
getMandatoryServerConnection().getPipeline().remove(
ResponseListenerChannelHandler.NAME);
}
getMandatoryServerConnection().getPipeline().addLast(
ResponseListenerChannelHandler.NAME,
new ResponseListenerChannelHandler(messageListener));
}
public ReferenceableMessageContainer sendMessageAndWaitForResponse(
final int messageReference, final GsmPdu gsmPdu) throws Throwable {
final CountDownLatch responseReceived = new CountDownLatch(1);
final AtomicReference<MessageEvent> receivedResponse = new AtomicReference<MessageEvent>();
final MessageEventListener responseListener = new MessageEventListener() {
@Override
public void messageEventReceived(final MessageEvent e) {
receivedResponse.set(e);
responseReceived.countDown();
}
};
sendMessage(messageReference, gsmPdu, responseListener);
this.log.debug("Waiting for response to message {} sent to {}", gsmPdu,
this.serverAddress);
responseReceived.await();
this.log.debug("Received response {} to message {}",
receivedResponse.get(), gsmPdu);
return ReferenceableMessageContainer.class.cast(receivedResponse.get()
.getMessage());
}
public ReferenceableMessageContainer sendMessageAndWaitForMatchingResponse(
final int messageReference, final GsmPdu gsmPdu,
final MessageEventPredicate messageEventPredicate) throws Throwable {
final CountDownLatch responseReceived = new CountDownLatch(1);
final AtomicReference<MessageEvent> receivedResponse = new AtomicReference<MessageEvent>();
final MessageEventListener responseListener = new MessageEventListener() {
@Override
public void messageEventReceived(final MessageEvent e) {
if (messageEventPredicate.evaluate(e)) {
receivedResponse.set(e);
responseReceived.countDown();
}
}
};
sendMessage(messageReference, gsmPdu, responseListener);
this.log.debug("Waiting for response to message {} sent to {}", gsmPdu,
this.serverAddress);
responseReceived.await();
this.log.debug("Received response {} to message {}",
receivedResponse.get(), gsmPdu);
return ReferenceableMessageContainer.class.cast(receivedResponse.get()
.getMessage());
}
public void listen(final MessageEventListener messageListener) {
installMessageListener(messageListener);
}
public CountDownLatch listen(
final MessageEventPredicate messageEventPredicate) {
final CountDownLatch matchingMessageEventReceived = new CountDownLatch(
1);
final MessageEventListener listenForMatchingEvent = new MessageEventListener() {
@Override
public void messageEventReceived(final MessageEvent e) {
if (messageEventPredicate.evaluate(e)) {
matchingMessageEventReceived.countDown();
}
}
};
installMessageListener(listenForMatchingEvent);
return matchingMessageEventReceived;
}
public void login(final int messageReference, final String username,
final String password) throws Throwable {
final LoginRequest loginRequest = new LoginRequest(username, password);
final ReferenceableMessageContainer loginResponseContainer = sendMessageAndWaitForResponse(
messageReference, loginRequest);
final GsmPdu response = loginResponseContainer.getMessage();
if (!(response instanceof LoginResponse)) {
throw new RuntimeException("Unexpected response to " + loginRequest
+ ": " + response);
}
final LoginResponse loginResponse = LoginResponse.class.cast(response);
if (!loginResponse.loginSucceeded()) {
throw new BadCredentialsException(
"Failed to login using username = " + username
+ " and password = " + password);
}
}
public void disconnect() throws Exception {
this.log.info("Disconnecting from {} ...", this.serverAddress);
final ChannelFuture channelDisconnected = getMandatoryServerConnection()
.disconnect();
this.bootstrap.releaseExternalResources();
this.bootstrap = null;
if (!channelDisconnected.isSuccess()) {
this.log.error("Failed to disconnect from " + this.serverAddress
+ ": " + channelDisconnected.getCause().getMessage(),
channelDisconnected.getCause());
throw new RuntimeException(channelDisconnected.getCause());
}
this.log.info("Disconnected from {}", this.serverAddress);
}
private Channel getMandatoryServerConnection() {
if (this.serverConnection == null) {
throw new IllegalStateException(
"No server connection - did you remember to call connect()?");
}
return this.serverConnection;
}
private final class IntegrationTestClientChannelPipelineFactory implements
ChannelPipelineFactory {
private final boolean respondToPing;
private final PingResponseChannelHandler pingResponseHandler = new PingResponseChannelHandler();
IntegrationTestClientChannelPipelineFactory(final boolean respondToPing) {
this.respondToPing = respondToPing;
}
@Override
public ChannelPipeline getPipeline() throws Exception {
final ChannelPipeline pipeline = Channels.pipeline();
pipeline.addLast("encoder", new ObjectEncoder());
pipeline.addLast("decoder",
new ObjectDecoder(ClassResolvers.cacheDisabled(null)));
if (this.respondToPing) {
pipeline.addLast(PingResponseChannelHandler.NAME,
this.pingResponseHandler);
}
return pipeline;
}
}
private final class PingResponseChannelHandler extends
SimpleChannelUpstreamHandler {
static final String NAME = "itest:ping-response";
private final AtomicInteger nextMessageRef = new AtomicInteger(
10000000);
@Override
public void messageReceived(final ChannelHandlerContext ctx,
final MessageEvent e) throws Exception {
try {
final Object message = e.getMessage();
if (ReferenceableMessageContainer.class.isInstance(message)
&& PingRequest.class
.isInstance(ReferenceableMessageContainer.class
.cast(message).getMessage())) {
final PingRequest pingRequest = PingRequest.class
.cast(ReferenceableMessageContainer.class.cast(
message).getMessage());
final PingResponse pingResponse = PingResponse
.accept(pingRequest);
sendMessage(this.nextMessageRef.incrementAndGet(),
pingResponse);
IntegrationTestClient.this.log.debug(
"Sent {} in response to {}", pingResponse, message);
} else {
super.messageReceived(ctx, e);
}
} catch (final Throwable e1) {
throw new RuntimeException(e1);
}
}
}
private final class ResponseListenerChannelHandler extends
SimpleChannelUpstreamHandler {
static final String NAME = "itest:response-listener";
private final MessageEventListener listener;
ResponseListenerChannelHandler(final MessageEventListener listener) {
this.listener = listener;
}
@Override
public void messageReceived(final ChannelHandlerContext ctx,
final MessageEvent e) throws Exception {
IntegrationTestClient.this.log.info("Received response {}", e);
this.listener.messageEventReceived(e);
super.messageReceived(ctx, e);
}
}
}