/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.common.io.nio2; import java.io.IOException; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.ClosedChannelException; import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Queue; import java.util.concurrent.LinkedTransferQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.RuntimeSshException; import org.apache.sshd.common.future.CloseFuture; import org.apache.sshd.common.io.IoHandler; import org.apache.sshd.common.io.IoSession; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.Readable; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.closeable.AbstractCloseable; /** * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ public class Nio2Session extends AbstractCloseable implements IoSession { public static final int DEFAULT_READBUF_SIZE = 32 * 1024; private static final AtomicLong SESSION_ID_GENERATOR = new AtomicLong(100L); private final long id = SESSION_ID_GENERATOR.incrementAndGet(); private final Nio2Service service; private final IoHandler ioHandler; private final AsynchronousSocketChannel socketChannel; private final Map<Object, Object> attributes = new HashMap<>(); private final SocketAddress localAddress; private final SocketAddress remoteAddress; private final FactoryManager manager; private final Queue<Nio2DefaultIoWriteFuture> writes = new LinkedTransferQueue<>(); private final AtomicReference<Nio2DefaultIoWriteFuture> currentWrite = new AtomicReference<>(); public Nio2Session(Nio2Service service, FactoryManager manager, IoHandler handler, AsynchronousSocketChannel socket) throws IOException { this.service = Objects.requireNonNull(service, "No service instance"); this.manager = Objects.requireNonNull(manager, "No factory manager"); this.ioHandler = Objects.requireNonNull(handler, "No IoHandler"); this.socketChannel = Objects.requireNonNull(socket, "No socket channel"); this.localAddress = socket.getLocalAddress(); this.remoteAddress = socket.getRemoteAddress(); if (log.isDebugEnabled()) { log.debug("Creating IoSession on {} from {}", localAddress, remoteAddress); } } @Override public long getId() { return id; } @Override public Object getAttribute(Object key) { return attributes.get(key); } @Override public Object setAttribute(Object key, Object value) { return attributes.put(key, value); } @Override public SocketAddress getRemoteAddress() { return remoteAddress; } @Override public SocketAddress getLocalAddress() { return localAddress; } public AsynchronousSocketChannel getSocket() { return socketChannel; } public IoHandler getIoHandler() { return ioHandler; } public void suspend() { AsynchronousSocketChannel socket = getSocket(); try { socket.shutdownInput(); } catch (IOException e) { if (log.isDebugEnabled()) { log.debug("suspend({}) failed {{}) to shutdown input: {}", this, e.getClass().getSimpleName(), e.getMessage()); } } try { socket.shutdownOutput(); } catch (IOException e) { if (log.isDebugEnabled()) { log.debug("suspend({}) failed {{}) to shutdown output: {}", this, e.getClass().getSimpleName(), e.getMessage()); } } } @Override public IoWriteFuture write(Buffer buffer) { if (log.isDebugEnabled()) { log.debug("Writing {} bytes", buffer.available()); } ByteBuffer buf = ByteBuffer.wrap(buffer.array(), buffer.rpos(), buffer.available()); final Nio2DefaultIoWriteFuture future = new Nio2DefaultIoWriteFuture(null, buf); if (isClosing()) { Throwable exc = new ClosedChannelException(); future.setException(exc); exceptionCaught(exc); return future; } writes.add(future); startWriting(); return future; } protected void exceptionCaught(Throwable exc) { if (!closeFuture.isClosed()) { AsynchronousSocketChannel socket = getSocket(); if (isClosing() || !socket.isOpen()) { close(true); } else { IoHandler handler = getIoHandler(); try { if (log.isDebugEnabled()) { log.debug("exceptionCaught({}) caught {}[{}] - calling handler", this, exc.getClass().getSimpleName(), exc.getMessage()); } handler.exceptionCaught(this, exc); } catch (Throwable e) { Throwable t = GenericUtils.peelException(e); if (log.isDebugEnabled()) { log.debug("exceptionCaught({}) Exception handler threw {}, closing the session: {}", this, t.getClass().getSimpleName(), t.getMessage()); } if (log.isTraceEnabled()) { log.trace("exceptionCaught(" + this + ") exception handler failure details", t); } close(true); } } } } @Override protected CloseFuture doCloseGracefully() { return builder().when(writes).build().close(false); } @Override protected void doCloseImmediately() { for (;;) { Nio2DefaultIoWriteFuture future = writes.poll(); if (future != null) { future.setException(new ClosedChannelException()); } else { break; } } AsynchronousSocketChannel socket = getSocket(); try { socket.close(); } catch (IOException e) { log.info("doCloseImmediately(" + this + ") exception caught while closing socket", e); } service.sessionClosed(this); super.doCloseImmediately(); IoHandler handler = getIoHandler(); try { handler.sessionClosed(this); } catch (Throwable e) { if (log.isDebugEnabled()) { log.debug("doCloseImmediately({}) {} while calling IoHandler#sessionClosed: {}", this, e.getClass().getSimpleName(), e.getMessage()); } if (log.isTraceEnabled()) { log.trace("doCloseImmediately(" + this + ") IoHandler#sessionClosed failure details", e); } } } @Override // co-variant return public Nio2Service getService() { return service; } public void startReading() { startReading(manager.getIntProperty(FactoryManager.NIO2_READ_BUFFER_SIZE, DEFAULT_READBUF_SIZE)); } public void startReading(int bufSize) { startReading(new byte[bufSize]); } public void startReading(byte[] buf) { startReading(buf, 0, buf.length); } public void startReading(byte[] buf, int offset, int len) { startReading(ByteBuffer.wrap(buf, offset, len)); } public void startReading(ByteBuffer buffer) { doReadCycle(buffer, Readable.readable(buffer)); } protected void doReadCycle(ByteBuffer buffer, Readable bufReader) { Nio2CompletionHandler<Integer, Object> completion = Objects.requireNonNull(createReadCycleCompletionHandler(buffer, bufReader), "No completion handler created"); doReadCycle(buffer, completion); } protected Nio2CompletionHandler<Integer, Object> createReadCycleCompletionHandler(final ByteBuffer buffer, final Readable bufReader) { return new Nio2CompletionHandler<Integer, Object>() { @Override protected void onCompleted(Integer result, Object attachment) { handleReadCycleCompletion(buffer, bufReader, this, result, attachment); } @Override protected void onFailed(Throwable exc, Object attachment) { handleReadCycleFailure(buffer, bufReader, exc, attachment); } }; } protected void handleReadCycleCompletion( ByteBuffer buffer, Readable bufReader, Nio2CompletionHandler<Integer, Object> completionHandler, Integer result, Object attachment) { try { if (result >= 0) { if (log.isDebugEnabled()) { log.debug("handleReadCycleCompletion({}) read {} bytes", this, result); } buffer.flip(); IoHandler handler = getIoHandler(); handler.messageReceived(this, bufReader); if (!closeFuture.isClosed()) { // re-use reference for next iteration since we finished processing it buffer.clear(); doReadCycle(buffer, completionHandler); } else { if (log.isDebugEnabled()) { log.debug("handleReadCycleCompletion({}) IoSession has been closed, stop reading", this); } } } else { if (log.isDebugEnabled()) { log.debug("handleReadCycleCompletion({}) Socket has been disconnected (result={}), closing IoSession now", this, result); } close(true); } } catch (Throwable exc) { completionHandler.failed(exc, attachment); } } protected void handleReadCycleFailure(ByteBuffer buffer, Readable bufReader, Throwable exc, Object attachment) { exceptionCaught(exc); } protected void doReadCycle(ByteBuffer buffer, Nio2CompletionHandler<Integer, Object> completion) { AsynchronousSocketChannel socket = getSocket(); long readTimeout = manager.getLongProperty(FactoryManager.NIO2_READ_TIMEOUT, FactoryManager.DEFAULT_NIO2_READ_TIMEOUT); socket.read(buffer, readTimeout, TimeUnit.MILLISECONDS, null, completion); } protected void startWriting() { Nio2DefaultIoWriteFuture future = writes.peek(); if (future != null) { if (currentWrite.compareAndSet(null, future)) { try { AsynchronousSocketChannel socket = getSocket(); ByteBuffer buffer = future.getBuffer(); Nio2CompletionHandler<Integer, Object> handler = Objects.requireNonNull(createWriteCycleCompletionHandler(future, socket, buffer), "No write cycle completion handler created"); doWriteCycle(buffer, handler); } catch (Throwable e) { future.setWritten(); if (e instanceof RuntimeException) { throw (RuntimeException) e; } else { throw new RuntimeSshException(e); } } } } } protected void doWriteCycle(ByteBuffer buffer, Nio2CompletionHandler<Integer, Object> completion) { AsynchronousSocketChannel socket = getSocket(); long writeTimeout = manager.getLongProperty(FactoryManager.NIO2_MIN_WRITE_TIMEOUT, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT); socket.write(buffer, writeTimeout, TimeUnit.MILLISECONDS, null, completion); } protected Nio2CompletionHandler<Integer, Object> createWriteCycleCompletionHandler( final Nio2DefaultIoWriteFuture future, final AsynchronousSocketChannel socket, final ByteBuffer buffer) { final int writeLen = buffer.remaining(); return new Nio2CompletionHandler<Integer, Object>() { @Override protected void onCompleted(Integer result, Object attachment) { handleCompletedWriteCycle(future, socket, buffer, writeLen, this, result, attachment); } @Override protected void onFailed(Throwable exc, Object attachment) { handleWriteCycleFailure(future, socket, buffer, writeLen, exc, attachment); } }; } protected void handleCompletedWriteCycle( Nio2DefaultIoWriteFuture future, AsynchronousSocketChannel socket, ByteBuffer buffer, int writeLen, Nio2CompletionHandler<Integer, Object> completionHandler, Integer result, Object attachment) { if (buffer.hasRemaining()) { try { socket.write(buffer, null, completionHandler); } catch (Throwable t) { if (log.isDebugEnabled()) { log.debug("handleCompletedWriteCycle(" + this + ") Exception caught while writing " + writeLen + " bytes", t); } future.setWritten(); finishWrite(future); } } else { if (log.isDebugEnabled()) { log.debug("handleCompletedWriteCycle({}) finished writing len={}", this, writeLen); } future.setWritten(); finishWrite(future); } } protected void handleWriteCycleFailure( Nio2DefaultIoWriteFuture future, AsynchronousSocketChannel socket, ByteBuffer buffer, int writeLen, Throwable exc, Object attachment) { if (log.isDebugEnabled()) { log.debug("handleWriteCycleFailure({}) failed ({}) to write {} bytes: {}", this, exc.getClass().getSimpleName(), writeLen, exc.getMessage()); } if (log.isTraceEnabled()) { log.trace("handleWriteCycleFailure(" + this + ") len=" + writeLen + " failure details", exc); } future.setException(exc); exceptionCaught(exc); finishWrite(future); } protected void finishWrite(Nio2DefaultIoWriteFuture future) { writes.remove(future); currentWrite.compareAndSet(future, null); startWriting(); } @Override public String toString() { return getClass().getSimpleName() + "[local=" + getLocalAddress() + ", remote=" + getRemoteAddress() + "]"; } }