/* * Copyright 2010 netling project <http://netling.org> * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.netling.ssh.connection; import java.util.HashSet; import java.util.LinkedList; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import org.netling.concurrent.Future; import org.netling.concurrent.FutureUtils; import org.netling.ssh.AbstractService; import org.netling.ssh.common.DisconnectReason; import org.netling.ssh.common.ErrorNotifiable; import org.netling.ssh.common.Message; import org.netling.ssh.common.SSHException; import org.netling.ssh.common.SSHPacket; import org.netling.ssh.connection.channel.Channel; import org.netling.ssh.connection.channel.OpenFailException; import org.netling.ssh.connection.channel.OpenFailException.Reason; import org.netling.ssh.connection.channel.forwarded.ForwardedChannelOpener; import org.netling.ssh.transport.Transport; import org.netling.ssh.transport.TransportException; /** {@link Connection} implementation. */ public class ConnectionImpl extends AbstractService implements Connection { public static final class GlobalRequestResult extends Future<SSHPacket, ConnectionException> { public GlobalRequestResult(String name) { super(name, ConnectionException.chainer); } } private final Object internalSynchronizer = new Object(); private final AtomicInteger nextID = new AtomicInteger(); private final Map<Integer, Channel> channels = new ConcurrentHashMap<Integer, Channel>(); private final Map<String, ForwardedChannelOpener> openers = new ConcurrentHashMap<String, ForwardedChannelOpener>(); private final Queue<GlobalRequestResult> globalReqFutures = new LinkedList<GlobalRequestResult>(); private int windowSize = 2048 * 1024; private int maxPacketSize = 32 * 1024; /** * Create with an associated {@link Transport}. * * @param trans transport layer */ public ConnectionImpl(Transport trans) { super("ssh-connection", trans); } @Override public void attach(Channel chan) { log.info("Attaching `{}` channel (#{})", chan.getType(), chan.getID()); channels.put(chan.getID(), chan); } @Override public Channel get(int id) { return channels.get(id); } @Override public ForwardedChannelOpener get(String chanType) { return openers.get(chanType); } @Override public void forget(Channel chan) { log.info("Forgetting `{}` channel (#{})", chan.getType(), chan.getID()); channels.remove(chan.getID()); synchronized (internalSynchronizer) { if (channels.isEmpty()) internalSynchronizer.notifyAll(); } } @Override public void forget(ForwardedChannelOpener opener) { log.info("Forgetting opener for `{}` channels: {}", opener.getChannelType(), opener); openers.remove(opener.getChannelType()); } @Override public void attach(ForwardedChannelOpener opener) { log.info("Attaching opener for `{}` channels: {}", opener.getChannelType(), opener); openers.put(opener.getChannelType(), opener); } private Channel getChannel(SSHPacket buffer) throws ConnectionException { int recipient = buffer.readInt(); Channel channel = get(recipient); if (channel != null) return channel; else { buffer.rpos(buffer.rpos() - 5); throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Received " + buffer.readMessageID() + " on unknown channel #" + recipient); } } @Override public void handle(Message msg, SSHPacket buf) throws SSHException { if (msg.in(91, 100)) getChannel(buf).handle(msg, buf); else if (msg.in(80, 90)) switch (msg) { case REQUEST_SUCCESS: gotGlobalReqResponse(buf); break; case REQUEST_FAILURE: gotGlobalReqResponse(null); break; case CHANNEL_OPEN: gotChannelOpen(buf); break; default: super.handle(msg, buf); } else super.handle(msg, buf); } @Override public void notifyError(SSHException error) { super.notifyError(error); synchronized (globalReqFutures) { FutureUtils.alertAll(error, globalReqFutures); globalReqFutures.clear(); } ErrorNotifiable.Util.alertAll(error, channels.values()); channels.clear(); } @Override public int getMaxPacketSize() { return maxPacketSize; } @Override public Transport getTransport() { return trans; } @Override public void setMaxPacketSize(int maxPacketSize) { this.maxPacketSize = maxPacketSize; } @Override public int getWindowSize() { return windowSize; } @Override public void setWindowSize(int windowSize) { this.windowSize = windowSize; } @Override public void join() throws InterruptedException { synchronized (internalSynchronizer) { while (!channels.isEmpty()) internalSynchronizer.wait(); } } @Override public int nextID() { return nextID.getAndIncrement(); } @Override public GlobalRequestResult sendGlobalRequest(String name, boolean wantReply, byte[] specifics) throws TransportException { synchronized (globalReqFutures) { log.info("Making global request for `{}`", name); trans.write(new SSHPacket(Message.GLOBAL_REQUEST).putString(name) .putBoolean(wantReply).putRawBytes(specifics)); GlobalRequestResult future = null; if (wantReply) { future = new GlobalRequestResult("global req for " + name); globalReqFutures.add(future); } return future; } } private void gotGlobalReqResponse(SSHPacket response) throws ConnectionException { synchronized (globalReqFutures) { final GlobalRequestResult gr = globalReqFutures.poll(); if (gr == null) throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Got a global request response when none was requested"); else if (response == null) gr.error(new ConnectionException("Global request [" + gr + "] failed")); else gr.set(response); } } private void gotChannelOpen(SSHPacket buf) throws ConnectionException, TransportException { final String type = buf.readString(); log.debug("Received CHANNEL_OPEN for `{}` channel", type); if (openers.containsKey(type)) openers.get(type).handleOpen(buf); else { log.warn("No opener found for `{}` CHANNEL_OPEN request -- rejecting", type); sendOpenFailure(buf.readInt(), OpenFailException.Reason.UNKNOWN_CHANNEL_TYPE, ""); } } @Override public void sendOpenFailure(int recipient, Reason reason, String message) throws TransportException { trans.write(new SSHPacket(Message.CHANNEL_OPEN_FAILURE) .putInt(recipient) .putInt(reason.getCode()) .putString(message)); } @Override public void notifyDisconnect() throws SSHException { super.notifyDisconnect(); final ConnectionException ex = new ConnectionException("Disconnected."); FutureUtils.alertAll(ex, globalReqFutures); ErrorNotifiable.Util.alertAll(ex, new HashSet<Channel>(channels.values())); } }