/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.common.session.helpers; import java.io.IOException; import java.io.InterruptedIOException; import java.net.SocketAddress; import java.net.SocketTimeoutException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Date; import java.util.EnumMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import org.apache.sshd.common.AttributeStore; import org.apache.sshd.common.Closeable; import org.apache.sshd.common.Factory; import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.NamedFactory; import org.apache.sshd.common.NamedResource; import org.apache.sshd.common.PropertyResolver; import org.apache.sshd.common.PropertyResolverUtils; import org.apache.sshd.common.RuntimeSshException; import org.apache.sshd.common.Service; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; import org.apache.sshd.common.channel.ChannelListener; import org.apache.sshd.common.cipher.Cipher; import org.apache.sshd.common.cipher.CipherInformation; import org.apache.sshd.common.compression.Compression; import org.apache.sshd.common.compression.CompressionInformation; import org.apache.sshd.common.digest.Digest; import org.apache.sshd.common.forward.PortForwardingEventListener; import org.apache.sshd.common.future.DefaultKeyExchangeFuture; import org.apache.sshd.common.future.DefaultSshFuture; import org.apache.sshd.common.future.KeyExchangeFuture; import org.apache.sshd.common.io.IoSession; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.kex.AbstractKexFactoryManager; import org.apache.sshd.common.kex.KexProposalOption; import org.apache.sshd.common.kex.KexState; import org.apache.sshd.common.kex.KeyExchange; import org.apache.sshd.common.mac.Mac; import org.apache.sshd.common.mac.MacInformation; import org.apache.sshd.common.random.Random; import org.apache.sshd.common.session.ReservedSessionMessagesHandler; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.session.SessionListener; import org.apache.sshd.common.session.SessionWorkBuffer; import org.apache.sshd.common.util.EventListenerUtils; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.Invoker; import org.apache.sshd.common.util.NumberUtils; import org.apache.sshd.common.util.Pair; import org.apache.sshd.common.util.Readable; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.buffer.BufferUtils; import org.apache.sshd.common.util.buffer.ByteArrayBuffer; /** * <P> * The AbstractSession handles all the basic SSH protocol such as key exchange, authentication, * encoding and decoding. Both server side and client side sessions should inherit from this * abstract class. Some basic packet processing methods are defined but the actual call to these * methods should be done from the {@link #handleMessage(Buffer)} * method, which is dependent on the state and side of this session. * </P> * * TODO: if there is any very big packet, decoderBuffer and uncompressBuffer will get quite big * and they won't be resized down at any time. Though the packet size is really limited * by the channel max packet size * * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ public abstract class AbstractSession extends AbstractKexFactoryManager implements Session { /** * Name of the property where this session is stored in the attributes of the * underlying MINA session. See {@link #getSession(IoSession, boolean)} * and {@link #attachSession(IoSession, AbstractSession)}. */ public static final String SESSION = "org.apache.sshd.session"; /** * Client or server side */ protected final boolean isServer; /** * The underlying MINA session */ protected final IoSession ioSession; /** * The pseudo random generator */ protected final Random random; /** * Boolean indicating if this session has been authenticated or not */ protected boolean authed; /** * The name of the authenticated user */ protected String username; /** * Session listeners container */ protected final Collection<SessionListener> sessionListeners = new CopyOnWriteArraySet<>(); protected final SessionListener sessionListenerProxy; /** * Channel events listener container */ protected final Collection<ChannelListener> channelListeners = new CopyOnWriteArraySet<>(); protected final ChannelListener channelListenerProxy; /** * Port forwarding events listener container */ protected final Collection<PortForwardingEventListener> tunnelListeners = new CopyOnWriteArraySet<>(); protected final PortForwardingEventListener tunnelListenerProxy; /* * Key exchange support */ protected byte[] sessionId; protected String serverVersion; protected String clientVersion; // if empty then means not-initialized protected final Map<KexProposalOption, String> serverProposal = new EnumMap<>(KexProposalOption.class); protected final Map<KexProposalOption, String> clientProposal = new EnumMap<>(KexProposalOption.class); protected final Map<KexProposalOption, String> negotiationResult = new EnumMap<>(KexProposalOption.class); protected byte[] i_c; // the payload of the client's SSH_MSG_KEXINIT protected byte[] i_s; // the payload of the factoryManager's SSH_MSG_KEXINIT protected KeyExchange kex; protected Boolean firstKexPacketFollows; protected final AtomicReference<KexState> kexState = new AtomicReference<>(KexState.UNKNOWN); protected final AtomicReference<DefaultKeyExchangeFuture> kexFutureHolder = new AtomicReference<>(null); /* * SSH packets encoding / decoding support */ protected Cipher outCipher; protected Cipher inCipher; protected int outCipherSize = 8; protected int inCipherSize = 8; protected Mac outMac; protected Mac inMac; protected byte[] inMacResult; protected Compression outCompression; protected Compression inCompression; protected long seqi; protected long seqo; protected SessionWorkBuffer uncompressBuffer; protected final SessionWorkBuffer decoderBuffer; protected int decoderState; protected int decoderLength; protected final Object encodeLock = new Object(); protected final Object decodeLock = new Object(); protected final Object requestLock = new Object(); // Session timeout measurements protected long authTimeoutStart = System.currentTimeMillis(); protected long idleTimeoutStart = System.currentTimeMillis(); protected final AtomicReference<TimeoutStatus> timeoutStatus = new AtomicReference<>(TimeoutStatus.NoTimeout); /* * Rekeying */ protected final AtomicLong inPacketsCount = new AtomicLong(0L); protected final AtomicLong outPacketsCount = new AtomicLong(0L); protected final AtomicLong inBytesCount = new AtomicLong(0L); protected final AtomicLong outBytesCount = new AtomicLong(0L); protected final AtomicLong inBlocksCount = new AtomicLong(0L); protected final AtomicLong outBlocksCount = new AtomicLong(0L); protected final AtomicLong lastKeyTimeValue = new AtomicLong(0L); // we initialize them here in case super constructor calls some methods that use these values protected long maxRekyPackets = FactoryManager.DEFAULT_REKEY_PACKETS_LIMIT; protected long maxRekeyBytes = FactoryManager.DEFAULT_REKEY_BYTES_LIMIT; protected long maxRekeyInterval = FactoryManager.DEFAULT_REKEY_TIME_LIMIT; protected final Queue<PendingWriteFuture> pendingPackets = new LinkedList<>(); protected Service currentService; // SSH_MSG_IGNORE stream padding protected int ignorePacketDataLength = FactoryManager.DEFAULT_IGNORE_MESSAGE_SIZE; protected long ignorePacketsFrequency = FactoryManager.DEFAULT_IGNORE_MESSAGE_FREQUENCY; protected int ignorePacketsVariance = FactoryManager.DEFAULT_IGNORE_MESSAGE_VARIANCE; protected final AtomicLong maxRekeyBlocks = new AtomicLong(FactoryManager.DEFAULT_REKEY_BYTES_LIMIT / 16); protected final AtomicLong ignorePacketsCount = new AtomicLong(FactoryManager.DEFAULT_IGNORE_MESSAGE_FREQUENCY); /** * The factory manager used to retrieve factories of Ciphers, Macs and other objects */ private final FactoryManager factoryManager; /** * The session specific properties */ private final Map<String, Object> properties = new ConcurrentHashMap<>(); /** * Used to wait for global requests result synchronous wait */ private final AtomicReference<Object> requestResult = new AtomicReference<>(); /** * Session specific attributes */ private final Map<AttributeKey<?>, Object> attributes = new ConcurrentHashMap<>(); private ReservedSessionMessagesHandler reservedSessionMessagesHandler; /** * Create a new session. * * @param isServer {@code true} if this is a server session, {@code false} if client one * @param factoryManager the factory manager * @param ioSession the underlying MINA session */ protected AbstractSession(boolean isServer, FactoryManager factoryManager, IoSession ioSession) { super(Objects.requireNonNull(factoryManager, "No factory manager provided")); this.isServer = isServer; this.factoryManager = factoryManager; this.ioSession = ioSession; this.decoderBuffer = new SessionWorkBuffer(this); attachSession(ioSession, this); Factory<Random> factory = ValidateUtils.checkNotNull(factoryManager.getRandomFactory(), "No random factory for %s", ioSession); random = ValidateUtils.checkNotNull(factory.create(), "No randomizer instance for %s", ioSession); refreshConfiguration(); ClassLoader loader = getClass().getClassLoader(); sessionListenerProxy = EventListenerUtils.proxyWrapper(SessionListener.class, loader, sessionListeners); channelListenerProxy = EventListenerUtils.proxyWrapper(ChannelListener.class, loader, channelListeners); tunnelListenerProxy = EventListenerUtils.proxyWrapper(PortForwardingEventListener.class, loader, tunnelListeners); } protected void signalSessionCreated(IoSession ioSession) throws Exception { try { invokeSessionSignaller(l -> { signalSessionCreated(l); return null; }); } catch (Throwable err) { Throwable e = GenericUtils.peelException(err); if (log.isDebugEnabled()) { log.debug("Failed ({}) to announce session={} created: {}", e.getClass().getSimpleName(), ioSession, e.getMessage()); } if (log.isTraceEnabled()) { log.trace("Session=" + ioSession + " creation failure details", e); } if (e instanceof Exception) { throw (Exception) e; } else { throw new RuntimeSshException(e); } } } protected void signalSessionCreated(SessionListener listener) { if (listener == null) { return; } listener.sessionCreated(this); } /** * Retrieve the session from the MINA session. * If the session has not been attached, an {@link IllegalStateException} * will be thrown * * @param ioSession the MINA session * @return the session attached to the MINA session * @see #getSession(IoSession, boolean) */ public static AbstractSession getSession(IoSession ioSession) { return getSession(ioSession, false); } /** * Retrieve the session from the MINA session. * If the session has not been attached and <tt>allowNull</tt> is <code>false</code>, * an {@link IllegalStateException} will be thrown, else a {@code null} will * be returned * * @param ioSession the MINA session * @param allowNull if <code>true</code>, a {@code null} value may be * returned if no session is attached * @return the session attached to the MINA session or {@code null} */ public static AbstractSession getSession(IoSession ioSession, boolean allowNull) { AbstractSession session = (AbstractSession) ioSession.getAttribute(SESSION); if ((session == null) && (!allowNull)) { throw new IllegalStateException("No session available"); } return session; } /** * Attach a session to the MINA session * * @param ioSession the MINA session * @param session the session to attach */ public static void attachSession(IoSession ioSession, AbstractSession session) { Objects.requireNonNull(ioSession, "No I/O session").setAttribute(SESSION, Objects.requireNonNull(session, "No SSH session")); } @Override public String getServerVersion() { return serverVersion; } @Override public String getClientVersion() { return clientVersion; } @Override public KeyExchange getKex() { return kex; } @Override public byte[] getSessionId() { // return a clone to avoid anyone changing the internal value return NumberUtils.isEmpty(sessionId) ? sessionId : sessionId.clone(); } @Override public IoSession getIoSession() { return ioSession; } /** * @param knownAddress Any externally set peer address - e.g., due to some * proxy mechanism meta-data * @return The external address if not {@code null} otherwise, the {@code IoSession} * peer address */ protected SocketAddress resolvePeerAddress(SocketAddress knownAddress) { if (knownAddress != null) { return knownAddress; } IoSession s = getIoSession(); return (s == null) ? null : s.getRemoteAddress(); } @Override public FactoryManager getFactoryManager() { return factoryManager; } @Override public PropertyResolver getParentPropertyResolver() { return getFactoryManager(); } @Override public Map<String, Object> getProperties() { return properties; } @Override public String getNegotiatedKexParameter(KexProposalOption paramType) { if (paramType == null) { return null; } synchronized (negotiationResult) { return negotiationResult.get(paramType); } } @Override public CipherInformation getCipherInformation(boolean incoming) { return incoming ? inCipher : outCipher; } @Override public CompressionInformation getCompressionInformation(boolean incoming) { return incoming ? inCompression : outCompression; } @Override public MacInformation getMacInformation(boolean incoming) { return incoming ? inMac : outMac; } @Override public boolean isAuthenticated() { return authed; } @Override public void setAuthenticated() throws IOException { this.authed = true; signalSessionEvent(SessionListener.Event.Authenticated); } /** * <P>Main input point for the MINA framework.</P> * * <P> * This method will be called each time new data is received on * the socket and will append it to the input buffer before * calling the {@link #decode()} method. * </P> * * @param buffer the new buffer received * @throws Exception if an error occurs while decoding or handling the data */ public void messageReceived(Readable buffer) throws Exception { synchronized (decodeLock) { decoderBuffer.putBuffer(buffer); // One of those property will be set by the constructor and the other // one should be set by the readIdentification method if (clientVersion == null || serverVersion == null) { if (readIdentification(decoderBuffer)) { decoderBuffer.compact(); } else { return; } } decode(); } } /** * Refresh whatever internal configuration is not {@code final} */ protected void refreshConfiguration() { synchronized (random) { // re-keying configuration maxRekeyBytes = this.getLongProperty(FactoryManager.REKEY_BYTES_LIMIT, maxRekeyBytes); maxRekeyInterval = this.getLongProperty(FactoryManager.REKEY_TIME_LIMIT, maxRekeyInterval); maxRekyPackets = this.getLongProperty(FactoryManager.REKEY_PACKETS_LIMIT, maxRekyPackets); // intermittent SSH_MSG_IGNORE stream padding ignorePacketDataLength = this.getIntProperty(FactoryManager.IGNORE_MESSAGE_SIZE, FactoryManager.DEFAULT_IGNORE_MESSAGE_SIZE); ignorePacketsFrequency = this.getLongProperty(FactoryManager.IGNORE_MESSAGE_FREQUENCY, FactoryManager.DEFAULT_IGNORE_MESSAGE_FREQUENCY); ignorePacketsVariance = this.getIntProperty(FactoryManager.IGNORE_MESSAGE_VARIANCE, FactoryManager.DEFAULT_IGNORE_MESSAGE_VARIANCE); if (ignorePacketsVariance >= ignorePacketsFrequency) { ignorePacketsVariance = 0; } ignorePacketsCount.set(calculateNextIgnorePacketCount(random, ignorePacketsFrequency, ignorePacketsVariance)); } } /** * Abstract method for processing incoming decoded packets. * The given buffer will hold the decoded packet, starting from * the command byte at the read position. * * @param buffer The {@link Buffer} containing the packet - it may be * re-used to generate the response once request has been decoded * @throws Exception if an exception occurs while handling this packet. * @see #doHandleMessage(Buffer) */ protected void handleMessage(Buffer buffer) throws Exception { try { synchronized (lock) { doHandleMessage(buffer); } } catch (Throwable e) { DefaultKeyExchangeFuture kexFuture = kexFutureHolder.get(); // if have any ongoing KEX notify it about the failure if (kexFuture != null) { synchronized (kexFuture) { Object value = kexFuture.getValue(); if (value == null) { kexFuture.setValue(e); } } } if (e instanceof Exception) { throw (Exception) e; } else { throw new RuntimeSshException(e); } } } protected void doHandleMessage(Buffer buffer) throws Exception { int cmd = buffer.getUByte(); if (log.isTraceEnabled()) { log.trace("doHandleMessage({}) process {}", this, SshConstants.getCommandMessageName(cmd)); } switch (cmd) { case SshConstants.SSH_MSG_DISCONNECT: handleDisconnect(buffer); break; case SshConstants.SSH_MSG_IGNORE: handleIgnore(buffer); break; case SshConstants.SSH_MSG_UNIMPLEMENTED: handleUnimplemented(buffer); break; case SshConstants.SSH_MSG_DEBUG: handleDebug(buffer); break; case SshConstants.SSH_MSG_SERVICE_REQUEST: handleServiceRequest(buffer); break; case SshConstants.SSH_MSG_SERVICE_ACCEPT: handleServiceAccept(buffer); break; case SshConstants.SSH_MSG_KEXINIT: handleKexInit(buffer); break; case SshConstants.SSH_MSG_NEWKEYS: handleNewKeys(cmd, buffer); break; default: if ((cmd >= SshConstants.SSH_MSG_KEX_FIRST) && (cmd <= SshConstants.SSH_MSG_KEX_LAST)) { if (firstKexPacketFollows != null) { try { if (!handleFirstKexPacketFollows(cmd, buffer, firstKexPacketFollows)) { break; } } finally { firstKexPacketFollows = null; // avoid re-checking } } handleKexMessage(cmd, buffer); } else if (currentService != null) { currentService.process(cmd, buffer); resetIdleTimeout(); } else { throw new IllegalStateException("Unsupported command " + SshConstants.getCommandMessageName(cmd)); } break; } checkRekey(); } protected boolean handleFirstKexPacketFollows(int cmd, Buffer buffer, boolean followFlag) { if (!followFlag) { return true; // if 1st KEX packet does not follow then process the command } /* * According to RFC4253 section 7.1: * * If the other party's guess was wrong, and this field was TRUE, * the next packet MUST be silently ignored */ for (KexProposalOption option : new KexProposalOption[]{KexProposalOption.ALGORITHMS, KexProposalOption.SERVERKEYS}) { Pair<String, String> result = comparePreferredKexProposalOption(option); if (result != null) { if (log.isDebugEnabled()) { log.debug("handleFirstKexPacketFollows({})[{}] 1st follow KEX packet {} option mismatch: client={}, server={}", this, SshConstants.getCommandMessageName(cmd), option, result.getFirst(), result.getSecond()); } return false; } } return true; } protected Pair<String, String> comparePreferredKexProposalOption(KexProposalOption option) { String[] clientPreferences = GenericUtils.split(clientProposal.get(option), ','); String clientValue = clientPreferences[0]; String[] serverPreferences = GenericUtils.split(serverProposal.get(option), ','); String serverValue = serverPreferences[0]; return clientValue.equals(serverValue) ? null : new Pair<>(clientValue, serverValue); } protected void handleKexMessage(int cmd, Buffer buffer) throws Exception { validateKexState(cmd, KexState.RUN); if (kex.next(cmd, buffer)) { if (log.isDebugEnabled()) { log.debug("handleKexMessage({})[{}] KEX processing complete after cmd={}", this, kex.getName(), cmd); } checkKeys(); sendNewKeys(); kexState.set(KexState.KEYS); } else { if (log.isDebugEnabled()) { log.debug("handleKexMessage({})[{}] more KEX packets expected after cmd={}", this, kex.getName(), cmd); } } } @Override public IoWriteFuture sendIgnoreMessage(byte... data) throws IOException { data = (data == null) ? GenericUtils.EMPTY_BYTE_ARRAY : data; Buffer buffer = createBuffer(SshConstants.SSH_MSG_IGNORE, data.length + Byte.SIZE); buffer.putBytes(data); return writePacket(buffer); } protected void handleIgnore(Buffer buffer) throws Exception { // malformed ignore message - ignore (even though we don't have to, but we can be tolerant in this case) if (!buffer.isValidMessageStructure(byte[].class)) { if (log.isTraceEnabled()) { log.trace("handleIgnore({}) ignore malformed message", this); } return; } resetIdleTimeout(); ReservedSessionMessagesHandler handler = resolveReservedSessionMessagesHandler(); handler.handleIgnoreMessage(this, buffer); } protected void handleUnimplemented(Buffer buffer) throws Exception { if (!buffer.isValidMessageStructure(int.class)) { if (log.isTraceEnabled()) { log.trace("handleUnimplemented({}) ignore malformed message", this); } return; } resetIdleTimeout(); ReservedSessionMessagesHandler handler = resolveReservedSessionMessagesHandler(); handler.handleUnimplementedMessage(this, buffer); } @Override public IoWriteFuture sendDebugMessage(boolean display, Object msg, String lang) throws IOException { String text = Objects.toString(msg, ""); lang = (lang == null) ? "" : lang; Buffer buffer = createBuffer(SshConstants.SSH_MSG_DEBUG, text.length() + lang.length() + Integer.SIZE /* a few extras */); buffer.putBoolean(display); buffer.putString(text); buffer.putString(lang); return writePacket(buffer); } protected void handleDebug(Buffer buffer) throws Exception { // malformed ignore message - ignore (even though we don't have to, but we can be tolerant in this case) if (!buffer.isValidMessageStructure(boolean.class, String.class, String.class)) { if (log.isTraceEnabled()) { log.trace("handleDebug({}) ignore malformed message", this); } return; } resetIdleTimeout(); ReservedSessionMessagesHandler handler = resolveReservedSessionMessagesHandler(); handler.handleDebugMessage(this, buffer); } protected ReservedSessionMessagesHandler resolveReservedSessionMessagesHandler() { ReservedSessionMessagesHandler handler = getReservedSessionMessagesHandler(); return (handler == null) ? ReservedSessionMessagesHandlerAdapter.DEFAULT : handler; } protected void handleDisconnect(Buffer buffer) throws Exception { int code = buffer.getInt(); String message = buffer.getString(); String languageTag = buffer.getString(); handleDisconnect(code, message, languageTag, buffer); } protected void handleDisconnect(int code, String msg, String lang, Buffer buffer) throws Exception { if (log.isDebugEnabled()) { log.debug("handleDisconnect({}) SSH_MSG_DISCONNECT reason={}, [lang={}] msg={}", this, SshConstants.getDisconnectReasonName(code), lang, msg); } close(true); } protected void handleServiceRequest(Buffer buffer) throws Exception { String serviceName = buffer.getString(); handleServiceRequest(serviceName, buffer); } protected boolean handleServiceRequest(String serviceName, Buffer buffer) throws Exception { if (log.isDebugEnabled()) { log.debug("handleServiceRequest({}) SSH_MSG_SERVICE_REQUEST '{}'", this, serviceName); } validateKexState(SshConstants.SSH_MSG_SERVICE_REQUEST, KexState.DONE); try { startService(serviceName); } catch (Throwable e) { if (log.isDebugEnabled()) { log.debug("handleServiceRequest({}) Service {} rejected: {} = {}", this, serviceName, e.getClass().getSimpleName(), e.getMessage()); } if (log.isTraceEnabled()) { log.trace("handleServiceRequest(" + this + ") service=" + serviceName + " rejection details", e); } disconnect(SshConstants.SSH2_DISCONNECT_SERVICE_NOT_AVAILABLE, "Bad service request: " + serviceName); return false; } if (log.isDebugEnabled()) { log.debug("handleServiceRequest({}) Accepted service {}", this, serviceName); } Buffer response = createBuffer(SshConstants.SSH_MSG_SERVICE_ACCEPT, Byte.SIZE + GenericUtils.length(serviceName)); response.putString(serviceName); writePacket(response); return true; } protected void handleServiceAccept(Buffer buffer) throws Exception { handleServiceAccept(buffer.getString(), buffer); } protected void handleServiceAccept(String serviceName, Buffer buffer) throws Exception { if (log.isDebugEnabled()) { log.debug("handleServiceAccept({}) SSH_MSG_SERVICE_ACCEPT service={}", this, serviceName); } validateKexState(SshConstants.SSH_MSG_SERVICE_ACCEPT, KexState.DONE); } protected void handleKexInit(Buffer buffer) throws Exception { if (log.isDebugEnabled()) { log.debug("handleKexInit({}) SSH_MSG_KEXINIT", this); } receiveKexInit(buffer); if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) { sendKexInit(); } else if (!kexState.compareAndSet(KexState.INIT, KexState.RUN)) { throw new IllegalStateException("Received SSH_MSG_KEXINIT while key exchange is running"); } Map<KexProposalOption, String> result = negotiate(); String kexAlgorithm = result.get(KexProposalOption.ALGORITHMS); kex = ValidateUtils.checkNotNull(NamedFactory.create(getKeyExchangeFactories(), kexAlgorithm), "Unknown negotiated KEX algorithm: %s", kexAlgorithm); kex.init(this, serverVersion.getBytes(StandardCharsets.UTF_8), clientVersion.getBytes(StandardCharsets.UTF_8), i_s, i_c); signalSessionEvent(SessionListener.Event.KexCompleted); } protected void handleNewKeys(int cmd, Buffer buffer) throws Exception { if (log.isDebugEnabled()) { log.debug("handleNewKeys({}) SSH_MSG_NEWKEYS command={}", this, SshConstants.getCommandMessageName(cmd)); } validateKexState(cmd, KexState.KEYS); receiveNewKeys(); DefaultKeyExchangeFuture kexFuture = kexFutureHolder.get(); if (kexFuture != null) { synchronized (kexFuture) { Object value = kexFuture.getValue(); if (value == null) { kexFuture.setValue(Boolean.TRUE); } } } signalSessionEvent(SessionListener.Event.KeyEstablished); synchronized (pendingPackets) { if (!pendingPackets.isEmpty()) { if (log.isDebugEnabled()) { log.debug("handleNewKeys({}) Dequeing {} pending packets", this, pendingPackets.size()); } synchronized (encodeLock) { PendingWriteFuture future; while ((future = pendingPackets.poll()) != null) { doWritePacket(future.getBuffer()).addListener(future); } } } kexState.set(KexState.DONE); } synchronized (lock) { lock.notifyAll(); } } protected void validateKexState(int cmd, KexState expected) { KexState actual = kexState.get(); if (!expected.equals(actual)) { throw new IllegalStateException("Received KEX command=" + SshConstants.getCommandMessageName(cmd) + " while in state=" + actual + " instead of " + expected); } } /** * Handle any exceptions that occurred on this session. * The session will be closed and a disconnect packet will be * sent before if the given exception is an {@link SshException}. * * @param t the exception to process */ @Override public void exceptionCaught(Throwable t) { State curState = state.get(); // Ignore exceptions that happen while closing immediately if ((!State.Opened.equals(curState)) && (!State.Graceful.equals(curState))) { if (log.isDebugEnabled()) { log.debug("exceptionCaught({}) ignore {} due to state={}, message='{}'", this, t.getClass().getSimpleName(), curState, t.getMessage()); } if (log.isTraceEnabled()) { log.trace("exceptionCaught(" + this + ")[state=" + curState + "] ignored exception details", t); } return; } log.warn("exceptionCaught({})[state={}] {}: {}", this, curState, t.getClass().getSimpleName(), t.getMessage()); if (log.isDebugEnabled()) { log.debug("exceptionCaught(" + this + ")[state=" + curState + "] details", t); } signalExceptionCaught(t); if (State.Opened.equals(curState) && (t instanceof SshException)) { int code = ((SshException) t).getDisconnectCode(); if (code > 0) { try { disconnect(code, t.getMessage()); } catch (Throwable t2) { if (log.isDebugEnabled()) { log.debug("exceptionCaught({}) {} while disconnect with code={}: {}", this, t2.getClass().getSimpleName(), SshConstants.getDisconnectReasonName(code), t2.getMessage()); } if (log.isTraceEnabled()) { log.trace("exceptionCaught(" + this + ")[code=" + SshConstants.getDisconnectReasonName(code) + "] disconnect exception details", t2); } } return; } } close(true); } protected void signalExceptionCaught(Throwable t) { try { invokeSessionSignaller(l -> { signalExceptionCaught(l, t); return null; }); } catch (Throwable err) { Throwable e = GenericUtils.peelException(err); if (log.isDebugEnabled()) { log.debug("exceptionCaught(" + this + ") signal session exception details", e); } if (log.isTraceEnabled()) { Throwable[] suppressed = e.getSuppressed(); if (GenericUtils.length(suppressed) > 0) { for (Throwable s : suppressed) { log.trace("exceptionCaught(" + this + ") suppressed session exception signalling", s); } } } } } protected void signalExceptionCaught(SessionListener listener, Throwable t) { if (listener == null) { return; } listener.sessionException(this, t); } @Override protected Closeable getInnerCloseable() { return builder() .parallel(getServices()) .close(ioSession) .build(); } @Override protected void preClose() { DefaultKeyExchangeFuture kexFuture = kexFutureHolder.get(); if (kexFuture != null) { // if have any pending KEX then notify it about the closing session synchronized (kexFuture) { Object value = kexFuture.getValue(); if (value == null) { kexFuture.setValue(new SshException("Session closing while KEX in progress")); } } } // if anyone waiting for global response notify them about the closing session synchronized (requestResult) { requestResult.set(GenericUtils.NULL); requestResult.notify(); } // Fire 'close' event try { signalSessionClosed(); } finally { // clear the listeners since we are closing the session (quicker GC) this.sessionListeners.clear(); this.channelListeners.clear(); this.tunnelListeners.clear(); } super.preClose(); } protected void signalSessionClosed() { try { invokeSessionSignaller(l -> { signalSessionClosed(l); return null; }); } catch (Throwable err) { Throwable e = GenericUtils.peelException(err); log.warn("signalSessionClosed({}) {} while signal session closed: {}", this, e.getClass().getSimpleName(), e.getMessage()); if (log.isDebugEnabled()) { log.debug("signalSessionClosed(" + this + ") signal session closed exception details", e); } if (log.isTraceEnabled()) { Throwable[] suppressed = e.getSuppressed(); if (GenericUtils.length(suppressed) > 0) { for (Throwable s : suppressed) { log.trace("signalSessionClosed(" + this + ") suppressed session closed signalling", s); } } } } } protected void signalSessionClosed(SessionListener listener) { if (listener == null) { return; } listener.sessionClosed(this); } protected List<Service> getServices() { return (currentService != null) ? Collections.singletonList(currentService) : Collections.emptyList(); } @Override public <T extends Service> T getService(Class<T> clazz) { for (Service s : getServices()) { if (clazz.isInstance(s)) { return clazz.cast(s); } } throw new IllegalStateException("Attempted to access unknown service " + clazz.getSimpleName()); } @Override public IoWriteFuture writePacket(Buffer buffer) throws IOException { // While exchanging key, queue high level packets if (!KexState.DONE.equals(kexState.get())) { byte cmd = buffer.array()[buffer.rpos()]; if (cmd > SshConstants.SSH_MSG_KEX_LAST) { synchronized (pendingPackets) { if (!KexState.DONE.equals(kexState.get())) { if (pendingPackets.isEmpty()) { log.debug("writePacket({})[{}] Start flagging packets as pending until key exchange is done", this, SshConstants.getCommandMessageName(cmd & 0xFF)); } PendingWriteFuture future = new PendingWriteFuture(buffer); pendingPackets.add(future); return future; } } } } try { return doWritePacket(buffer); } finally { resetIdleTimeout(); checkRekey(); } } @SuppressWarnings("unchecked") @Override public IoWriteFuture writePacket(Buffer buffer, long timeout, TimeUnit unit) throws IOException { IoWriteFuture writeFuture = writePacket(buffer); DefaultSshFuture<IoWriteFuture> future = (DefaultSshFuture<IoWriteFuture>) writeFuture; ScheduledExecutorService executor = factoryManager.getScheduledExecutorService(); ScheduledFuture<?> sched = executor.schedule(() -> { Throwable t = new TimeoutException("Timeout writing packet: " + timeout + " " + unit); if (log.isDebugEnabled()) { log.debug("writePacket({}): {}", AbstractSession.this, t.getMessage()); } future.setValue(t); }, timeout, unit); future.addListener(future1 -> sched.cancel(false)); return writeFuture; } protected IoWriteFuture doWritePacket(Buffer buffer) throws IOException { Buffer ignoreBuf = null; int ignoreDataLen = resolveIgnoreBufferDataLength(); if (ignoreDataLen > 0) { ignoreBuf = createBuffer(SshConstants.SSH_MSG_IGNORE, ignoreDataLen + Byte.SIZE); ignoreBuf.putInt(ignoreDataLen); int wpos = ignoreBuf.wpos(); synchronized (random) { random.fill(ignoreBuf.array(), wpos, ignoreDataLen); } ignoreBuf.wpos(wpos + ignoreDataLen); if (log.isDebugEnabled()) { log.debug("doWritePacket({}) append SSH_MSG_IGNORE message", this); } } int curPos = buffer.rpos(); byte[] data = buffer.array(); int cmd = data[curPos] & 0xFF; // usually the 1st byte is the command buffer = validateTargetBuffer(cmd, buffer); // Synchronize all write requests as needed by the encoding algorithm // and also queue the write request in this synchronized block to ensure // packets are sent in the correct order IoWriteFuture future; synchronized (encodeLock) { if (ignoreBuf != null) { ignoreBuf = encode(ignoreBuf); ioSession.write(ignoreBuf); } buffer = encode(buffer); future = ioSession.write(buffer); } return future; } protected int resolveIgnoreBufferDataLength() { if ((ignorePacketDataLength <= 0) || (ignorePacketsFrequency <= 0L) || (ignorePacketsVariance < 0)) { return 0; } long count = ignorePacketsCount.decrementAndGet(); if (count > 0L) { return 0; } synchronized (random) { ignorePacketsCount.set(calculateNextIgnorePacketCount(random, ignorePacketsFrequency, ignorePacketsVariance)); return ignorePacketDataLength + random.random(ignorePacketDataLength); } } protected long calculateNextIgnorePacketCount(Random r, long freq, int variance) { if ((freq <= 0L) || (variance < 0)) { return -1L; } if (variance == 0) { return freq; } int extra = r.random((variance < 0) ? (0 - variance) : variance); long count = (variance < 0) ? (freq - extra) : (freq + extra); if (log.isTraceEnabled()) { log.trace("calculateNextIgnorePacketCount({}) count={}", this, count); } return count; } @Override public Buffer request(String request, Buffer buffer, long timeout, TimeUnit unit) throws IOException { ValidateUtils.checkTrue(timeout > 0L, "Non-positive timeout requested: %d", timeout); long maxWaitMillis = TimeUnit.MILLISECONDS.convert(timeout, unit); if (maxWaitMillis <= 0L) { throw new IllegalArgumentException("Requested timeout for " + request + " below 1 msec: " + timeout + " " + unit); } if (log.isDebugEnabled()) { log.debug("request({}) request={}, timeout={} {}", this, request, timeout, unit); } Object result; synchronized (requestLock) { try { writePacket(buffer); synchronized (requestResult) { while (isOpen() && (maxWaitMillis > 0L) && (requestResult.get() == null)) { if (log.isTraceEnabled()) { log.trace("request({})[{}] remaining wait={}", this, request, maxWaitMillis); } long waitStart = System.nanoTime(); requestResult.wait(maxWaitMillis); long waitEnd = System.nanoTime(); long waitDuration = waitEnd - waitStart; long waitMillis = TimeUnit.NANOSECONDS.toMillis(waitDuration); if (waitMillis > 0L) { maxWaitMillis -= waitMillis; } else { maxWaitMillis--; } } result = requestResult.getAndSet(null); } } catch (InterruptedException e) { throw (InterruptedIOException) new InterruptedIOException("Interrupted while waiting for request=" + request + " result").initCause(e); } } if (!isOpen()) { throw new IOException("Session is closed or closing while awaiting reply for request=" + request); } if (log.isDebugEnabled()) { log.debug("request({}) request={}, timeout={} {}, result received={}", this, request, timeout, unit, result != null); } if (result == null) { throw new SocketTimeoutException("No response received after " + timeout + " " + unit + " for request=" + request); } if (result instanceof Buffer) { return (Buffer) result; } return null; } @Override public Buffer createBuffer(byte cmd) { return createBuffer(cmd, 0); } @Override public Buffer createBuffer(byte cmd, int len) { if (len <= 0) { return prepareBuffer(cmd, new ByteArrayBuffer()); } // Since the caller claims to know how many bytes they will need // increase their request to account for our headers/footers if // they actually send exactly this amount. // int bsize = outCipherSize; len += SshConstants.SSH_PACKET_HEADER_LEN; int pad = (-len) & (bsize - 1); if (pad < bsize) { pad += bsize; } len = len + pad - 4; if (outMac != null) { len += outMac.getBlockSize(); } return prepareBuffer(cmd, new ByteArrayBuffer(new byte[len + Byte.SIZE], false)); } @Override public Buffer prepareBuffer(byte cmd, Buffer buffer) { buffer = validateTargetBuffer(cmd & 0xFF, buffer); buffer.rpos(SshConstants.SSH_PACKET_HEADER_LEN); buffer.wpos(SshConstants.SSH_PACKET_HEADER_LEN); buffer.putByte(cmd); return buffer; } /** * Makes sure that the buffer used for output is not {@code null} or one * of the session's internal ones used for decoding and uncompressing * * @param <B> The {@link Buffer} type being validated * @param cmd The most likely command this buffer refers to (not guaranteed to be correct) * @param buffer The buffer to be examined * @return The validated target instance - default same as input * @throws IllegalArgumentException if any of the conditions is violated */ protected <B extends Buffer> B validateTargetBuffer(int cmd, B buffer) { ValidateUtils.checkNotNull(buffer, "No target buffer to examine for command=%d", cmd); ValidateUtils.checkTrue(buffer != decoderBuffer, "Not allowed to use the internal decoder buffer for command=%d", cmd); ValidateUtils.checkTrue(buffer != uncompressBuffer, "Not allowed to use the internal uncompress buffer for command=%d", cmd); return buffer; } /** * Encode a buffer into the SSH protocol. * This method need to be called into a synchronized block around encodeLock * * @param buffer the buffer to encode * @return The encoded buffer - may be different than original if input * buffer does not have enough room for {@link SshConstants#SSH_PACKET_HEADER_LEN}, * in which a substitute buffer will be created and used. * @throws IOException if an exception occurs during the encoding process */ protected Buffer encode(Buffer buffer) throws IOException { try { // Check that the packet has some free space for the header int curPos = buffer.rpos(); if (curPos < SshConstants.SSH_PACKET_HEADER_LEN) { byte[] data = buffer.array(); int cmd = data[curPos] & 0xFF; // usually the 1st byte is an SSH opcode log.warn("encode({}) command={} performance cost: available buffer packet header length ({}) below min. required ({})", this, SshConstants.getCommandMessageName(cmd), curPos, SshConstants.SSH_PACKET_HEADER_LEN); Buffer nb = new ByteArrayBuffer(buffer.available() + Long.SIZE, false); nb.wpos(SshConstants.SSH_PACKET_HEADER_LEN); nb.putBuffer(buffer); buffer = nb; curPos = buffer.rpos(); } // Grab the length of the packet (excluding the 5 header bytes) int len = buffer.available(); int off = curPos - SshConstants.SSH_PACKET_HEADER_LEN; // Debug log the packet if (log.isTraceEnabled()) { buffer.dumpHex(getSimplifiedLogger(), "encode(" + this + ") packet #" + seqo, this); } // Compress the packet if needed if ((outCompression != null) && outCompression.isCompressionExecuted() && (authed || (!outCompression.isDelayed()))) { outCompression.compress(buffer); len = buffer.available(); } // Compute padding length int bsize = outCipherSize; int oldLen = len; len += SshConstants.SSH_PACKET_HEADER_LEN; int pad = (-len) & (bsize - 1); if (pad < bsize) { pad += bsize; } len = len + pad - 4; // Write 5 header bytes buffer.wpos(off); buffer.putInt(len); buffer.putByte((byte) pad); // Fill padding buffer.wpos(off + oldLen + SshConstants.SSH_PACKET_HEADER_LEN + pad); synchronized (random) { random.fill(buffer.array(), buffer.wpos() - pad, pad); } // Compute mac if (outMac != null) { int macSize = outMac.getBlockSize(); int l = buffer.wpos(); buffer.wpos(l + macSize); outMac.updateUInt(seqo); outMac.update(buffer.array(), off, l); outMac.doFinal(buffer.array(), l); } // Encrypt packet, excluding mac if (outCipher != null) { outCipher.update(buffer.array(), off, len + 4); int blocksCount = (len + 4) / outCipher.getBlockSize(); outBlocksCount.addAndGet(Math.max(1, blocksCount)); } // Increment packet id seqo = (seqo + 1) & 0xffffffffL; // Update stats outPacketsCount.incrementAndGet(); outBytesCount.addAndGet(len); // Make buffer ready to be read buffer.rpos(off); return buffer; } catch (IOException e) { throw e; } catch (Exception e) { throw new SshException(e); } } /** * Decode the incoming buffer and handle packets as needed. * * @throws Exception If failed to decode */ protected void decode() throws Exception { // Decoding loop for (;;) { // Wait for beginning of packet if (decoderState == 0) { // The read position should always be 0 at this point because we have compacted this buffer assert decoderBuffer.rpos() == 0; // If we have received enough bytes, start processing those if (decoderBuffer.available() > inCipherSize) { // Decrypt the first bytes if (inCipher != null) { inCipher.update(decoderBuffer.array(), 0, inCipherSize); int blocksCount = inCipherSize / inCipher.getBlockSize(); inBlocksCount.addAndGet(Math.max(1, blocksCount)); } // Read packet length decoderLength = decoderBuffer.getInt(); // Check packet length validity if ((decoderLength < SshConstants.SSH_PACKET_HEADER_LEN) || (decoderLength > (256 * 1024))) { log.warn("decode({}) Error decoding packet(invalid length): {}", this, decoderLength); decoderBuffer.dumpHex(getSimplifiedLogger(), "decode(" + this + ") invalid length packet", this); throw new SshException(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR, "Invalid packet length: " + decoderLength); } // Ok, that's good, we can go to the next step decoderState = 1; } else { // need more data break; } // We have received the beginning of the packet } else if (decoderState == 1) { // The read position should always be 4 at this point assert decoderBuffer.rpos() == 4; int macSize = inMac != null ? inMac.getBlockSize() : 0; // Check if the packet has been fully received if (decoderBuffer.available() >= (decoderLength + macSize)) { byte[] data = decoderBuffer.array(); // Decrypt the remaining of the packet if (inCipher != null) { int updateLen = decoderLength + 4 - inCipherSize; inCipher.update(data, inCipherSize, updateLen); int blocksCount = updateLen / inCipher.getBlockSize(); inBlocksCount.addAndGet(Math.max(1, blocksCount)); } // Check the mac of the packet if (inMac != null) { // Update mac with packet id inMac.updateUInt(seqi); // Update mac with packet data inMac.update(data, 0, decoderLength + 4); // Compute mac result inMac.doFinal(inMacResult, 0); // Check the computed result with the received mac (just after the packet data) if (!BufferUtils.equals(inMacResult, 0, data, decoderLength + 4, macSize)) { throw new SshException(SshConstants.SSH2_DISCONNECT_MAC_ERROR, "MAC Error"); } } // Increment incoming packet sequence number seqi = (seqi + 1) & 0xffffffffL; // Get padding int pad = decoderBuffer.getUByte(); Buffer packet; int wpos = decoderBuffer.wpos(); // Decompress if needed if ((inCompression != null) && inCompression.isCompressionExecuted() && (authed || (!inCompression.isDelayed()))) { if (uncompressBuffer == null) { uncompressBuffer = new SessionWorkBuffer(this); } else { uncompressBuffer.forceClear(true); } decoderBuffer.wpos(decoderBuffer.rpos() + decoderLength - 1 - pad); inCompression.uncompress(decoderBuffer, uncompressBuffer); packet = uncompressBuffer; } else { decoderBuffer.wpos(decoderLength + 4 - pad); packet = decoderBuffer; } if (log.isTraceEnabled()) { packet.dumpHex(getSimplifiedLogger(), "decode(" + this + ") packet #" + seqi, this); } // Update stats inPacketsCount.incrementAndGet(); inBytesCount.addAndGet(packet.available()); // Process decoded packet handleMessage(packet); // Set ready to handle next packet decoderBuffer.rpos(decoderLength + 4 + macSize); decoderBuffer.wpos(wpos); decoderBuffer.compact(); decoderState = 0; } else { // need more data break; } } } } /** * Resolves the identification to send to the peer session by consulting * the associated {@link FactoryManager}. If a value is set, then it is * <U>appended</U> to the standard {@link #DEFAULT_SSH_VERSION_PREFIX}. * Otherwise a default value is returned consisting of the prefix and * the core artifact name + version in <U>uppercase</U> - e.g.,' * "SSH-2.0-SSHD-CORE-1.2.3.4" * * @param configPropName The property used to query the factory manager * @return The resolved identification value */ protected String resolveIdentificationString(String configPropName) { FactoryManager manager = getFactoryManager(); String ident = manager.getString(configPropName); return DEFAULT_SSH_VERSION_PREFIX + (GenericUtils.isEmpty(ident) ? manager.getVersion() : ident); } /** * Send our identification. * * @param ident our identification to send * @return {@link IoWriteFuture} that can be used to wait for notification * that identification has been send */ protected IoWriteFuture sendIdentification(String ident) { byte[] data = (ident + "\r\n").getBytes(StandardCharsets.UTF_8); if (log.isDebugEnabled()) { log.debug("sendIdentification({}): {}", this, ident.replace('\r', '|').replace('\n', '|')); } return ioSession.write(new ByteArrayBuffer(data)); } /** * Read the other side identification. * This method is specific to the client or server side, but both should call * {@link #doReadIdentification(Buffer, boolean)} and * store the result in the needed property. * * @param buffer The {@link Buffer} containing the remote identification * @return <code>true</code> if the identification has been fully read or * <code>false</code> if more data is needed * @throws IOException if an error occurs such as a bad protocol version */ protected abstract boolean readIdentification(Buffer buffer) throws IOException; /** * Read the remote identification from this buffer. * If more data is needed, the buffer will be reset to its original state * and a {@code null} value will be returned. Else the identification * string will be returned and the data read will be consumed from the buffer. * * @param buffer the buffer containing the identification string * @param server {@code true} if it is called by the server session, * {@code false} if by the client session * @return A {@link List} of all received remote identification lines until * the version line was read or {@code null} if more data is needed. * The identification line is the <U>last</U> one in the list */ protected List<String> doReadIdentification(Buffer buffer, boolean server) { int maxIdentSize = PropertyResolverUtils.getIntProperty(this, FactoryManager.MAX_IDENTIFICATION_SIZE, FactoryManager.DEFAULT_MAX_IDENTIFICATION_SIZE); List<String> ident = null; int rpos = buffer.rpos(); for (byte[] data = new byte[MAX_VERSION_LINE_LENGTH];;) { int pos = 0; // start accumulating line from scratch for (boolean needLf = false;;) { if (buffer.available() == 0) { // Need more data, so undo reading and return null buffer.rpos(rpos); return null; } byte b = buffer.getByte(); /* * According to RFC 4253 section 4.2: * * "The null character MUST NOT be sent" */ if (b == 0) { throw new IllegalStateException("Incorrect identification (null characters not allowed) - " + " at line " + (GenericUtils.size(ident) + 1) + " character #" + (pos + 1) + " after '" + new String(data, 0, pos, StandardCharsets.UTF_8) + "'"); } if (b == '\r') { needLf = true; continue; } if (b == '\n') { break; } if (needLf) { throw new IllegalStateException("Incorrect identification (bad line ending) " + " at line " + (GenericUtils.size(ident) + 1) + ": " + new String(data, 0, pos, StandardCharsets.UTF_8)); } if (pos >= data.length) { throw new IllegalStateException("Incorrect identification (line too long): " + " at line " + (GenericUtils.size(ident) + 1) + ": " + new String(data, 0, pos, StandardCharsets.UTF_8)); } data[pos++] = b; } String str = new String(data, 0, pos, StandardCharsets.UTF_8); if (log.isDebugEnabled()) { log.debug("doReadIdentification({}) line='{}'", this, str); } if (ident == null) { ident = new ArrayList<>(); } ident.add(str); // if this is a server then only one line is expected from the client if (server || str.startsWith("SSH-")) { return ident; } if (buffer.rpos() > maxIdentSize) { throw new IllegalStateException("Incorrect identification (too many header lines): size > " + maxIdentSize); } } } /** * Create our proposal for SSH negotiation * * @param hostKeyTypes The comma-separated list of supported host key types * @return The proposal {@link Map} */ protected Map<KexProposalOption, String> createProposal(String hostKeyTypes) { Map<KexProposalOption, String> proposal = new EnumMap<>(KexProposalOption.class); proposal.put(KexProposalOption.ALGORITHMS, NamedResource.getNames( ValidateUtils.checkNotNullAndNotEmpty(getKeyExchangeFactories(), "No KEX factories"))); proposal.put(KexProposalOption.SERVERKEYS, hostKeyTypes); String ciphers = NamedResource.getNames( ValidateUtils.checkNotNullAndNotEmpty(getCipherFactories(), "No cipher factories")); proposal.put(KexProposalOption.S2CENC, ciphers); proposal.put(KexProposalOption.C2SENC, ciphers); String macs = NamedResource.getNames( ValidateUtils.checkNotNullAndNotEmpty(getMacFactories(), "No MAC factories")); proposal.put(KexProposalOption.S2CMAC, macs); proposal.put(KexProposalOption.C2SMAC, macs); String compressions = NamedResource.getNames( ValidateUtils.checkNotNullAndNotEmpty(getCompressionFactories(), "No compression factories")); proposal.put(KexProposalOption.S2CCOMP, compressions); proposal.put(KexProposalOption.C2SCOMP, compressions); proposal.put(KexProposalOption.S2CLANG, ""); // TODO allow configuration proposal.put(KexProposalOption.C2SLANG, ""); // TODO allow configuration return proposal; } /** * Send the key exchange initialization packet. * This packet contains random data along with our proposal. * * @param proposal our proposal for key exchange negotiation * @return the sent packet data which must be kept for later use * when deriving the session keys * @throws IOException if an error occurred sending the packet */ protected byte[] sendKexInit(Map<KexProposalOption, String> proposal) throws IOException { if (log.isDebugEnabled()) { log.debug("sendKexInit({}) Send SSH_MSG_KEXINIT", this); } Buffer buffer = createBuffer(SshConstants.SSH_MSG_KEXINIT); int p = buffer.wpos(); buffer.wpos(p + SshConstants.MSG_KEX_COOKIE_SIZE); synchronized (random) { random.fill(buffer.array(), p, SshConstants.MSG_KEX_COOKIE_SIZE); } if (log.isTraceEnabled()) { log.trace("sendKexInit({}) cookie={}", this, BufferUtils.toHex(buffer.array(), p, SshConstants.MSG_KEX_COOKIE_SIZE, ':')); } for (KexProposalOption paramType : KexProposalOption.VALUES) { String s = proposal.get(paramType); if (log.isTraceEnabled()) { log.trace("sendKexInit({})[{}] {}", this, paramType.getDescription(), s); } buffer.putString(GenericUtils.trimToEmpty(s)); } buffer.putBoolean(false); // first kex packet follows buffer.putInt(0); // reserved (FFU) byte[] data = buffer.getCompactData(); writePacket(buffer); return data; } /** * Receive the remote key exchange init message. * The packet data is returned for later use. * * @param buffer the {@link Buffer} containing the key exchange init packet * @param proposal the remote proposal to fill * @return the packet data */ protected byte[] receiveKexInit(Buffer buffer, Map<KexProposalOption, String> proposal) { // Recreate the packet payload which will be needed at a later time byte[] d = buffer.array(); byte[] data = new byte[buffer.available() + 1 /* the opcode */]; data[0] = SshConstants.SSH_MSG_KEXINIT; int size = 6; int cookieStartPos = buffer.rpos(); System.arraycopy(d, cookieStartPos, data, 1, data.length - 1); // Skip random cookie data buffer.rpos(cookieStartPos + SshConstants.MSG_KEX_COOKIE_SIZE); size += SshConstants.MSG_KEX_COOKIE_SIZE; if (log.isTraceEnabled()) { log.trace("receiveKexInit({}) cookie={}", this, BufferUtils.toHex(d, cookieStartPos, SshConstants.MSG_KEX_COOKIE_SIZE, ':')); } // Read proposal for (KexProposalOption paramType : KexProposalOption.VALUES) { int lastPos = buffer.rpos(); String value = buffer.getString(); if (log.isTraceEnabled()) { log.trace("receiveKexInit({})[{}] {}", this, paramType.getDescription(), value); } int curPos = buffer.rpos(); int readLen = curPos - lastPos; proposal.put(paramType, value); size += readLen; } firstKexPacketFollows = buffer.getBoolean(); if (log.isTraceEnabled()) { log.trace("receiveKexInit({}) first kex packet follows: {}", this, firstKexPacketFollows); } long reserved = buffer.getUInt(); if (reserved != 0) { if (log.isTraceEnabled()) { log.trace("receiveKexInit({}) non-zero reserved value: {}", this, reserved); } } // Return data byte[] dataShrinked = new byte[size]; System.arraycopy(data, 0, dataShrinked, 0, size); return dataShrinked; } /** * Send a message to put new keys into use. * * @return An {@link IoWriteFuture} that can be used to wait and * check the result of sending the packet * @throws IOException if an error occurs sending the message */ protected IoWriteFuture sendNewKeys() throws IOException { if (log.isDebugEnabled()) { log.debug("sendNewKeys({}) Send SSH_MSG_NEWKEYS", this); } Buffer buffer = createBuffer(SshConstants.SSH_MSG_NEWKEYS, Byte.SIZE); return writePacket(buffer); } /** * Put new keys into use. * This method will initialize the ciphers, digests, macs and compression * according to the negotiated server and client proposals. * * @throws Exception if an error occurs */ protected void receiveNewKeys() throws Exception { byte[] k = kex.getK(); byte[] h = kex.getH(); Digest hash = kex.getHash(); if (sessionId == null) { sessionId = h.clone(); if (log.isDebugEnabled()) { log.debug("receiveNewKeys({}) session ID={}", this, BufferUtils.toHex(':', sessionId)); } } Buffer buffer = new ByteArrayBuffer(); buffer.putMPInt(k); buffer.putRawBytes(h); buffer.putByte((byte) 0x41); buffer.putRawBytes(sessionId); int pos = buffer.available(); byte[] buf = buffer.array(); hash.update(buf, 0, pos); byte[] iv_c2s = hash.digest(); int j = pos - sessionId.length - 1; buf[j]++; hash.update(buf, 0, pos); byte[] iv_s2c = hash.digest(); buf[j]++; hash.update(buf, 0, pos); byte[] e_c2s = hash.digest(); buf[j]++; hash.update(buf, 0, pos); byte[] e_s2c = hash.digest(); buf[j]++; hash.update(buf, 0, pos); byte[] mac_c2s = hash.digest(); buf[j]++; hash.update(buf, 0, pos); byte[] mac_s2c = hash.digest(); String value = getNegotiatedKexParameter(KexProposalOption.S2CENC); Cipher s2ccipher = ValidateUtils.checkNotNull(NamedFactory.create(getCipherFactories(), value), "Unknown s2c cipher: %s", value); e_s2c = resizeKey(e_s2c, s2ccipher.getBlockSize(), hash, k, h); s2ccipher.init(isServer ? Cipher.Mode.Encrypt : Cipher.Mode.Decrypt, e_s2c, iv_s2c); value = getNegotiatedKexParameter(KexProposalOption.S2CMAC); Mac s2cmac = NamedFactory.create(getMacFactories(), value); if (s2cmac == null) { throw new SshException(SshConstants.SSH2_DISCONNECT_MAC_ERROR, "Unknown s2c MAC: " + value); } mac_s2c = resizeKey(mac_s2c, s2cmac.getBlockSize(), hash, k, h); s2cmac.init(mac_s2c); value = getNegotiatedKexParameter(KexProposalOption.S2CCOMP); Compression s2ccomp = NamedFactory.create(getCompressionFactories(), value); if (s2ccomp == null) { throw new SshException(SshConstants.SSH2_DISCONNECT_COMPRESSION_ERROR, "Unknown s2c compression: " + value); } value = getNegotiatedKexParameter(KexProposalOption.C2SENC); Cipher c2scipher = ValidateUtils.checkNotNull(NamedFactory.create(getCipherFactories(), value), "Unknown c2s cipher: %s", value); e_c2s = resizeKey(e_c2s, c2scipher.getBlockSize(), hash, k, h); c2scipher.init(isServer ? Cipher.Mode.Decrypt : Cipher.Mode.Encrypt, e_c2s, iv_c2s); value = getNegotiatedKexParameter(KexProposalOption.C2SMAC); Mac c2smac = NamedFactory.create(getMacFactories(), value); if (c2smac == null) { throw new SshException(SshConstants.SSH2_DISCONNECT_MAC_ERROR, "Unknown c2s MAC: " + value); } mac_c2s = resizeKey(mac_c2s, c2smac.getBlockSize(), hash, k, h); c2smac.init(mac_c2s); value = getNegotiatedKexParameter(KexProposalOption.C2SCOMP); Compression c2scomp = NamedFactory.create(getCompressionFactories(), value); if (c2scomp == null) { throw new SshException(SshConstants.SSH2_DISCONNECT_COMPRESSION_ERROR, "Unknown c2s compression: " + value); } if (isServer) { outCipher = s2ccipher; outMac = s2cmac; outCompression = s2ccomp; inCipher = c2scipher; inMac = c2smac; inCompression = c2scomp; } else { outCipher = c2scipher; outMac = c2smac; outCompression = c2scomp; inCipher = s2ccipher; inMac = s2cmac; inCompression = s2ccomp; } outCipherSize = outCipher.getIVSize(); // TODO add support for configurable compression level outCompression.init(Compression.Type.Deflater, -1); inCipherSize = inCipher.getIVSize(); inMacResult = new byte[inMac.getBlockSize()]; // TODO add support for configurable compression level inCompression.init(Compression.Type.Inflater, -1); // see https://tools.ietf.org/html/rfc4344#section-3.2 int inBlockSize = inCipher.getBlockSize(); int outBlockSize = outCipher.getBlockSize(); // select the lowest cipher size int avgCipherBlockSize = Math.min(inBlockSize, outBlockSize); long recommendedByteRekeyBlocks = 1L << Math.min((avgCipherBlockSize * Byte.SIZE) / 4, 63); // in case (block-size / 4) > 63 maxRekeyBlocks.set(this.getLongProperty(FactoryManager.REKEY_BLOCKS_LIMIT, recommendedByteRekeyBlocks)); if (log.isDebugEnabled()) { log.debug("receiveNewKeys({}) inCipher={}, outCipher={}, recommended blocks limit={}, actual={}", this, inCipher, outCipher, recommendedByteRekeyBlocks, maxRekeyBlocks); } inBytesCount.set(0L); outBytesCount.set(0L); inPacketsCount.set(0L); outPacketsCount.set(0L); inBlocksCount.set(0L); outBlocksCount.set(0L); lastKeyTimeValue.set(System.currentTimeMillis()); firstKexPacketFollows = null; } /** * Method used while putting new keys into use that will resize the key used to * initialize the cipher to the needed length. * * @param e the key to resize * @param blockSize the cipher block size (in bytes) * @param hash the hash algorithm * @param k the key exchange k parameter * @param h the key exchange h parameter * @return the resized key * @throws Exception if a problem occur while resizing the key */ protected byte[] resizeKey(byte[] e, int blockSize, Digest hash, byte[] k, byte[] h) throws Exception { for (Buffer buffer = null; blockSize > e.length; buffer = BufferUtils.clear(buffer)) { if (buffer == null) { buffer = new ByteArrayBuffer(); } buffer.putMPInt(k); buffer.putRawBytes(h); buffer.putRawBytes(e); hash.update(buffer.array(), 0, buffer.available()); byte[] foo = hash.digest(); byte[] bar = new byte[e.length + foo.length]; System.arraycopy(e, 0, bar, 0, e.length); System.arraycopy(foo, 0, bar, e.length, foo.length); e = bar; } return e; } @Override public void disconnect(final int reason, final String msg) throws IOException { log.info("Disconnecting({}): {} - {}", this, SshConstants.getDisconnectReasonName(reason), msg); Buffer buffer = createBuffer(SshConstants.SSH_MSG_DISCONNECT, msg.length() + Short.SIZE); buffer.putInt(reason); buffer.putString(msg); buffer.putString(""); // TODO configure language... // Write the packet with a timeout to ensure a timely close of the session // in case the consumer does not read packets anymore. long disconnectTimeoutMs = this.getLongProperty(FactoryManager.DISCONNECT_TIMEOUT, FactoryManager.DEFAULT_DISCONNECT_TIMEOUT); writePacket(buffer, disconnectTimeoutMs, TimeUnit.MILLISECONDS).addListener(future -> { Throwable t = future.getException(); if (log.isDebugEnabled()) { if (t == null) { log.debug("disconnect({}) operation successfully completed for reason={} [{}]", AbstractSession.this, SshConstants.getDisconnectReasonName(reason), msg); } else { log.debug("disconnect({}) operation failed ({}) for reason={} [{}]: {}", AbstractSession.this, t.getClass().getSimpleName(), SshConstants.getDisconnectReasonName(reason), msg, t.getMessage()); } } if (t != null) { if (log.isTraceEnabled()) { log.trace("disconnect(" + AbstractSession.this + ") reason=" + SshConstants.getDisconnectReasonName(reason) + " failure details", t); } } close(true); }); } /** * Send a {@code SSH_MSG_UNIMPLEMENTED} packet. This packet should * contain the sequence id of the unsupported packet: this number * is assumed to be the last packet received. * * @return An {@link IoWriteFuture} that can be used to wait for packet write completion * @throws IOException if an error occurred sending the packet * @see #sendNotImplemented(long) */ protected IoWriteFuture notImplemented() throws IOException { return sendNotImplemented(seqi - 1); } /** * Sends a {@code SSH_MSG_UNIMPLEMENTED} message * * @param seqNoValue The referenced sequence number * @return An {@link IoWriteFuture} that can be used to wait for packet write completion * @throws IOException if an error occurred sending the packet */ protected IoWriteFuture sendNotImplemented(long seqNoValue) throws IOException { Buffer buffer = createBuffer(SshConstants.SSH_MSG_UNIMPLEMENTED, Byte.SIZE); buffer.putInt(seqNoValue); return writePacket(buffer); } /** * Compute the negotiated proposals by merging the client and * server proposal. The negotiated proposal will also be stored in * the {@link #negotiationResult} property. * * @return The negotiated options {@link Map} */ protected Map<KexProposalOption, String> negotiate() { Map<KexProposalOption, String> c2sOptions = Collections.unmodifiableMap(clientProposal); Map<KexProposalOption, String> s2cOptions = Collections.unmodifiableMap(serverProposal); signalNegotiationStart(c2sOptions, s2cOptions); Map<KexProposalOption, String> guess = new EnumMap<>(KexProposalOption.class); Map<KexProposalOption, String> negotiatedGuess = Collections.unmodifiableMap(guess); try { for (KexProposalOption paramType : KexProposalOption.VALUES) { String clientParamValue = c2sOptions.get(paramType); String serverParamValue = s2cOptions.get(paramType); String[] c = GenericUtils.split(clientParamValue, ','); String[] s = GenericUtils.split(serverParamValue, ','); for (String ci : c) { for (String si : s) { if (ci.equals(si)) { guess.put(paramType, ci); break; } } String value = guess.get(paramType); if (value != null) { break; } } // check if reached an agreement String value = guess.get(paramType); if (value == null) { String message = "Unable to negotiate key exchange for " + paramType.getDescription() + " (client: " + clientParamValue + " / server: " + serverParamValue + ")"; // OK if could not negotiate languages if (KexProposalOption.S2CLANG.equals(paramType) || KexProposalOption.C2SLANG.equals(paramType)) { if (log.isTraceEnabled()) { log.trace("negotiate({}) {}", this, message); } } else { throw new IllegalStateException(message); } } else { if (log.isTraceEnabled()) { log.trace("negotiate(" + this + ")[" + paramType.getDescription() + "] guess=" + value + " (client: " + clientParamValue + " / server: " + serverParamValue + ")"); } } } } catch (RuntimeException | Error e) { signalNegotiationEnd(c2sOptions, s2cOptions, negotiatedGuess, e); throw e; } signalNegotiationEnd(c2sOptions, s2cOptions, negotiatedGuess, null); return setNegotiationResult(guess); } protected void signalNegotiationStart(Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, String> s2cOptions) { try { invokeSessionSignaller(l -> { signalNegotiationStart(l, c2sOptions, s2cOptions); return null; }); } catch (Throwable err) { if (err instanceof RuntimeException) { throw (RuntimeException) err; } else if (err instanceof Error) { throw (Error) err; } else { throw new RuntimeException(err); } } } protected void signalNegotiationStart( SessionListener listener, Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, String> s2cOptions) { if (listener == null) { return; } listener.sessionNegotiationStart(this, c2sOptions, s2cOptions); } protected void signalNegotiationEnd( Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, String> s2cOptions, Map<KexProposalOption, String> negotiatedGuess, Throwable reason) { try { invokeSessionSignaller(l -> { signalNegotiationEnd(l, c2sOptions, s2cOptions, negotiatedGuess, reason); return null; }); } catch (Throwable err) { if (err instanceof RuntimeException) { throw (RuntimeException) err; } else if (err instanceof Error) { throw (Error) err; } else { throw new RuntimeException(err); } } } protected void signalNegotiationEnd(SessionListener listener, Map<KexProposalOption, String> c2sOptions, Map<KexProposalOption, String> s2cOptions, Map<KexProposalOption, String> negotiatedGuess, Throwable reason) { if (listener == null) { return; } listener.sessionNegotiationEnd(this, c2sOptions, s2cOptions, negotiatedGuess, null); } protected Map<KexProposalOption, String> setNegotiationResult(Map<KexProposalOption, String> guess) { synchronized (negotiationResult) { if (!negotiationResult.isEmpty()) { negotiationResult.clear(); // debug breakpoint } negotiationResult.putAll(guess); } if (log.isDebugEnabled()) { log.debug("setNegotiationResult({}) Kex: server->client {} {} {}", this, guess.get(KexProposalOption.S2CENC), guess.get(KexProposalOption.S2CMAC), guess.get(KexProposalOption.S2CCOMP)); log.debug("setNegotiationResult({}) Kex: client->server {} {} {}", this, guess.get(KexProposalOption.C2SENC), guess.get(KexProposalOption.C2SMAC), guess.get(KexProposalOption.C2SCOMP)); } return guess; } /** * Indicates the reception of a {@code SSH_MSG_REQUEST_SUCCESS} message * * @param buffer The {@link Buffer} containing the message data * @throws Exception If failed to handle the message */ protected void requestSuccess(Buffer buffer) throws Exception { // use a copy of the original data in case it is re-used on return Buffer resultBuf = ByteArrayBuffer.getCompactClone(buffer.array(), buffer.rpos(), buffer.available()); synchronized (requestResult) { requestResult.set(resultBuf); resetIdleTimeout(); requestResult.notify(); } } /** * Indicates the reception of a {@code SSH_MSG_REQUEST_FAILURE} message * * @param buffer The {@link Buffer} containing the message data * @throws Exception If failed to handle the message */ protected void requestFailure(Buffer buffer) throws Exception { synchronized (requestResult) { requestResult.set(GenericUtils.NULL); resetIdleTimeout(); requestResult.notify(); } } @Override @SuppressWarnings("unchecked") public <T> T getAttribute(AttributeKey<T> key) { return (T) attributes.get(Objects.requireNonNull(key, "No key")); } @Override @SuppressWarnings("unchecked") public <T> T setAttribute(AttributeKey<T> key, T value) { return (T) attributes.put( Objects.requireNonNull(key, "No key"), Objects.requireNonNull(value, "No value")); } @Override @SuppressWarnings("unchecked") public <T> T removeAttribute(AttributeKey<T> key) { return (T) attributes.remove(Objects.requireNonNull(key, "No key")); } @Override public <T> T resolveAttribute(AttributeKey<T> key) { return AttributeStore.resolveAttribute(this, key); } @Override public String getUsername() { return username; } @Override public void setUsername(String username) { this.username = username; } public Object getLock() { return lock; } @Override public ReservedSessionMessagesHandler getReservedSessionMessagesHandler() { return resolveEffectiveProvider(ReservedSessionMessagesHandler.class, reservedSessionMessagesHandler, getFactoryManager().getReservedSessionMessagesHandler()); } @Override public void setReservedSessionMessagesHandler(ReservedSessionMessagesHandler handler) { reservedSessionMessagesHandler = handler; } @Override public void addSessionListener(SessionListener listener) { SessionListener.validateListener(listener); // avoid race conditions on notifications while session is being closed if (!isOpen()) { log.warn("addSessionListener({})[{}] ignore registration while session is closing", this, listener); return; } if (this.sessionListeners.add(listener)) { if (log.isTraceEnabled()) { log.trace("addSessionListener({})[{}] registered", this, listener); } } else { if (log.isTraceEnabled()) { log.trace("addSessionListener({})[{}] ignored duplicate", this, listener); } } } @Override public void removeSessionListener(SessionListener listener) { if (listener == null) { return; } SessionListener.validateListener(listener); if (this.sessionListeners.remove(listener)) { if (log.isTraceEnabled()) { log.trace("removeSessionListener({})[{}] removed", this, listener); } } else { if (log.isTraceEnabled()) { log.trace("removeSessionListener({})[{}] not registered", this, listener); } } } @Override public SessionListener getSessionListenerProxy() { return sessionListenerProxy; } @Override public void addChannelListener(ChannelListener listener) { ChannelListener.validateListener(listener); // avoid race conditions on notifications while session is being closed if (!isOpen()) { log.warn("addChannelListener({})[{}] ignore registration while session is closing", this, listener); return; } if (this.channelListeners.add(listener)) { if (log.isTraceEnabled()) { log.trace("addChannelListener({})[{}] registered", this, listener); } } else { if (log.isTraceEnabled()) { log.trace("addChannelListener({})[{}] ignored duplicate", this, listener); } } } @Override public void removeChannelListener(ChannelListener listener) { if (listener == null) { return; } ChannelListener.validateListener(listener); if (this.channelListeners.remove(listener)) { if (log.isTraceEnabled()) { log.trace("removeChannelListener({})[{}] removed", this, listener); } } else { if (log.isTraceEnabled()) { log.trace("removeChannelListener({})[{}] not registered", this, listener); } } } @Override public ChannelListener getChannelListenerProxy() { return channelListenerProxy; } @Override public PortForwardingEventListener getPortForwardingEventListenerProxy() { return tunnelListenerProxy; } @Override public void addPortForwardingEventListener(PortForwardingEventListener listener) { PortForwardingEventListener.validateListener(listener); // avoid race conditions on notifications while session is being closed if (!isOpen()) { log.warn("addPortForwardingEventListener({})[{}] ignore registration while session is closing", this, listener); return; } if (this.tunnelListeners.add(listener)) { if (log.isTraceEnabled()) { log.trace("addPortForwardingEventListener({})[{}] registered", this, listener); } } else { if (log.isTraceEnabled()) { log.trace("addPortForwardingEventListener({})[{}] ignored duplicate", this, listener); } } } @Override public void removePortForwardingEventListener(PortForwardingEventListener listener) { if (listener == null) { return; } PortForwardingEventListener.validateListener(listener); if (this.tunnelListeners.remove(listener)) { if (log.isTraceEnabled()) { log.trace("removePortForwardingEventListener({})[{}] removed", this, listener); } } else { if (log.isTraceEnabled()) { log.trace("removePortForwardingEventListener({})[{}] not registered", this, listener); } } } /** * Sends a session event to all currently registered session listeners * * @param event The event to send * @throws IOException If any of the registered listeners threw an exception. */ protected void signalSessionEvent(SessionListener.Event event) throws IOException { try { invokeSessionSignaller(l -> { signalSessionEvent(l, event); return null; }); } catch (Throwable err) { Throwable t = GenericUtils.peelException(err); if (log.isDebugEnabled()) { log.debug("sendSessionEvent({})[{}] failed ({}) to inform listeners: {}", this, event, t.getClass().getSimpleName(), t.getMessage()); } if (log.isTraceEnabled()) { log.trace("sendSessionEvent(" + this + ")[" + event + "] listener inform details", t); } if (t instanceof IOException) { throw (IOException) t; } else if (t instanceof RuntimeException) { throw (RuntimeException) t; } else { throw new IOException("Failed (" + t.getClass().getSimpleName() + ") to send session event: " + t.getMessage(), t); } } } protected void signalSessionEvent(SessionListener listener, SessionListener.Event event) throws IOException { if (listener == null) { return; } listener.sessionEvent(this, event); } protected void invokeSessionSignaller(Invoker<SessionListener, Void> invoker) throws Throwable { FactoryManager manager = getFactoryManager(); SessionListener[] listeners = { (manager == null) ? null : manager.getSessionListenerProxy(), getSessionListenerProxy() }; Throwable err = null; for (SessionListener l : listeners) { if (l == null) { continue; } try { invoker.invoke(l); } catch (Throwable t) { err = GenericUtils.accumulateException(err, t); } } if (err != null) { throw err; } } @Override public KeyExchangeFuture reExchangeKeys() throws IOException { requestNewKeysExchange(); return ValidateUtils.checkNotNull(kexFutureHolder.get(), "No current KEX future on state=%s", kexState.get()); } /** * Checks if a re-keying is required and if so initiates it * * @return A {@link KeyExchangeFuture} to wait for the initiated exchange * or {@code null} if no need to re-key or an exchange is already in progress * @throws IOException If failed to send the request * @see #isRekeyRequired() * @see #requestNewKeysExchange() */ protected KeyExchangeFuture checkRekey() throws IOException { return isRekeyRequired() ? requestNewKeysExchange() : null; } /** * Initiates a new keys exchange if one not already in progress * * @return A {@link KeyExchangeFuture} to wait for the initiated exchange * or {@code null} if an exchange is already in progress * @throws IOException If failed to send the request */ protected KeyExchangeFuture requestNewKeysExchange() throws IOException { if (!kexState.compareAndSet(KexState.DONE, KexState.INIT)) { if (log.isDebugEnabled()) { log.debug("requestNewKeysExchange({}) KEX state not DONE: {}", this, kexState.get()); } return null; } log.info("requestNewKeysExchange({}) Initiating key re-exchange", this); sendKexInit(); DefaultKeyExchangeFuture newFuture = new DefaultKeyExchangeFuture(null); DefaultKeyExchangeFuture kexFuture = kexFutureHolder.getAndSet(newFuture); if (kexFuture != null) { synchronized (kexFuture) { Object value = kexFuture.getValue(); if (value == null) { kexFuture.setValue(new SshException("New KEX started while previous one still ongoing")); } } } return newFuture; } protected boolean isRekeyRequired() { if ((!isOpen()) || isClosing() || isClosed()) { return false; } KexState curState = kexState.get(); if (!KexState.DONE.equals(curState)) { return false; } return isRekeyTimeIntervalExceeded() || isRekeyPacketCountsExceeded() || isRekeyBlocksCountExceeded() || isRekeyDataSizeExceeded(); } protected boolean isRekeyTimeIntervalExceeded() { if (maxRekeyInterval <= 0L) { return false; // disabled } long now = System.currentTimeMillis(); long rekeyDiff = now - lastKeyTimeValue.get(); boolean rekey = rekeyDiff > maxRekeyInterval; if (rekey) { if (log.isDebugEnabled()) { log.debug("isRekeyTimeIntervalExceeded({}) re-keying: last={}, now={}, diff={}, max={}", this, new Date(lastKeyTimeValue.get()), new Date(now), rekeyDiff, maxRekeyInterval); } } return rekey; } protected boolean isRekeyPacketCountsExceeded() { if (maxRekyPackets <= 0L) { return false; // disabled } boolean rekey = (inPacketsCount.get() > maxRekyPackets) || (outPacketsCount.get() > maxRekyPackets); if (rekey) { if (log.isDebugEnabled()) { log.debug("isRekeyPacketCountsExceeded({}) re-keying: in={}, out={}, max={}", this, inPacketsCount, outPacketsCount, maxRekyPackets); } } return rekey; } protected boolean isRekeyDataSizeExceeded() { if (maxRekeyBytes <= 0L) { return false; } boolean rekey = (inBytesCount.get() > maxRekeyBytes) || (outBytesCount.get() > maxRekeyBytes); if (rekey) { if (log.isDebugEnabled()) { log.debug("isRekeyDataSizeExceeded({}) re-keying: in={}, out={}, max={}", this, inBytesCount, outBytesCount, maxRekeyBytes); } } return rekey; } protected boolean isRekeyBlocksCountExceeded() { long maxBlocks = maxRekeyBlocks.get(); if (maxBlocks <= 0L) { return false; } boolean rekey = (inBlocksCount.get() > maxBlocks) || (outBlocksCount.get() > maxBlocks); if (rekey) { if (log.isDebugEnabled()) { log.debug("isRekeyBlocksCountExceeded({}) re-keying: in={}, out={}, max={}", this, inBlocksCount, outBlocksCount, maxBlocks); } } return rekey; } protected byte[] sendKexInit() throws IOException { String resolvedAlgorithms = resolveAvailableSignaturesProposal(); if (GenericUtils.isEmpty(resolvedAlgorithms)) { throw new SshException(SshConstants.SSH2_DISCONNECT_HOST_KEY_NOT_VERIFIABLE, "sendKexInit() no resolved signatures available"); } Map<KexProposalOption, String> proposal = createProposal(resolvedAlgorithms); byte[] seed = sendKexInit(proposal); if (log.isTraceEnabled()) { log.trace("sendKexInit({}) proposal={} seed: {}", this, proposal, BufferUtils.toHex(':', seed)); } setKexSeed(seed); return seed; } /** * @param seed The result of the KEXINIT handshake - required for correct * session key establishment */ protected abstract void setKexSeed(byte... seed); /** * @return A comma-separated list of all the signature protocols to be * included in the proposal - {@code null}/empty if no proposal * @see #getFactoryManager() * @see #resolveAvailableSignaturesProposal(FactoryManager) */ protected String resolveAvailableSignaturesProposal() { return resolveAvailableSignaturesProposal(getFactoryManager()); } /** * @param manager The {@link FactoryManager} * @return A comma-separated list of all the signature protocols to be * included in the proposal - {@code null}/empty if no proposal */ protected abstract String resolveAvailableSignaturesProposal(FactoryManager manager); /** * Indicates the the key exchange is completed and the exchanged keys * can now be verified - e.g., client can verify the server's key * * @throws IOException If validation failed */ protected abstract void checkKeys() throws IOException; protected void receiveKexInit(Buffer buffer) throws IOException { Map<KexProposalOption, String> proposal = new EnumMap<>(KexProposalOption.class); byte[] seed = receiveKexInit(buffer, proposal); receiveKexInit(proposal, seed); } protected abstract void receiveKexInit(Map<KexProposalOption, String> proposal, byte[] seed) throws IOException; // returns the proposal argument protected Map<KexProposalOption, String> mergeProposals(Map<KexProposalOption, String> current, Map<KexProposalOption, String> proposal) { if (current == proposal) { return proposal; // nothing to merge } synchronized (current) { if (!current.isEmpty()) { current.clear(); // debug breakpoint } if (GenericUtils.isEmpty(proposal)) { return proposal; // debug breakpoint } current.putAll(proposal); } return proposal; } /** * Checks whether the session has timed out (both auth and idle timeouts are checked). * If the session has timed out, a DISCONNECT message will be sent. * * @throws IOException If failed to check * @see #checkAuthenticationTimeout(long, long) * @see #checkIdleTimeout(long, long) */ protected void checkForTimeouts() throws IOException { if ((!isOpen()) || isClosing() || isClosed()) { if (log.isDebugEnabled()) { log.debug("checkForTimeouts({}) session closing", this); return; } } long now = System.currentTimeMillis(); Pair<TimeoutStatus, String> result = checkAuthenticationTimeout(now, getAuthTimeout()); if (result == null) { result = checkIdleTimeout(now, getIdleTimeout()); } TimeoutStatus status = (result == null) ? TimeoutStatus.NoTimeout : result.getFirst(); if ((status == null) || TimeoutStatus.NoTimeout.equals(status)) { return; } if (log.isDebugEnabled()) { log.debug("checkForTimeouts({}) disconnect - reason={}", this, status); } timeoutStatus.set(status); disconnect(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR, result.getSecond()); } /** * Checks if authentication timeout expired * * @param now The current time in millis * @param authTimeoutMs The configured timeout in millis - if non-positive * then no timeout * @return A {@link Pair} specifying the timeout status and disconnect reason * message if timeout expired, {@code null} or {@code NoTimeout} if no timeout * occurred * @see #getAuthTimeout() */ protected Pair<TimeoutStatus, String> checkAuthenticationTimeout(long now, long authTimeoutMs) { long authDiff = now - authTimeoutStart; if ((!authed) && (authTimeoutMs > 0L) && (authDiff > authTimeoutMs)) { return new Pair<>(TimeoutStatus.AuthTimeout, "Session has timed out waiting for authentication after " + authTimeoutMs + " ms."); } else { return null; } } /** * Checks if idle timeout expired * * @param now The current time in millis * @param idleTimeoutMs The configured timeout in millis - if non-positive * then no timeout * @return A {@link Pair} specifying the timeout status and disconnect reason * message if timeout expired, {@code null} or {@code NoTimeout} if no timeout * occurred * @see #getIdleTimeout() */ protected Pair<TimeoutStatus, String> checkIdleTimeout(long now, long idleTimeoutMs) { long idleDiff = now - idleTimeoutStart; if ((idleTimeoutMs > 0L) && (idleDiff > idleTimeoutMs)) { return new Pair<>(TimeoutStatus.IdleTimeout, "User session has timed out idling after " + idleTimeoutMs + " ms."); } else { return null; } } @Override public void resetIdleTimeout() { this.idleTimeoutStart = System.currentTimeMillis(); } @Override public TimeoutStatus getTimeoutStatus() { return timeoutStatus.get(); } @Override public long getAuthTimeout() { return this.getLongProperty(FactoryManager.AUTH_TIMEOUT, FactoryManager.DEFAULT_AUTH_TIMEOUT); } @Override public long getIdleTimeout() { return this.getLongProperty(FactoryManager.IDLE_TIMEOUT, FactoryManager.DEFAULT_IDLE_TIMEOUT); } @Override public String toString() { IoSession ioSession = getIoSession(); SocketAddress peerAddress = (ioSession == null) ? null : ioSession.getRemoteAddress(); return getClass().getSimpleName() + "[" + getUsername() + "@" + peerAddress + "]"; } }