/* * TeleStax, Open Source Cloud Communications * Copyright 2011-2016, TeleStax Inc. and individual contributors * by the @authors tag. * * This program is free software: you can redistribute it and/or modify * under the terms of the GNU Affero General Public License as * published by the Free Software Foundation; either version 3 of * the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/> */ package org.jdiameter.client.impl.transport.tls.netty; import java.io.IOException; import java.net.InetSocketAddress; import org.jdiameter.api.Configuration; import org.jdiameter.client.api.IMessage; import org.jdiameter.client.api.parser.IMessageParser; import org.jdiameter.common.api.concurrent.IConcurrentFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; /** * * @author <a href="mailto:jqayyum@gmail.com"> Jehanzeb Qayyum </a> */ public class TLSTransportClient { private static final Logger logger = LoggerFactory.getLogger(TLSTransportClient.class); private TLSClientConnection parentConnection; private IConcurrentFactory concurrentFactory; private IMessageParser parser; private Configuration config; private InetSocketAddress destAddress; private InetSocketAddress origAddress; private String socketDescription = null; private Channel channel; private EventLoopGroup workerGroup; private volatile TlsHandshakingState tlsHandshakingState = TlsHandshakingState.INIT; enum TlsHandshakingState { INIT, SHAKING, SHAKEN } protected TLSTransportClient(TLSClientConnection parenConnection, IConcurrentFactory concurrentFactory, IMessageParser parser, Configuration config) { this.parentConnection = parenConnection; this.concurrentFactory = concurrentFactory; this.parser = parser; this.config = config; } public TLSTransportClient(TLSClientConnection parenConnection, IConcurrentFactory concurrentFactory, IMessageParser parser, Configuration config, InetSocketAddress destAddress, InetSocketAddress origAddress) { this(parenConnection, concurrentFactory, parser, config); if (destAddress == null) { throw new IllegalArgumentException("Destination address is required"); } this.destAddress = destAddress; this.origAddress = origAddress; this.socketDescription = origAddress.toString() + "->" + destAddress.toString(); logger.debug("Created TLSTransportClient (client) for {}", socketDescription); } public TLSTransportClient(TLSClientConnection parenConnection, IConcurrentFactory concurrentFactory, IMessageParser parser, Configuration config, Channel channel) { this(parenConnection, concurrentFactory, parser, config); if (channel == null) { throw new IllegalArgumentException("Channel is required"); } this.channel = channel; this.origAddress = (InetSocketAddress) this.channel.localAddress(); this.destAddress = (InetSocketAddress) this.channel.remoteAddress(); this.socketDescription = origAddress.toString() + "->" + destAddress.toString(); ChannelPipeline pipeline = this.channel.pipeline(); pipeline.addLast("startTlsServerHandler", new StartTlsServerHandler(this)); pipeline.addLast("decoder", new DiameterMessageDecoder(parenConnection, parser)); pipeline.addLast("msgHandler", new DiameterMessageHandler(parentConnection, true)); pipeline.addLast("encoder", new DiameterMessageEncoder(parser)); pipeline.addLast("inbandWriter", new InbandSecurityHandler()); logger.debug("Created TLSTransportClient (server) for {}", socketDescription); } // only client side public void start() throws InterruptedException { logger.debug("Staring client TLSTransportClient {} ", socketDescription); if (isConnected()) { logger.debug("Already connected TLSTransportClient {} ", socketDescription); return; } workerGroup = new NioEventLoopGroup(); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() { @Override protected void initChannel(SocketChannel channel) throws Exception { ChannelPipeline pipeline = channel.pipeline(); pipeline.addLast("decoder", new DiameterMessageDecoder(parentConnection, parser)); pipeline.addLast("msgHandler", new DiameterMessageHandler(parentConnection, false)); pipeline.addLast("startTlsInitiator", new StartTlsInitiator(config, TLSTransportClient.this)); pipeline.addLast("encoder", new DiameterMessageEncoder(parser)); pipeline.addLast("inbandWriter", new InbandSecurityHandler()); } }); this.channel = bootstrap.remoteAddress(destAddress).connect().sync().channel(); parentConnection.onConnected(); logger.debug("Started TLS Transport on Socket {}", socketDescription); } public TLSClientConnection getParent() { return parentConnection; } public InetSocketAddress getDestAddress() { return this.destAddress; } public void setDestAddress(InetSocketAddress address) { this.destAddress = address; if (logger.isDebugEnabled()) { logger.debug("Destination address is set to [{}] : [{}]", destAddress.getHostName(), destAddress.getPort()); } } public void setOrigAddress(InetSocketAddress address) { this.origAddress = address; if (logger.isDebugEnabled()) { logger.debug("Origin address is set to [{}] : [{}]", origAddress.getHostName(), origAddress.getPort()); } } public InetSocketAddress getOrigAddress() { return this.origAddress; } void sendMessage(IMessage message) throws IOException { if (!isConnected()) { throw new IOException("Failed to send message over [" + socketDescription + "]"); } if (this.tlsHandshakingState == TlsHandshakingState.SHAKING) { return; } logger.debug("About to send a message over the TLS socket [{}]", socketDescription); channel.writeAndFlush(message); } boolean isConnected() { return this.channel != null && this.channel.isActive(); } public void stop() { //logger.debug("Stopping TLS Transport {}", socketDescription); closeChannel(); closeWorkerGroup(); //logger.debug("TLS Transport is stopped {}", socketDescription); getParent().disconnect(); } public void release() throws Exception { stop(); destAddress = null; origAddress = null; } private void closeChannel() { if (channel != null && channel.isActive()) { try { channel.closeFuture().sync(); } catch (InterruptedException e) { logger.error("Error stopping socket " + socketDescription, e); } channel = null; } } private void closeWorkerGroup() { if (workerGroup != null && !workerGroup.isShuttingDown()) { try { workerGroup.shutdownGracefully().sync(); } catch (InterruptedException e) { logger.error("Error stopping socket " + socketDescription, e); } workerGroup = null; } } public TlsHandshakingState getTlsHandshakingState() { return tlsHandshakingState; } public void setTlsHandshakingState(TlsHandshakingState tlsHandshakingState) { this.tlsHandshakingState = tlsHandshakingState; } public TLSClientConnection getParentConnection() { return parentConnection; } public IMessageParser getParser() { return parser; } public Configuration getConfig() { return config; } }