/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.common.session.helpers; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.IntUnaryOperator; import org.apache.sshd.agent.common.AgentForwardSupport; import org.apache.sshd.agent.common.DefaultAgentForwardSupport; import org.apache.sshd.client.channel.AbstractClientChannel; import org.apache.sshd.common.Closeable; import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.NamedFactory; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; import org.apache.sshd.common.channel.AbstractChannel; import org.apache.sshd.common.channel.Channel; import org.apache.sshd.common.channel.OpenChannelException; import org.apache.sshd.common.channel.RequestHandler; import org.apache.sshd.common.channel.Window; import org.apache.sshd.common.forward.PortForwardingEventListener; import org.apache.sshd.common.forward.PortForwardingEventListenerManager; import org.apache.sshd.common.forward.TcpipForwarder; import org.apache.sshd.common.forward.TcpipForwarderFactory; import org.apache.sshd.common.io.AbstractIoWriteFuture; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.session.ConnectionService; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.util.EventListenerUtils; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.Int2IntFunction; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.closeable.AbstractInnerCloseable; import org.apache.sshd.server.x11.DefaultX11ForwardSupport; import org.apache.sshd.server.x11.X11ForwardSupport; /** * Base implementation of ConnectionService. * * @param <S> Type of {@link AbstractSession} being used * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ public abstract class AbstractConnectionService<S extends AbstractSession> extends AbstractInnerCloseable implements ConnectionService { /** * Property that can be used to configure max. allowed concurrent active channels * * @see #registerChannel(Channel) */ public static final String MAX_CONCURRENT_CHANNELS_PROP = "max-sshd-channels"; /** * Default value for {@link #MAX_CONCURRENT_CHANNELS_PROP} is none specified */ public static final int DEFAULT_MAX_CHANNELS = Integer.MAX_VALUE; /** * Default growth factor function used to resize response buffers */ public static final IntUnaryOperator RESPONSE_BUFFER_GROWTH_FACTOR = Int2IntFunction.add(Byte.SIZE); /** * Map of channels keyed by the identifier */ protected final Map<Integer, Channel> channels = new ConcurrentHashMap<>(); /** * Next channel identifier */ protected final AtomicInteger nextChannelId = new AtomicInteger(0); private final AtomicReference<AgentForwardSupport> agentForwardHolder = new AtomicReference<>(); private final AtomicReference<X11ForwardSupport> x11ForwardHolder = new AtomicReference<>(); private final AtomicReference<TcpipForwarder> tcpipForwarderHolder = new AtomicReference<>(); private final AtomicBoolean allowMoreSessions = new AtomicBoolean(true); private final Collection<PortForwardingEventListener> listeners = new CopyOnWriteArraySet<>(); private final Collection<PortForwardingEventListenerManager> managersHolder = new CopyOnWriteArraySet<>(); private final PortForwardingEventListener listenerProxy; private final S sessionInstance; protected AbstractConnectionService(S session) { sessionInstance = Objects.requireNonNull(session, "No session"); listenerProxy = EventListenerUtils.proxyWrapper(PortForwardingEventListener.class, getClass().getClassLoader(), listeners); } @Override public PortForwardingEventListener getPortForwardingEventListenerProxy() { return listenerProxy; } @Override public void addPortForwardingEventListener(PortForwardingEventListener listener) { listeners.add(PortForwardingEventListener.validateListener(listener)); } @Override public void removePortForwardingEventListener(PortForwardingEventListener listener) { if (listener == null) { return; } listeners.remove(PortForwardingEventListener.validateListener(listener)); } @Override public Collection<PortForwardingEventListenerManager> getRegisteredManagers() { return managersHolder.isEmpty() ? Collections.emptyList() : new ArrayList<>(managersHolder); } @Override public boolean addPortForwardingEventListenerManager(PortForwardingEventListenerManager manager) { return managersHolder.add(Objects.requireNonNull(manager, "No manager")); } @Override public boolean removePortForwardingEventListenerManager(PortForwardingEventListenerManager manager) { if (manager == null) { return false; } return managersHolder.remove(manager); } public Collection<Channel> getChannels() { return channels.values(); } @Override public S getSession() { return sessionInstance; } @Override public void start() { // do nothing } @Override public TcpipForwarder getTcpipForwarder() { TcpipForwarder forwarder; S session = getSession(); synchronized (tcpipForwarderHolder) { forwarder = tcpipForwarderHolder.get(); if (forwarder != null) { return forwarder; } forwarder = ValidateUtils.checkNotNull(createTcpipForwarder(session), "No forwarder created for %s", session); tcpipForwarderHolder.set(forwarder); } if (log.isDebugEnabled()) { log.debug("getTcpipForwarder({}) created instance", session); } return forwarder; } @Override protected void preClose() { this.listeners.clear(); this.managersHolder.clear(); super.preClose(); } protected TcpipForwarder createTcpipForwarder(S session) { FactoryManager manager = Objects.requireNonNull(session.getFactoryManager(), "No factory manager"); TcpipForwarderFactory factory = Objects.requireNonNull(manager.getTcpipForwarderFactory(), "No forwarder factory"); TcpipForwarder forwarder = factory.create(this); forwarder.addPortForwardingEventListenerManager(this); return forwarder; } @Override public X11ForwardSupport getX11ForwardSupport() { X11ForwardSupport x11Support; S session = getSession(); synchronized (x11ForwardHolder) { x11Support = x11ForwardHolder.get(); if (x11Support != null) { return x11Support; } x11Support = ValidateUtils.checkNotNull(createX11ForwardSupport(session), "No X11 forwarder created for %s", session); x11ForwardHolder.set(x11Support); } if (log.isDebugEnabled()) { log.debug("getX11ForwardSupport({}) created instance", session); } return x11Support; } protected X11ForwardSupport createX11ForwardSupport(S session) { return new DefaultX11ForwardSupport(this); } @Override public AgentForwardSupport getAgentForwardSupport() { AgentForwardSupport agentForward; S session = getSession(); synchronized (agentForwardHolder) { agentForward = agentForwardHolder.get(); if (agentForward != null) { return agentForward; } agentForward = ValidateUtils.checkNotNull(createAgentForwardSupport(session), "No agent forward created for %s", session); agentForwardHolder.set(agentForward); } if (log.isDebugEnabled()) { log.debug("getAgentForwardSupport({}) created instance", session); } return agentForward; } protected AgentForwardSupport createAgentForwardSupport(S session) { return new DefaultAgentForwardSupport(this); } @Override protected Closeable getInnerCloseable() { return builder() .sequential(tcpipForwarderHolder.get(), agentForwardHolder.get(), x11ForwardHolder.get()) .parallel(channels.values()) .build(); } protected int getNextChannelId() { return nextChannelId.getAndIncrement(); } @Override public int registerChannel(Channel channel) throws IOException { Session session = getSession(); int maxChannels = session.getIntProperty(MAX_CONCURRENT_CHANNELS_PROP, DEFAULT_MAX_CHANNELS); int curSize = channels.size(); if (curSize > maxChannels) { throw new IllegalStateException("Currently active channels (" + curSize + ") at max.: " + maxChannels); } int channelId = getNextChannelId(); channel.init(this, session, channelId); boolean registered = false; synchronized (lock) { if (!isClosing()) { channels.put(channelId, channel); registered = true; } } if (!registered) { handleChannelRegistrationFailure(channel, channelId); } if (log.isDebugEnabled()) { log.debug("registerChannel({})[id={}] {}", this, channelId, channel); } return channelId; } protected void handleChannelRegistrationFailure(Channel channel, int channelId) throws IOException { RuntimeException reason = new IllegalStateException("Channel id=" + channelId + " not registered because session is being closed: " + this); AbstractChannel notifier = ValidateUtils.checkInstanceOf(channel, AbstractChannel.class, "Non abstract channel for id=%d", channelId); notifier.signalChannelClosed(reason); throw reason; } /** * Remove this channel from the list of managed channels * * @param channel the channel */ @Override public void unregisterChannel(Channel channel) { Channel result = channels.remove(channel.getId()); if (log.isDebugEnabled()) { log.debug("unregisterChannel({}) result={}", channel, result); } } @Override public void process(int cmd, Buffer buffer) throws Exception { switch (cmd) { case SshConstants.SSH_MSG_CHANNEL_OPEN: channelOpen(buffer); break; case SshConstants.SSH_MSG_CHANNEL_OPEN_CONFIRMATION: channelOpenConfirmation(buffer); break; case SshConstants.SSH_MSG_CHANNEL_OPEN_FAILURE: channelOpenFailure(buffer); break; case SshConstants.SSH_MSG_CHANNEL_REQUEST: channelRequest(buffer); break; case SshConstants.SSH_MSG_CHANNEL_DATA: channelData(buffer); break; case SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA: channelExtendedData(buffer); break; case SshConstants.SSH_MSG_CHANNEL_FAILURE: channelFailure(buffer); break; case SshConstants.SSH_MSG_CHANNEL_SUCCESS: channelSuccess(buffer); break; case SshConstants.SSH_MSG_CHANNEL_WINDOW_ADJUST: channelWindowAdjust(buffer); break; case SshConstants.SSH_MSG_CHANNEL_EOF: channelEof(buffer); break; case SshConstants.SSH_MSG_CHANNEL_CLOSE: channelClose(buffer); break; case SshConstants.SSH_MSG_GLOBAL_REQUEST: globalRequest(buffer); break; case SshConstants.SSH_MSG_REQUEST_SUCCESS: requestSuccess(buffer); break; case SshConstants.SSH_MSG_REQUEST_FAILURE: requestFailure(buffer); break; default: throw new IllegalStateException("Unsupported command: " + SshConstants.getCommandMessageName(cmd)); } } @Override public boolean isAllowMoreSessions() { return allowMoreSessions.get(); } @Override public void setAllowMoreSessions(boolean allow) { if (log.isDebugEnabled()) { log.debug("setAllowMoreSessions({}): {}", this, allow); } allowMoreSessions.set(allow); } public void channelOpenConfirmation(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); int sender = buffer.getInt(); long rwsize = buffer.getUInt(); long rmpsize = buffer.getUInt(); if (log.isDebugEnabled()) { log.debug("channelOpenConfirmation({}) SSH_MSG_CHANNEL_OPEN_CONFIRMATION sender={}, window-size={}, packet-size={}", channel, sender, rwsize, rmpsize); } /* * NOTE: the 'sender' of the SSH_MSG_CHANNEL_OPEN_CONFIRMATION is the * recipient on the client side - see rfc4254 section 5.1: * * 'sender channel' is the channel number allocated by the other side * * in our case, the server */ channel.handleOpenSuccess(sender, rwsize, rmpsize, buffer); } public void channelOpenFailure(Buffer buffer) throws IOException { AbstractClientChannel channel = (AbstractClientChannel) getChannel(buffer); int id = channel.getId(); if (log.isDebugEnabled()) { log.debug("channelOpenFailure({}) Received SSH_MSG_CHANNEL_OPEN_FAILURE", channel); } channels.remove(id); channel.handleOpenFailure(buffer); } /** * Process incoming data on a channel * * @param buffer the buffer containing the data * @throws IOException if an error occurs */ public void channelData(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); channel.handleData(buffer); } /** * Process incoming extended data on a channel * * @param buffer the buffer containing the data * @throws IOException if an error occurs */ public void channelExtendedData(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); channel.handleExtendedData(buffer); } /** * Process a window adjust packet on a channel * * @param buffer the buffer containing the window adjustment parameters * @throws IOException if an error occurs */ public void channelWindowAdjust(Buffer buffer) throws IOException { try { Channel channel = getChannel(buffer); channel.handleWindowAdjust(buffer); } catch (SshException e) { if (log.isDebugEnabled()) { log.debug("channelWindowAdjust {} error: {}", e.getClass().getSimpleName(), e.getMessage()); } } } /** * Process end of file on a channel * * @param buffer the buffer containing the packet * @throws IOException if an error occurs */ public void channelEof(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); channel.handleEof(); } /** * Close a channel due to a close packet received * * @param buffer the buffer containing the packet * @throws IOException if an error occurs */ public void channelClose(Buffer buffer) throws IOException { // Do not use getChannel to avoid the session being closed // if receiving the SSH_MSG_CHANNEL_CLOSE on an already closed channel int recipient = buffer.getInt(); Channel channel = channels.get(recipient); if (channel != null) { channel.handleClose(); } else { log.warn("Received SSH_MSG_CHANNEL_CLOSE on unknown channel " + recipient); } } /** * Service a request on a channel * * @param buffer the buffer containing the request * @throws IOException if an error occurs */ public void channelRequest(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); channel.handleRequest(buffer); } /** * Process a failure on a channel * * @param buffer the buffer containing the packet * @throws IOException if an error occurs */ public void channelFailure(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); channel.handleFailure(); } /** * Process a success on a channel * * @param buffer the buffer containing the packet * @throws IOException if an error occurs */ public void channelSuccess(Buffer buffer) throws IOException { Channel channel = getChannel(buffer); channel.handleSuccess(); } /** * Retrieve the channel designated by the given packet * * @param buffer the incoming packet * @return the target channel * @throws IOException if the channel does not exists */ protected Channel getChannel(Buffer buffer) throws IOException { return getChannel(buffer.getInt(), buffer); } protected Channel getChannel(int recipient, Buffer buffer) throws IOException { Channel channel = channels.get(recipient); if (channel == null) { byte[] data = buffer.array(); int curPos = buffer.rpos(); int cmd = (curPos >= 5) ? (data[curPos - 5] & 0xFF) : -1; throw new SshException("Received " + SshConstants.getCommandMessageName(cmd) + " on unknown channel " + recipient); } return channel; } protected void channelOpen(Buffer buffer) throws Exception { String type = buffer.getString(); final int sender = buffer.getInt(); final long rwsize = buffer.getUInt(); final long rmpsize = buffer.getUInt(); /* * NOTE: the 'sender' is the identifier assigned by the remote side - the server in this case */ if (log.isDebugEnabled()) { log.debug("channelOpen({}) SSH_MSG_CHANNEL_OPEN sender={}, type={}, window-size={}, packet-size={}", this, sender, type, rwsize, rmpsize); } if (isClosing()) { // TODO add language tag sendChannelOpenFailure(buffer, sender, SshConstants.SSH_OPEN_CONNECT_FAILED, "Server is shutting down while attempting to open channel type=" + type, ""); return; } if (!isAllowMoreSessions()) { // TODO add language tag sendChannelOpenFailure(buffer, sender, SshConstants.SSH_OPEN_CONNECT_FAILED, "additional sessions disabled", ""); return; } final Session session = getSession(); FactoryManager manager = Objects.requireNonNull(session.getFactoryManager(), "No factory manager"); final Channel channel = NamedFactory.create(manager.getChannelFactories(), type); if (channel == null) { // TODO add language tag sendChannelOpenFailure(buffer, sender, SshConstants.SSH_OPEN_UNKNOWN_CHANNEL_TYPE, "Unsupported channel type: " + type, ""); return; } final int channelId = registerChannel(channel); channel.open(sender, rwsize, rmpsize, buffer).addListener(future -> { try { if (future.isOpened()) { Window window = channel.getLocalWindow(); if (log.isDebugEnabled()) { log.debug("operationComplete({}) send SSH_MSG_CHANNEL_OPEN_CONFIRMATION recipient={}, sender={}, window-size={}, packet-size={}", channel, sender, channelId, window.getSize(), window.getPacketSize()); } Buffer buf = session.createBuffer(SshConstants.SSH_MSG_CHANNEL_OPEN_CONFIRMATION, Integer.SIZE); buf.putInt(sender); // remote (server side) identifier buf.putInt(channelId); // local (client side) identifier buf.putInt(window.getSize()); buf.putInt(window.getPacketSize()); session.writePacket(buf); } else { Throwable exception = future.getException(); if (exception != null) { String message = exception.getMessage(); int reasonCode = 0; if (exception instanceof OpenChannelException) { reasonCode = ((OpenChannelException) exception).getReasonCode(); } else { message = exception.getClass().getSimpleName() + " while opening channel: " + message; } Buffer buf = session.createBuffer(SshConstants.SSH_MSG_CHANNEL_OPEN_FAILURE, message.length() + Long.SIZE); sendChannelOpenFailure(buf, sender, reasonCode, message, ""); } } } catch (IOException e) { if (log.isDebugEnabled()) { log.debug("operationComplete({}) {}: {}", AbstractConnectionService.this, e.getClass().getSimpleName(), e.getMessage()); } session.exceptionCaught(e); } }); } protected IoWriteFuture sendChannelOpenFailure(Buffer buffer, int sender, int reasonCode, String message, String lang) throws IOException { if (log.isDebugEnabled()) { log.debug("sendChannelOpenFailure({}) sender={}, reason={}, lang={}, message='{}'", this, sender, SshConstants.getOpenErrorCodeName(reasonCode), lang, message); } Session session = getSession(); Buffer buf = session.createBuffer(SshConstants.SSH_MSG_CHANNEL_OPEN_FAILURE, Long.SIZE + GenericUtils.length(message) + GenericUtils.length(lang)); buf.putInt(sender); buf.putInt(reasonCode); buf.putString(message); buf.putString(lang); return session.writePacket(buf); } /** * Process global requests * * @param buffer The request {@link Buffer} * @throws Exception If failed to process the request */ protected void globalRequest(Buffer buffer) throws Exception { String req = buffer.getString(); boolean wantReply = buffer.getBoolean(); if (log.isDebugEnabled()) { log.debug("globalRequest({}) received SSH_MSG_GLOBAL_REQUEST {} want-reply={}", this, req, wantReply); } Session session = getSession(); FactoryManager manager = Objects.requireNonNull(session.getFactoryManager(), "No factory manager"); List<RequestHandler<ConnectionService>> handlers = manager.getGlobalRequestHandlers(); if (GenericUtils.size(handlers) > 0) { for (RequestHandler<ConnectionService> handler : handlers) { RequestHandler.Result result; try { result = handler.process(this, req, wantReply, buffer); } catch (Throwable e) { log.warn("globalRequest({})[{}, want-reply={}] failed ({}) to process: {}", this, req, wantReply, e.getClass().getSimpleName(), e.getMessage()); if (log.isDebugEnabled()) { log.debug("globalRequest(" + this + ")[" + req + ", want-reply=" + wantReply + "] failure details", e); } result = RequestHandler.Result.ReplyFailure; } // if Unsupported then check the next handler in line if (RequestHandler.Result.Unsupported.equals(result)) { if (log.isTraceEnabled()) { log.trace("globalRequest({}) {}#process({})[want-reply={}] : {}", this, handler.getClass().getSimpleName(), req, wantReply, result); } } else { sendGlobalResponse(buffer, req, result, wantReply); return; } } } handleUnknownRequest(buffer, req, wantReply); } protected void handleUnknownRequest(Buffer buffer, String req, boolean wantReply) throws IOException { log.warn("handleUnknownRequest({}) unknown global request: {}", this, req); sendGlobalResponse(buffer, req, RequestHandler.Result.Unsupported, wantReply); } protected IoWriteFuture sendGlobalResponse(Buffer buffer, String req, RequestHandler.Result result, boolean wantReply) throws IOException { if (log.isDebugEnabled()) { log.debug("sendGlobalResponse({})[{}] result={}, want-reply={}", this, req, result, wantReply); } if (RequestHandler.Result.Replied.equals(result) || (!wantReply)) { return new AbstractIoWriteFuture(null) { { setValue(Boolean.TRUE); } }; } byte cmd = RequestHandler.Result.ReplySuccess.equals(result) ? SshConstants.SSH_MSG_REQUEST_SUCCESS : SshConstants.SSH_MSG_REQUEST_FAILURE; Session session = getSession(); Buffer rsp = session.createBuffer(cmd, 2); return session.writePacket(rsp); } protected void requestSuccess(Buffer buffer) throws Exception { getSession().requestSuccess(buffer); } protected void requestFailure(Buffer buffer) throws Exception { getSession().requestFailure(buffer); } @Override public String toString() { return getClass().getSimpleName() + "[" + getSession() + "]"; } }