/* * Copyright (C)2009 - SSHJ Contributors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.schmizz.sshj.xfer.scp; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.LoggerFactory; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.connection.channel.direct.Session.Command; import net.schmizz.sshj.connection.channel.direct.SessionFactory; import net.schmizz.sshj.xfer.TransferListener; import org.slf4j.Logger; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; /** @see <a href="https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works">SCP Protocol</a> */ class SCPEngine { private static final char LF = '\n'; private final LoggerFactory loggerFactory; private final Logger log; private final SessionFactory host; private final TransferListener listener; private Command scp; private int exitStatus; SCPEngine(SessionFactory host, TransferListener listener, LoggerFactory loggerFactory) { this.host = host; this.listener = listener; this.loggerFactory = loggerFactory; log = loggerFactory.getLogger(getClass()); } public int getExitStatus() { return exitStatus; } void check(String what) throws IOException { int code = scp.getInputStream().read(); switch (code) { case -1: String stderr = IOUtils.readFully(scp.getErrorStream(), loggerFactory).toString(); if (!stderr.isEmpty()) stderr = ". Additional info: `" + stderr + "`"; throw new SCPException("EOF while expecting response to protocol message" + stderr); case 0: // OK log.debug(what); return; case 1: // Warning? not case 2: final String remoteMessage = readMessage(); throw new SCPRemoteException("Remote SCP command had error: " + remoteMessage, remoteMessage); default: throw new SCPException("Received unknown response code"); } } void cleanSlate() { exitStatus = -1; } void execSCPWith(ScpCommandLine commandLine) throws SSHException { scp = host.startSession().exec(commandLine.toCommandLine()); } void exit() { if (scp != null) { IOUtils.closeQuietly(scp); if (scp.getExitStatus() != null) { exitStatus = scp.getExitStatus(); if (scp.getExitStatus() != 0) log.warn("SCP exit status: {}", scp.getExitStatus()); } else { exitStatus = -1; } if (scp.getExitSignal() != null) { log.warn("SCP exit signal: {}", scp.getExitSignal()); } } scp = null; } String readMessage() throws IOException { final ByteArrayOutputStream baos = new ByteArrayOutputStream(); int x; while ((x = scp.getInputStream().read()) != LF) { if (x == -1) { if (baos.size() == 0) { return ""; } else { throw new IOException("EOF while reading message"); } } else { baos.write(x); } } final String msg = baos.toString(IOUtils.UTF8.displayName()); log.debug("Read message: `{}`", msg); return msg; } void sendMessage(String msg) throws IOException { log.debug("Sending message: {}", msg); scp.getOutputStream().write((msg + LF).getBytes(scp.getRemoteCharset())); scp.getOutputStream().flush(); check("Message ACK received"); } void signal(String what) throws IOException { log.debug("Signalling: {}", what); scp.getOutputStream().write(0); scp.getOutputStream().flush(); } long transferToRemote(StreamCopier.Listener listener, InputStream src, long length) throws IOException { return new StreamCopier(src, scp.getOutputStream(), loggerFactory) .bufSize(scp.getRemoteMaxPacketSize()).length(length) .keepFlushing(false) .listener(listener) .copy(); } long transferFromRemote(StreamCopier.Listener listener, OutputStream dest, long length) throws IOException { return new StreamCopier(scp.getInputStream(), dest, loggerFactory) .bufSize(scp.getLocalMaxPacketSize()).length(length) .keepFlushing(false) .listener(listener) .copy(); } TransferListener getTransferListener() { return listener; } }