package vnet.sms.gateway.server.framework.test;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.local.DefaultLocalClientChannelFactory;
import org.jboss.netty.channel.local.LocalAddress;
import org.jboss.netty.handler.codec.serialization.ClassResolvers;
import org.jboss.netty.handler.codec.serialization.ObjectDecoder;
import org.jboss.netty.handler.codec.serialization.ObjectEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.BadCredentialsException;
import vnet.sms.common.messages.GsmPdu;
import vnet.sms.common.messages.LoginRequest;
import vnet.sms.common.messages.LoginResponse;
import vnet.sms.gateway.transports.serialization.ReferenceableMessageContainer;
public class LocalClient {
public interface MessageListener {
void messageReceived(MessageEvent e);
}
private final Logger log = LoggerFactory.getLogger(getClass());
private final LocalAddress serverAddress;
private ClientBootstrap bootstrap;
private Channel serverConnection;
/**
* @param serverAddress
*/
public LocalClient(final LocalAddress serverAddress) {
this.serverAddress = serverAddress;
}
public void connect() throws Throwable {
this.log.info("Connecting to {} ...", this.serverAddress);
this.bootstrap = new ClientBootstrap(
new DefaultLocalClientChannelFactory());
this.bootstrap
.setPipelineFactory(new LocalClientChannelPipelineFactory());
final ChannelFuture channelConnected = this.bootstrap
.connect(this.serverAddress);
this.serverConnection = channelConnected.awaitUninterruptibly()
.getChannel();
if (!channelConnected.isSuccess()) {
this.log.error("Failed to connect to " + this.serverAddress + ": "
+ channelConnected.getCause().getMessage(),
channelConnected.getCause());
this.bootstrap.releaseExternalResources();
this.bootstrap = null;
throw channelConnected.getCause();
}
this.log.info("Connected to {}", this.serverAddress);
}
public void sendMessage(final int messageReference, final GsmPdu gsmPdu)
throws Throwable {
sendMessage(messageReference, gsmPdu, null);
}
public void sendMessage(final int messageReference, final GsmPdu gsmPdu,
final MessageListener responseListener) throws Throwable {
this.log.debug("Sending message {} to {} ...", gsmPdu,
this.serverAddress);
maybeInstallMessageListener(responseListener);
final ChannelFuture writeCompleted = getMandatoryServerConnection()
.write(ReferenceableMessageContainer.wrap(messageReference,
gsmPdu));
writeCompleted.awaitUninterruptibly();
if (!writeCompleted.isSuccess()) {
this.log.error("Failed to send " + gsmPdu + ": "
+ writeCompleted.getCause().getMessage(),
writeCompleted.getCause());
throw writeCompleted.getCause();
}
this.log.debug("Successfully sent message {} to {}", gsmPdu,
this.serverAddress);
}
private void maybeInstallMessageListener(
final MessageListener messageListener) {
if (messageListener != null) {
installMessageListener(messageListener);
}
}
private void installMessageListener(final MessageListener messageListener) {
if (getMandatoryServerConnection().getPipeline().get(
ResponseListenerChannelHandler.NAME) != null) {
getMandatoryServerConnection().getPipeline().remove(
ResponseListenerChannelHandler.NAME);
}
getMandatoryServerConnection().getPipeline().addLast(
ResponseListenerChannelHandler.NAME,
new ResponseListenerChannelHandler(messageListener));
}
public ReferenceableMessageContainer sendMessageAndWaitForResponse(
final int messageReference, final GsmPdu gsmPdu) throws Throwable {
final CountDownLatch responseReceived = new CountDownLatch(1);
final AtomicReference<MessageEvent> receivedResponse = new AtomicReference<MessageEvent>();
final MessageListener responseListener = new MessageListener() {
@Override
public void messageReceived(final MessageEvent e) {
receivedResponse.set(e);
responseReceived.countDown();
}
};
sendMessage(messageReference, gsmPdu, responseListener);
responseReceived.await();
return ReferenceableMessageContainer.class.cast(receivedResponse.get()
.getMessage());
}
public void listen(final MessageListener messageListener) {
installMessageListener(messageListener);
}
public void login(final int messageReference, final String username,
final String password) throws Throwable {
final LoginRequest loginRequest = new LoginRequest(username, password);
final ReferenceableMessageContainer loginResponseContainer = sendMessageAndWaitForResponse(
messageReference, loginRequest);
final GsmPdu response = loginResponseContainer.getMessage();
if (!(response instanceof LoginResponse)) {
throw new RuntimeException("Unexpected response to " + loginRequest
+ ": " + response);
}
final LoginResponse loginResponse = LoginResponse.class.cast(response);
if (!loginResponse.loginSucceeded()) {
throw new BadCredentialsException(
"Failed to login using username = " + username
+ " and password = " + password);
}
}
public void disconnect() throws Throwable {
this.log.info("Disconnecting from {} ...", this.serverAddress);
final ChannelFuture channelDisconnected = getMandatoryServerConnection()
.disconnect();
this.bootstrap.releaseExternalResources();
this.bootstrap = null;
if (!channelDisconnected.isSuccess()) {
this.log.error("Failed to disconnect from " + this.serverAddress
+ ": " + channelDisconnected.getCause().getMessage(),
channelDisconnected.getCause());
throw channelDisconnected.getCause();
}
this.log.info("Disconnected from {}", this.serverAddress);
}
private Channel getMandatoryServerConnection() {
if (this.serverConnection == null) {
throw new IllegalStateException(
"No server connection - did you remember to call connect()?");
}
return this.serverConnection;
}
private final class LocalClientChannelPipelineFactory implements
ChannelPipelineFactory {
@Override
public ChannelPipeline getPipeline() throws Exception {
final ChannelPipeline pipeline = Channels.pipeline();
pipeline.addLast("encoder", new ObjectEncoder());
pipeline.addLast("decoder",
new ObjectDecoder(ClassResolvers.cacheDisabled(null)));
return pipeline;
}
}
private final class ResponseListenerChannelHandler extends
SimpleChannelUpstreamHandler {
public static final String NAME = "test:response-listener";
private final MessageListener listener;
ResponseListenerChannelHandler(final MessageListener listener) {
this.listener = listener;
}
@Override
public void messageReceived(final ChannelHandlerContext ctx,
final MessageEvent e) throws Exception {
LocalClient.this.log.info("Received response {}", e);
this.listener.messageReceived(e);
super.messageReceived(ctx, e);
}
}
}