/* * TeleStax, Open Source Cloud Communications * Copyright 2011-2016, TeleStax Inc. and individual contributors * by the @authors tag. * * This program is free software: you can redistribute it and/or modify * under the terms of the GNU Affero General Public License as * published by the Free Software Foundation; either version 3 of * the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/> * * This file incorporates work covered by the following copyright and * permission notice: * * JBoss, Home of Professional Open Source * Copyright 2007-2011, Red Hat, Inc. and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jdiameter.client.impl.transport.tls; import static org.jdiameter.client.impl.helpers.Parameters.CipherSuites; import static org.jdiameter.client.impl.helpers.Parameters.SDEnableSessionCreation; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.net.Socket; import java.nio.BufferOverflowException; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousCloseException; import java.nio.channels.ClosedByInterruptException; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import javax.net.ssl.HandshakeCompletedEvent; import javax.net.ssl.HandshakeCompletedListener; import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import org.jdiameter.api.Avp; import org.jdiameter.api.AvpDataException; import org.jdiameter.api.AvpSet; import org.jdiameter.api.Message; import org.jdiameter.client.api.IMessage; import org.jdiameter.client.api.io.NotInitializedException; import org.jdiameter.client.api.parser.IMessageParser; import org.jdiameter.client.api.parser.ParseException; import org.jdiameter.common.api.concurrent.IConcurrentFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author <a href="mailto:baranowb@gmail.com"> Bartosz Baranowski </a> * @author <a href="mailto:brainslog@gmail.com"> Alexandre Mendonca </a> */ public class TLSTransportClient { // NOTE: SSL Does not provide channels, need to do plain old sync R/W :/ // So SSLSocket.getChannel() returns NULL! private static final Logger logger = LoggerFactory.getLogger(TLSTransportClient.class); private TLSClientConnection parentConnection; private IConcurrentFactory concurrentFactory; private boolean stop = false; // flag to indicate that initial shake did happen private boolean shaken; // flag indicating that SSL handshake is going on, while this is set to true, no messages can be exchanged. private boolean shaking; private Thread readThread; private InetSocketAddress destAddress; private InetSocketAddress origAddress; private String socketDescription = null; // sync streams to get data. private InputStream inputStream; private OutputStream outputStream; //private SSLSocket sslSocket; private Socket plainSocket; public static final int DEFAULT_BUFFER_SIZE = 4096; public static final int DEFAULT_STORAGE_SIZE = 4096; private int bufferSize = DEFAULT_BUFFER_SIZE; private ByteBuffer buffer = ByteBuffer.allocate(this.bufferSize); private int storageSize = DEFAULT_STORAGE_SIZE; private ByteBuffer storage = ByteBuffer.allocate(storageSize); private Lock lock = new ReentrantLock(); private IMessageParser parser; private final DiameterSSLHandshakeListener handshakeListener = new DiameterSSLHandshakeListener(); private final ReadTask readTash = new ReadTask(); //tell weather we are in a client mode private boolean client; private boolean receivedInband; /** * Default constructor * * @param parenConnection * connection created this transport */ public TLSTransportClient(TLSClientConnection parenConnection, IConcurrentFactory concurrentFactory, IMessageParser parser) { this.parentConnection = parenConnection; this.concurrentFactory = concurrentFactory; this.parser = parser; } public void initialize() throws IOException, NotInitializedException { if (destAddress == null) { throw new NotInitializedException("Destination address is not set"); } this.client = true; // SSLSocketFactory cltFct = parentConnection.getSSLFactory(); // this.sslSocket = (SSLSocket) cltFct.createSocket(); // // this.sslSocket.setEnableSessionCreation(parentConnection.getSSLConfig().getBooleanValue(SDEnableSessionCreation.ordinal(), true)); // this.sslSocket.setUseClientMode(true); // if (parentConnection.getSSLConfig().getStringValue(CipherSuites.ordinal(), null) != null) { // this.sslSocket.setEnabledCipherSuites(parentConnection.getSSLConfig().getStringValue(CipherSuites.ordinal(), null).split(",")); // } // // if (this.origAddress != null) { // this.sslSocket.bind(this.origAddress); // } // this.sslSocket.connect(this.destAddress); // // // now lets get streams. // this.sslInputStream = this.sslSocket.getInputStream(); // this.sslOutputStream = this.sslSocket.getOutputStream(); this.plainSocket = new Socket(); if (this.origAddress != null) { this.plainSocket.bind(this.origAddress); } this.plainSocket.connect(this.destAddress); this.inputStream = this.plainSocket.getInputStream(); this.outputStream = this.plainSocket.getOutputStream(); // now, we need to notify parent, this will START CER/CEA exchange // on CEA 2xxx we can enable TLS parentConnection.onConnected(); } public void initialize(Socket socket) throws IOException, NotInitializedException { logger.debug("Initialising TLSTransportClient for a socket on [{}]", socket); this.client = false; this.plainSocket = socket; this.socketDescription = socket.toString(); this.destAddress = new InetSocketAddress(socket.getInetAddress(), socket.getPort()); this.inputStream = this.plainSocket.getInputStream(); this.outputStream = this.plainSocket.getOutputStream(); } public void start() throws NotInitializedException { // for client if (this.socketDescription == null) { this.socketDescription = this.plainSocket.toString(); } logger.debug("Starting transport. Socket is {}", socketDescription); if (!this.plainSocket.isConnected()) { throw new NotInitializedException("Socket is not connected"); } if (getParent() == null) { throw new NotInitializedException("No parent connection is set is set"); } if (this.readThread == null || !this.readThread.isAlive()) { this.readThread = this.concurrentFactory.getThread("TLSReader", this.readTash); } if (!this.readThread.isAlive()) { this.readThread.setDaemon(true); this.readThread.start(); } } // ---------------- getters & setters --------------------- public TLSClientConnection getParent() { return parentConnection; } public InetSocketAddress getDestAddress() { return this.destAddress; } public void setDestAddress(InetSocketAddress address) { this.destAddress = address; if (logger.isDebugEnabled()) { logger.debug("Destination address is set to [{}] : [{}]", destAddress.getHostName(), destAddress.getPort()); } } public void setOrigAddress(InetSocketAddress address) { this.origAddress = address; if (logger.isDebugEnabled()) { logger.debug("Origin address is set to [{}] : [{}]", origAddress.getHostName(), origAddress.getPort()); } } public InetSocketAddress getOrigAddress() { return this.origAddress; } // ---------------- helper methods --------------------- void sendMessage(IMessage message) throws IOException, AvpDataException, NotInitializedException, ParseException { if (!isConnected()) { throw new IOException("Failed to send message over [" + socketDescription + "]"); } //switch to wait for SSL handshake to workout. if (!isExchangeAllowed()) { //TODO: do more? return; } doTLSPreSendProcessing(message); final ByteBuffer messageBuffer = this.parser.encodeMessage(message); if (logger.isDebugEnabled()) { logger.debug("About to send a byte buffer of size [{}] over the TLS socket [{}]", messageBuffer.array().length, socketDescription); } lock.lock(); try { this.outputStream.write(messageBuffer.array(), messageBuffer.position(), messageBuffer.limit()); doTLSPostSendProcessing(message); } catch (Exception e) { logger.debug("Unable to send message", e); throw new IOException("Error while sending message: " + e); } finally { lock.unlock(); } if (logger.isDebugEnabled()) { logger.debug("Sent a byte buffer of size [{}] over the TLS nio socket [{}]", messageBuffer.array().length, socketDescription); } } boolean isConnected() { return this.plainSocket != null && this.plainSocket.isConnected(); } void stop() throws Exception { logger.debug("Stopping transport. Socket is [{}]", socketDescription); stop = true; if (plainSocket != null && !plainSocket.isClosed()) { plainSocket.close(); } if (this.readThread != null) { this.readThread.join(100); } clearBuffer(); logger.debug("Transport is stopped. Socket is [{}]", socketDescription); } public void release() throws Exception { stop(); destAddress = null; } void append(byte[] data) { if (storage.position() + data.length >= storage.capacity()) { ByteBuffer tmp = ByteBuffer.allocate(storage.limit() + data.length * 2); byte[] tmpData = new byte[storage.position()]; storage.flip(); storage.get(tmpData); tmp.put(tmpData); storage = tmp; logger.warn("Increase storage size. Current size is {}", storage.array().length); } try { storage.put(data); } catch (BufferOverflowException boe) { logger.error("Buffer overflow occured", boe); } boolean messageReseived; do { messageReseived = seekMessage(storage); } while (messageReseived); } private boolean isExchangeAllowed() { this.lock.lock(); try { return !this.shaking; } finally { this.lock.unlock(); } } private boolean isSuccess(IMessage message) throws AvpDataException { Avp resultAvp = message.getResultCode(); if (resultAvp == null) { resultAvp = message.getAvps().getAvp(Avp.EXPERIMENTAL_RESULT); if (resultAvp == null) { // bad message, ignore if (logger.isDebugEnabled()) { logger.debug("Discarding message since SSL handshake has not been performed on [{}], dropped message [{}]. No result type avp.", socketDescription, message); } // TODO: anything else? return false; } resultAvp = resultAvp.getGrouped().getAvp(Avp.EXPERIMENTAL_RESULT_CODE); if (resultAvp == null) { // bad message, ignore if (logger.isDebugEnabled()) { logger.debug("Discarding message since SSL handshake has not been performed on [{}], dropped message [{}]. No result avp.", socketDescription, message); } } } long resultCode = resultAvp.getUnsigned32(); return resultCode >= 2000 && resultCode < 3000; } private boolean seekMessage(ByteBuffer localStorage) { if (storage.position() == 0) { return false; } storage.flip(); int tmp = localStorage.getInt(); localStorage.position(0); byte vers = (byte) (tmp >> 24); if (vers != 1) { return false; } int dataLength = (tmp & 0xFFFFFF); if (localStorage.limit() < dataLength) { localStorage.position(localStorage.limit()); localStorage.limit(localStorage.capacity()); return false; } byte[] data = new byte[dataLength]; localStorage.get(data); localStorage.position(dataLength); localStorage.compact(); ByteBuffer messageBuffer = ByteBuffer.wrap(data); try { if (logger.isDebugEnabled()) { logger.debug("Received message of size [{}]", data.length); } IMessage message = this.parser.createMessage(messageBuffer); // check if if (isExchangeAllowed()) { doTLSPreReceiveProcessing(message); getParent().onMessageReceived(message); } } catch (Exception e) { logger.debug("Garbage was received. Discarding."); storage.clear(); // not a best way. getParent().onAvpDataException(new AvpDataException(e)); } return true; } /** * @param message * @throws AvpDataException * @throws NotInitializedException */ private void doTLSPreReceiveProcessing(IMessage message) throws AvpDataException, NotInitializedException { if (this.shaken) { return; } if (this.client) { // if (CEA && message.isSuccess && message.has(inband)) { // startTLS(); // } if (message.isRequest()) { return; } if (message.getCommandCode() == Message.CAPABILITIES_EXCHANGE_ANSWER && isSuccess(message)) { AvpSet set = message.getAvps(); Avp inbandAvp = set.getAvp(Avp.INBAND_SECURITY_ID); if (inbandAvp != null && inbandAvp.getUnsigned32() == 1) { startTLS(); } } } else { // if (CER && message.has(inband)) { // this.receveidInband = true; // } if (!message.isRequest()) { return; } AvpSet set = message.getAvps(); Avp inbandAvp = set.getAvp(Avp.INBAND_SECURITY_ID); if (inbandAvp != null && inbandAvp.getUnsigned32() == 1) { this.receivedInband = true; } } } /** * @param message */ private void doTLSPreSendProcessing(IMessage message) { if (message.getCommandCode() == Message.CAPABILITIES_EXCHANGE_REQUEST) { AvpSet set = message.getAvps(); set.removeAvp(Avp.INBAND_SECURITY_ID); set.addAvp(Avp.INBAND_SECURITY_ID, 1); } } /** * @param message * @throws AvpDataException * @throws NotInitializedException */ private void doTLSPostSendProcessing(IMessage message) throws AvpDataException, NotInitializedException { // if ( !client && !shaken && CEA && message.isSuccess() && receivedInband) { // startTLS; // } if (this.shaken || this.client || this.plainSocket instanceof SSLSocket || message.isRequest() || message.getCommandCode() != Message.CAPABILITIES_EXCHANGE_ANSWER) { return; } if (this.receivedInband && isSuccess(message)) { this.receivedInband = false; startTLS(); } } /** * @throws NotInitializedException * */ private void startTLS() throws NotInitializedException { try { this.shaking = true; SSLSocketFactory cltFct = parentConnection.getSSLFactory(); SSLSocket sslSocket = (SSLSocket) cltFct.createSocket(this.plainSocket, null, this.plainSocket.getPort(), false); sslSocket.setEnableSessionCreation(parentConnection.getSSLConfig().getBooleanValue( SDEnableSessionCreation.ordinal(), true)); // only clients start shake if (parentConnection.getSSLConfig().getStringValue(CipherSuites.ordinal(), null) != null) { sslSocket.setEnabledCipherSuites(parentConnection.getSSLConfig().getStringValue(CipherSuites.ordinal(), null) .split(",")); } this.inputStream = sslSocket.getInputStream(); this.outputStream = sslSocket.getOutputStream(); this.plainSocket = sslSocket; if (this.client) { sslSocket.setUseClientMode(true); // TODO: catch this to check for failure sslSocket.addHandshakeCompletedListener(this.handshakeListener); sslSocket.startHandshake(); } else { sslSocket.addHandshakeCompletedListener(this.handshakeListener); sslSocket.setUseClientMode(false); } } catch (Exception e) { // TODO: ensure close? throw new NotInitializedException(e); } } private void clearBuffer() throws IOException { bufferSize = DEFAULT_BUFFER_SIZE; buffer = ByteBuffer.allocate(bufferSize); } // ---------------- helper classes --------------------- private class DiameterSSLHandshakeListener implements HandshakeCompletedListener { @Override public void handshakeCompleted(HandshakeCompletedEvent event) { // connected comes from here! try { lock.lock(); shaking = false; shaken = true; ((SSLSocket) plainSocket).removeHandshakeCompletedListener(this); getParent().onConnected(); } finally { lock.unlock(); } } } private class ReadTask implements Runnable { @Override public void run() { logger.debug("Transport is started. Socket is [{}]", socketDescription); try { while (!stop) { int dataLength = inputStream.read(buffer.array()); logger.debug("Just read [{}] bytes on [{}]", dataLength, socketDescription); if (dataLength == -1) { break; } buffer.position(dataLength); buffer.flip(); byte[] data = new byte[buffer.limit()]; buffer.get(data); append(data); buffer.clear(); } } catch (ClosedByInterruptException e) { logger.debug("Transport exception ", e); } catch (AsynchronousCloseException e) { logger.debug("Transport exception ", e); } catch (Throwable e) { logger.debug("Transport exception ", e); } finally { try { clearBuffer(); if (plainSocket != null && !plainSocket.isClosed()) { plainSocket.close(); } getParent().onDisconnect(); } catch (Exception e) { logger.debug("Error", e); } stop = false; logger.info("Read thread is stopped for socket [{}]", socketDescription); } } } }