/* * TeleStax, Open Source Cloud Communications * Copyright 2011-2014, TeleStax Inc. and individual contributors * by the @authors tag. * * This program is free software: you can redistribute it and/or modify * under the terms of the GNU Affero General Public License as * published by the Free Software Foundation; either version 3 of * the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/> * * This file incorporates work covered by the following copyright and * permission notice: * * JBoss, Home of Professional Open Source * Copyright 2007-2011, Red Hat, Inc. and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jdiameter.client.impl.transport.tcp; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Socket; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousCloseException; import java.nio.channels.ClosedByInterruptException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; import java.util.Iterator; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.jdiameter.api.AvpDataException; import org.jdiameter.client.api.io.NotInitializedException; import org.jdiameter.common.api.concurrent.IConcurrentFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author erick.svenson@yahoo.com * @author <a href="mailto:baranowb@gmail.com"> Bartosz Baranowski </a> * @author <a href="mailto:brainslog@gmail.com"> Alexandre Mendonca </a> */ public class TCPTransportClient implements Runnable { private TCPClientConnection parentConnection; private IConcurrentFactory concurrentFactory; public static final int DEFAULT_BUFFER_SIZE = 1024; public static final int DEFAULT_STORAGE_SIZE = 2048; protected boolean stop = false; protected Thread selfThread; protected int bufferSize = DEFAULT_BUFFER_SIZE; protected ByteBuffer buffer = ByteBuffer.allocate(this.bufferSize); protected InetSocketAddress destAddress; protected InetSocketAddress origAddress; protected SocketChannel socketChannel; protected Lock lock = new ReentrantLock(); protected int storageSize = DEFAULT_STORAGE_SIZE; protected ByteBuffer storage = ByteBuffer.allocate(storageSize); private String socketDescription = null; private static final Logger logger = LoggerFactory.getLogger(TCPTransportClient.class); //PCB - allow non blocking IO private static final boolean BLOCKING_IO = false; private static final long SELECT_TIMEOUT = 500; // milliseconds public TCPTransportClient() { } /** * Default constructor * * @param concurrentFactory factory for create threads * @param parenConnection connection created this transport */ TCPTransportClient(IConcurrentFactory concurrentFactory, TCPClientConnection parenConnection) { this.parentConnection = parenConnection; this.concurrentFactory = concurrentFactory; } /** * Network init socket */ public void initialize() throws IOException, NotInitializedException { logger.debug("Initialising TCPTransportClient. Origin address is [{}] and destination address is [{}]", origAddress, destAddress); if (destAddress == null) { throw new NotInitializedException("Destination address is not set"); } socketChannel = SelectorProvider.provider().openSocketChannel(); try { if (origAddress != null) { socketChannel.socket().bind(origAddress); } socketChannel.connect(destAddress); //PCB added logging socketChannel.configureBlocking(BLOCKING_IO); getParent().onConnected(); } catch (IOException e) { if (origAddress != null) { socketChannel.socket().close(); } socketChannel.close(); throw e; } } public TCPClientConnection getParent() { return parentConnection; } public void initialize(Socket socket) throws IOException, NotInitializedException { logger.debug("Initialising TCPTransportClient for a socket on [{}]", socket); socketDescription = socket.toString(); socketChannel = socket.getChannel(); //PCB added logging socketChannel.configureBlocking(BLOCKING_IO); destAddress = new InetSocketAddress(socket.getInetAddress(), socket.getPort()); } public void start() throws NotInitializedException { // for client if (socketDescription == null && socketChannel != null) { socketDescription = socketChannel.socket().toString(); } logger.debug("Starting transport. Socket is {}", socketDescription); if (socketChannel == null) { throw new NotInitializedException("Transport is not initialized"); } if (!socketChannel.isConnected()) { throw new NotInitializedException("Socket channel is not connected"); } if (getParent() == null) { throw new NotInitializedException("No parent connection is set is set"); } if (selfThread == null || !selfThread.isAlive()) { selfThread = concurrentFactory.getThread("TCPReader", this); } if (!selfThread.isAlive()) { selfThread.setDaemon(true); selfThread.start(); } } //PCB added logging @Override public void run() { // Workaround for Issue #4 (http://code.google.com/p/jdiameter/issues/detail?id=4) // BEGIN WORKAROUND // Give some time to initialization... int sleepTime = 250; logger.debug("Sleeping for {}ms before starting transport so that listeners can all be added and ready for messages", sleepTime); try { Thread.sleep(sleepTime); } catch (InterruptedException e) { // ignore } logger.debug("Finished sleeping for {}ms. By now, MutablePeerTableImpl should have added its listener", sleepTime); logger.debug("Transport is started. Socket is [{}]", socketDescription); Selector selector = null; try { selector = Selector.open(); socketChannel.register(selector, SelectionKey.OP_READ); while (!stop) { selector.select(SELECT_TIMEOUT); Iterator<SelectionKey> it = selector.selectedKeys().iterator(); while (it.hasNext()) { // Get the selection key SelectionKey selKey = it.next(); // Remove it from the list to indicate that it is being processed it.remove(); if (selKey.isValid() && selKey.isReadable()) { // Get channel with bytes to read SocketChannel sChannel = (SocketChannel) selKey.channel(); int dataLength = sChannel.read(buffer); logger.debug("Just read [{}] bytes on [{}]", dataLength, socketDescription); if (dataLength == -1) { stop = true; break; } buffer.flip(); byte[] data = new byte[buffer.limit()]; buffer.get(data); append(data); buffer.clear(); } } } } catch (ClosedByInterruptException e) { logger.error("Transport exception ", e); } catch (AsynchronousCloseException e) { logger.error("Transport is closed"); } catch (Throwable e) { logger.error("Transport exception ", e); } finally { try { clearBuffer(); if (selector != null) { selector.close(); } if (socketChannel != null && socketChannel.isOpen()) { socketChannel.close(); } getParent().onDisconnect(); } catch (Exception e) { logger.error("Error", e); } stop = false; logger.info("Read thread is stopped for socket [{}]", socketDescription); } } public void stop() throws Exception { logger.debug("Stopping transport. Socket is [{}]", socketDescription); stop = true; if (socketChannel != null && socketChannel.isOpen()) { socketChannel.close(); } if (selfThread != null) { selfThread.join(100); } clearBuffer(); logger.debug("Transport is stopped. Socket is [{}]", socketDescription); } public void release() throws Exception { stop(); destAddress = null; } private void clearBuffer() throws IOException { bufferSize = DEFAULT_BUFFER_SIZE; buffer = ByteBuffer.allocate(bufferSize); } public InetSocketAddress getDestAddress() { return this.destAddress; } public void setDestAddress(InetSocketAddress address) { this.destAddress = address; if (logger.isDebugEnabled()) { logger.debug("Destination address is set to [{}] : [{}]", destAddress.getHostName(), destAddress.getPort()); } } public void setOrigAddress(InetSocketAddress address) { this.origAddress = address; if (logger.isDebugEnabled()) { logger.debug("Origin address is set to [{}] : [{}]", origAddress.getHostName(), origAddress.getPort()); } } public InetSocketAddress getOrigAddress() { return this.origAddress; } public void sendMessage(ByteBuffer bytes) throws IOException { if (logger.isDebugEnabled()) { logger.debug("About to send a byte buffer of size [{}] over the TCP nio socket [{}]", bytes.array().length, socketDescription); } int rc = 0; // PCB - removed locking // ZhixiaoLuo: Fix #28, without the lock the data in the socketChannel will get mixed in multi-threads. lock.lock(); try { while (rc < bytes.array().length) { rc += socketChannel.write(bytes); } } catch (Exception e) { logger.error("Unable to send message", e); throw new IOException("Error while sending message: " + e); } finally { lock.unlock(); } if (rc == -1) { throw new IOException("Connection closed"); } else if (rc == 0) { logger.error("socketChannel.write(bytes) - returned zero indicating that perhaps the write buffer is full"); } if (logger.isDebugEnabled()) { logger.debug("Sent a byte buffer of size [{}] over the TCP nio socket [{}]", bytes.array().length, socketDescription); } } @Override public String toString() { StringBuffer buffer = new StringBuffer(); buffer.append("Transport to "); if (this.destAddress != null) { buffer.append(this.destAddress.getHostName()); buffer.append(":"); buffer.append(this.destAddress.getPort()); } else { buffer.append("null"); } buffer.append("@"); buffer.append(super.toString()); return buffer.toString(); } boolean isConnected() { return socketChannel != null && socketChannel.isOpen() && socketChannel.isConnected(); } /** * Adds data to storage * * @param data data to add */ private void append(byte[] data) { if (storage.position() + data.length >= storage.capacity()) { ByteBuffer tmp = ByteBuffer.allocate(storage.limit() + data.length * 2); byte[] tmpData = new byte[storage.position()]; storage.flip(); storage.get(tmpData); tmp.put(tmpData); storage = tmp; logger.warn("Increase storage size. Current size is {}", storage.array().length); } try { storage.put(data); } catch (BufferOverflowException boe) { logger.error("Buffer overflow occured", boe); } boolean messageReceived; do { messageReceived = seekMessage(); } while (messageReceived); } private boolean seekMessage() { // make sure there's actual data written on the buffer if (storage.position() == 0) { return false; } storage.flip(); try { // get first four bytes for version and message length // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Version | Message Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ int tmp = storage.getInt(); // reset position so we can now read whole message storage.position(0); // check that version is 1, as per RFC 3588 - Section 3: // This Version field MUST be set to 1 to indicate Diameter Version 1 byte vers = (byte) (tmp >> 24); if (vers != 1) { // ZhixiaoLuo: fix #28, if unlucky storage.limit < data.length(1024), then always failed to do storage.put(data) // ZhixiaoLuo: and get BufferOverflowException in append(data) storage.clear(); logger.error("Invalid message version detected [" + vers + "]"); return false; } // extract the message length, so we know how much to read int messageLength = (tmp & 0xFFFFFF); // verify that we do have the whole message in the storage if (storage.limit() < messageLength) { // we don't have it all.. let's restore buffer to receive more storage.position(storage.limit()); storage.limit(storage.capacity()); logger.debug("Received partial message, waiting for remaining (expected: {} bytes, got {} bytes).", messageLength, storage.position()); return false; } // read the complete message byte[] data = new byte[messageLength]; storage.get(data); storage.compact(); try { // make a message out of data and process it logger.debug("Passing message on to parent"); getParent().onMessageReceived(ByteBuffer.wrap(data)); logger.debug("Finished passing message on to parent"); } catch (AvpDataException e) { logger.debug("Garbage was received. Discarding."); storage.clear(); getParent().onAvpDataException(e); } } catch (BufferUnderflowException bue) { // we don't have enough data to read message length.. wait for more storage.position(storage.limit()); storage.limit(storage.capacity()); logger.debug("Buffer underflow occured, waiting for more data.", bue); return false; } return true; } }