/* * Copyright (C)2009 - SSHJ Contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.schmizz.sshj.connection; import net.schmizz.concurrent.ErrorDeliveryUtil; import net.schmizz.concurrent.Promise; import net.schmizz.keepalive.KeepAlive; import net.schmizz.keepalive.KeepAliveProvider; import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.common.*; import net.schmizz.sshj.connection.channel.Channel; import net.schmizz.sshj.connection.channel.OpenFailException.Reason; import net.schmizz.sshj.connection.channel.forwarded.ForwardedChannelOpener; import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.TransportException; import java.util.LinkedList; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; /** {@link Connection} implementation. */ public class ConnectionImpl extends AbstractService implements Connection { private final Object internalSynchronizer = new Object(); private final AtomicInteger nextID = new AtomicInteger(); private final Map<Integer, Channel> channels = new ConcurrentHashMap<Integer, Channel>(); private final Map<String, ForwardedChannelOpener> openers = new ConcurrentHashMap<String, ForwardedChannelOpener>(); private final Queue<Promise<SSHPacket, ConnectionException>> globalReqPromises = new LinkedList<Promise<SSHPacket, ConnectionException>>(); /** {@code keep-alive} mechanism */ private final KeepAlive keepAlive; private long windowSize = 2048 * 1024; private int maxPacketSize = 32 * 1024; private volatile int timeoutMs; /** * Create with an associated {@link Transport}. * * @param trans transport layer * @param keepAlive the keep alive provider */ public ConnectionImpl(Transport trans, KeepAliveProvider keepAlive) { super("ssh-connection", trans); timeoutMs = trans.getTimeoutMs(); this.keepAlive = keepAlive.provide(this); } @Override public void attach(Channel chan) { log.debug("Attaching `{}` channel (#{})", chan.getType(), chan.getID()); channels.put(chan.getID(), chan); } @Override public Channel get(int id) { return channels.get(id); } @Override public ForwardedChannelOpener get(String chanType) { return openers.get(chanType); } @Override public void forget(Channel chan) { log.debug("Forgetting `{}` channel (#{})", chan.getType(), chan.getID()); channels.remove(chan.getID()); synchronized (internalSynchronizer) { if (channels.isEmpty()) internalSynchronizer.notifyAll(); } } @Override public void forget(ForwardedChannelOpener opener) { log.debug("Forgetting opener for `{}` channels: {}", opener.getChannelType(), opener); openers.remove(opener.getChannelType()); } @Override public void attach(ForwardedChannelOpener opener) { log.debug("Attaching opener for `{}` channels: {}", opener.getChannelType(), opener); openers.put(opener.getChannelType(), opener); } private Channel getChannel(SSHPacket buffer) throws ConnectionException { try { final int recipient = buffer.readUInt32AsInt(); final Channel channel = get(recipient); if (channel != null) return channel; else { buffer.rpos(buffer.rpos() - 5); throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Received " + buffer.readMessageID() + " on unknown channel #" + recipient); } } catch (Buffer.BufferException be) { throw new ConnectionException(be); } } @Override public void handle(Message msg, SSHPacket buf) throws SSHException { if (msg.in(91, 100)) { getChannel(buf).handle(msg, buf); } else if (msg.in(80, 90)) { switch (msg) { case REQUEST_SUCCESS: gotGlobalReqResponse(buf); break; case REQUEST_FAILURE: gotGlobalReqResponse(null); break; case CHANNEL_OPEN: gotChannelOpen(buf); break; default: super.handle(msg, buf); break; } } else { super.handle(msg, buf); } } @Override public int getMaxPacketSize() { return maxPacketSize; } @Override public Transport getTransport() { return trans; } @Override public void setMaxPacketSize(int maxPacketSize) { this.maxPacketSize = maxPacketSize; } @Override public long getWindowSize() { return windowSize; } @Override public void setWindowSize(long windowSize) { this.windowSize = windowSize; } @Override public void join() throws InterruptedException { synchronized (internalSynchronizer) { while (!channels.isEmpty()) { internalSynchronizer.wait(); } } } @Override public int nextID() { return nextID.getAndIncrement(); } @Override public Promise<SSHPacket, ConnectionException> sendGlobalRequest(String name, boolean wantReply, byte[] specifics) throws TransportException { synchronized (globalReqPromises) { log.debug("Making global request for `{}`", name); trans.write(new SSHPacket(Message.GLOBAL_REQUEST).putString(name) .putBoolean(wantReply) .putRawBytes(specifics)); Promise<SSHPacket, ConnectionException> promise = null; if (wantReply) { promise = new Promise<SSHPacket, ConnectionException>("global req for " + name, ConnectionException.chainer, trans.getConfig().getLoggerFactory()); globalReqPromises.add(promise); } return promise; } } private void gotGlobalReqResponse(SSHPacket response) throws ConnectionException { synchronized (globalReqPromises) { Promise<SSHPacket, ConnectionException> gr = globalReqPromises.poll(); if (gr == null) { throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Got a global request response when none was requested"); } else if (response == null) { gr.deliverError(new ConnectionException("Global request [" + gr + "] failed")); } else { // To prevent a race condition, copy the packet before delivering, as it will be handled in a different thread. gr.deliver(new SSHPacket(response)); } } } private void gotChannelOpen(SSHPacket buf) throws ConnectionException, TransportException { try { final String type = buf.readString(); log.debug("Received CHANNEL_OPEN for `{}` channel", type); if (openers.containsKey(type)) openers.get(type).handleOpen(buf); else { log.warn("No opener found for `{}` CHANNEL_OPEN request -- rejecting", type); sendOpenFailure(buf.readUInt32AsInt(), Reason.UNKNOWN_CHANNEL_TYPE, ""); } } catch (Buffer.BufferException be) { throw new ConnectionException(be); } } @Override public void sendOpenFailure(int recipient, Reason reason, String message) throws TransportException { trans.write(new SSHPacket(Message.CHANNEL_OPEN_FAILURE) .putUInt32(recipient) .putUInt32(reason.getCode()) .putString(message)); } @Override public void notifyError(SSHException error) { super.notifyError(error); synchronized (globalReqPromises) { ErrorDeliveryUtil.alertPromises(error, globalReqPromises); globalReqPromises.clear(); } keepAlive.interrupt(); ErrorNotifiable.Util.alertAll(error, channels.values()); channels.clear(); } @Override public void setTimeoutMs(int timeoutMs) { this.timeoutMs = timeoutMs; } @Override public int getTimeoutMs() { return timeoutMs; } @Override public KeepAlive getKeepAlive() { return keepAlive; } }