/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.client.subsystem.sftp; import java.io.ByteArrayOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.InterruptedIOException; import java.io.OutputStream; import java.io.StreamCorruptedException; import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import org.apache.sshd.client.channel.ChannelSubsystem; import org.apache.sshd.client.channel.ClientChannel; import org.apache.sshd.client.session.ClientSession; import org.apache.sshd.common.SshException; import org.apache.sshd.common.subsystem.sftp.SftpConstants; import org.apache.sshd.common.subsystem.sftp.extensions.ParserUtils; import org.apache.sshd.common.subsystem.sftp.extensions.VersionsParser.Versions; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.buffer.BufferUtils; import org.apache.sshd.common.util.buffer.ByteArrayBuffer; /** * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ public class DefaultSftpClient extends AbstractSftpClient { private final ClientSession clientSession; private final ChannelSubsystem channel; private final Map<Integer, Buffer> messages = new HashMap<>(); private final AtomicInteger cmdId = new AtomicInteger(100); private final Buffer receiveBuffer = new ByteArrayBuffer(); private final byte[] workBuf = new byte[Integer.BYTES]; private final AtomicInteger versionHolder = new AtomicInteger(0); private final AtomicBoolean closing = new AtomicBoolean(false); private final Map<String, byte[]> extensions = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); private final Map<String, byte[]> exposedExtensions = Collections.unmodifiableMap(extensions); public DefaultSftpClient(ClientSession clientSession) throws IOException { this.clientSession = Objects.requireNonNull(clientSession, "No client session"); this.channel = clientSession.createSubsystemChannel(SftpConstants.SFTP_SUBSYSTEM_NAME); this.channel.setOut(new OutputStream() { private final byte[] singleByte = new byte[1]; @Override public void write(int b) throws IOException { synchronized (singleByte) { singleByte[0] = (byte) b; write(singleByte); } } @Override public void write(byte[] b, int off, int len) throws IOException { data(b, off, len); } }); this.channel.setErr(new ByteArrayOutputStream(Byte.MAX_VALUE)); long initializationTimeout = clientSession.getLongProperty(SFTP_CHANNEL_OPEN_TIMEOUT, DEFAULT_CHANNEL_OPEN_TIMEOUT); this.channel.open().verify(initializationTimeout); this.channel.onClose(() -> { synchronized (messages) { closing.set(true); messages.notifyAll(); } if (versionHolder.get() <= 0) { log.warn("onClose({}) closed before version negotiated", channel); } }); try { init(initializationTimeout); } catch (IOException | RuntimeException e) { this.channel.close(true); throw e; } } @Override public int getVersion() { return versionHolder.get(); } @Override public ClientSession getClientSession() { return clientSession; } @Override public ClientChannel getClientChannel() { return channel; } @Override public Map<String, byte[]> getServerExtensions() { return exposedExtensions; } @Override public boolean isClosing() { return closing.get(); } @Override public boolean isOpen() { return this.channel.isOpen(); } @Override public void close() throws IOException { if (isOpen()) { this.channel.close(false); } } /** * Receive binary data * @param buf The buffer for the incoming data * @param start Offset in buffer to place the data * @param len Available space in buffer for the data * @return Actual size of received data * @throws IOException If failed to receive incoming data */ protected int data(byte[] buf, int start, int len) throws IOException { Buffer incoming = new ByteArrayBuffer(buf, start, len); // If we already have partial data, we need to append it to the buffer and use it if (receiveBuffer.available() > 0) { receiveBuffer.putBuffer(incoming); incoming = receiveBuffer; } // Process commands int rpos = incoming.rpos(); for (int count = 1; receive(incoming); count++) { if (log.isTraceEnabled()) { log.trace("data({}) Processed {} data messages", getClientChannel(), count); } } int read = incoming.rpos() - rpos; // Compact and add remaining data receiveBuffer.compact(); if (receiveBuffer != incoming && incoming.available() > 0) { receiveBuffer.putBuffer(incoming); } return read; } /** * Read SFTP packets from buffer * * @param incoming The received {@link Buffer} * @return {@code true} if data from incoming buffer was processed * @throws IOException if failed to process the buffer * @see #process(Buffer) */ protected boolean receive(Buffer incoming) throws IOException { int rpos = incoming.rpos(); int wpos = incoming.wpos(); ClientSession session = getClientSession(); session.resetIdleTimeout(); if ((wpos - rpos) > 4) { int length = incoming.getInt(); if (length < 5) { throw new IOException("Illegal sftp packet length: " + length); } if ((wpos - rpos) >= (length + 4)) { incoming.rpos(rpos); incoming.wpos(rpos + 4 + length); process(incoming); incoming.rpos(rpos + 4 + length); incoming.wpos(wpos); return true; } } incoming.rpos(rpos); return false; } /** * Process an SFTP packet * * @param incoming The received {@link Buffer} * @throws IOException if failed to process the buffer */ protected void process(Buffer incoming) throws IOException { // create a copy of the buffer in case it is being re-used Buffer buffer = new ByteArrayBuffer(incoming.available() + Long.SIZE, false); buffer.putBuffer(incoming); int rpos = buffer.rpos(); int length = buffer.getInt(); int type = buffer.getUByte(); Integer id = buffer.getInt(); buffer.rpos(rpos); if (log.isTraceEnabled()) { log.trace("process({}) id={}, type={}, len={}", getClientChannel(), id, SftpConstants.getCommandMessageName(type), length); } synchronized (messages) { messages.put(id, buffer); messages.notifyAll(); } } @Override public int send(int cmd, Buffer buffer) throws IOException { int id = cmdId.incrementAndGet(); int len = buffer.available(); if (log.isTraceEnabled()) { log.trace("send({}) cmd={}, len={}, id={}", getClientChannel(), SftpConstants.getCommandMessageName(cmd), len, id); } OutputStream dos = channel.getInvertedIn(); BufferUtils.writeInt(dos, 1 /* cmd */ + Integer.BYTES /* id */ + len, workBuf); dos.write(cmd & 0xFF); BufferUtils.writeInt(dos, id, workBuf); dos.write(buffer.array(), buffer.rpos(), len); dos.flush(); return id; } @Override public Buffer receive(int id) throws IOException { Integer reqId = id; synchronized (messages) { for (int count = 1;; count++) { if (isClosing() || (!isOpen())) { throw new SshException("Channel is being closed"); } Buffer buffer = messages.remove(reqId); if (buffer != null) { return buffer; } try { messages.wait(); } catch (InterruptedException e) { throw (IOException) new InterruptedIOException("Interrupted while waiting for messages at iteration #" + count).initCause(e); } } } } protected Buffer read() throws IOException { InputStream dis = channel.getInvertedOut(); int length = BufferUtils.readInt(dis, workBuf); // must have at least command + length if (length < (1 + Integer.BYTES)) { throw new IllegalArgumentException("Bad length: " + length); } Buffer buffer = new ByteArrayBuffer(length + Integer.BYTES, false); buffer.putInt(length); int nb = length; while (nb > 0) { int readLen = dis.read(buffer.array(), buffer.wpos(), nb); if (readLen < 0) { throw new IllegalArgumentException("Premature EOF while read " + length + " bytes - remaining=" + nb); } buffer.wpos(buffer.wpos() + readLen); nb -= readLen; } return buffer; } protected void init(long initializationTimeout) throws IOException { ValidateUtils.checkTrue(initializationTimeout > 0L, "Invalid initialization timeout: %d", initializationTimeout); // Send init packet OutputStream dos = channel.getInvertedIn(); BufferUtils.writeInt(dos, 5 /* total length */, workBuf); dos.write(SftpConstants.SSH_FXP_INIT); BufferUtils.writeInt(dos, SftpConstants.SFTP_V6, workBuf); dos.flush(); Buffer buffer; Integer reqId; synchronized (messages) { /* * We need to use a timeout since if the remote server does not support * SFTP, we will not know it immediately. This is due to the fact that the * request for the subsystem does not contain a reply as to its success or * failure. Thus, the SFTP channel is created by the client, but there is * no one on the other side to reply - thus the need for the timeout */ for (long remainingTimeout = initializationTimeout; (remainingTimeout > 0L) && messages.isEmpty() && (!isClosing()) && isOpen();) { try { long sleepStart = System.nanoTime(); messages.wait(remainingTimeout); long sleepEnd = System.nanoTime(); long sleepDuration = sleepEnd - sleepStart; long sleepMillis = TimeUnit.NANOSECONDS.toMillis(sleepDuration); if (sleepMillis < 1L) { remainingTimeout--; } else { remainingTimeout -= sleepMillis; } } catch (InterruptedException e) { throw (IOException) new InterruptedIOException("Interrupted init()").initCause(e); } } if (isClosing() || (!isOpen())) { throw new EOFException("Closing while await init message"); } if (messages.isEmpty()) { throw new SocketTimeoutException("No incoming initialization response received within " + initializationTimeout + " msec."); } Collection<Integer> ids = messages.keySet(); Iterator<Integer> iter = ids.iterator(); reqId = iter.next(); buffer = messages.remove(reqId); } int length = buffer.getInt(); int type = buffer.getUByte(); int id = buffer.getInt(); if (log.isTraceEnabled()) { log.trace("init({}) id={} type={} len={}", getClientChannel(), id, SftpConstants.getCommandMessageName(type), length); } if (type == SftpConstants.SSH_FXP_VERSION) { if (id < SftpConstants.SFTP_V3) { throw new SshException("Unsupported sftp version " + id); } versionHolder.set(id); if (log.isTraceEnabled()) { log.trace("init({}) version={}", getClientChannel(), versionHolder); } while (buffer.available() > 0) { String name = buffer.getString(); byte[] data = buffer.getBytes(); if (log.isTraceEnabled()) { log.trace("init({}) added extension=", getClientChannel(), name); } extensions.put(name, data); } } else if (type == SftpConstants.SSH_FXP_STATUS) { int substatus = buffer.getInt(); String msg = buffer.getString(); String lang = buffer.getString(); if (log.isTraceEnabled()) { log.trace("init({})[id={}] - status: {} [{}] {}", getClientChannel(), id, SftpConstants.getStatusName(substatus), lang, msg); } throwStatusException(SftpConstants.SSH_FXP_INIT, id, substatus, msg, lang); } else { handleUnexpectedPacket(SftpConstants.SSH_FXP_INIT, SftpConstants.SSH_FXP_VERSION, id, type, length, buffer); } } /** * @param selector The {@link SftpVersionSelector} to use - ignored if {@code null} * @return The selected version (may be same as current) * @throws IOException If failed to negotiate */ public int negotiateVersion(SftpVersionSelector selector) throws IOException { int current = getVersion(); if (selector == null) { return current; } Set<Integer> available = GenericUtils.asSortedSet(Collections.singleton(current)); Map<String, ?> parsed = getParsedServerExtensions(); Collection<String> extensions = ParserUtils.supportedExtensions(parsed); if ((GenericUtils.size(extensions) > 0) && extensions.contains(SftpConstants.EXT_VERSION_SELECT)) { Versions vers = GenericUtils.isEmpty(parsed) ? null : (Versions) parsed.get(SftpConstants.EXT_VERSIONS); Collection<String> reported = (vers == null) ? null : vers.getVersions(); if (GenericUtils.size(reported) > 0) { for (String v : reported) { if (!available.add(Integer.valueOf(v))) { continue; // debug breakpoint } } } } int selected = selector.selectVersion(getClientSession(), current, new ArrayList<>(available)); if (log.isDebugEnabled()) { log.debug("negotiateVersion({}) current={} {} -> {}", getClientChannel(), current, available, selected); } if (selected == current) { return current; } if (!available.contains(selected)) { throw new StreamCorruptedException("Selected version (" + selected + ") not part of available: " + available); } String verVal = String.valueOf(selected); Buffer buffer = new ByteArrayBuffer(Integer.BYTES + SftpConstants.EXT_VERSION_SELECT.length() // extension name + Integer.BYTES + verVal.length() + Byte.SIZE, false); buffer.putString(SftpConstants.EXT_VERSION_SELECT); buffer.putString(verVal); checkCommandStatus(SftpConstants.SSH_FXP_EXTENDED, buffer); versionHolder.set(selected); return selected; } }