/* * Copyright 2014-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.integration.samples.sftp; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.file.Paths; import java.security.KeyFactory; import java.security.PublicKey; import java.security.spec.RSAPublicKeySpec; import java.util.Arrays; import java.util.Collections; import org.apache.sshd.common.file.virtualfs.VirtualFileSystemFactory; import org.apache.sshd.server.SshServer; import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider; import org.apache.sshd.server.subsystem.sftp.SftpSubsystemFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.context.SmartLifecycle; import org.springframework.core.io.ClassPathResource; import org.springframework.integration.sftp.session.DefaultSftpSessionFactory; import org.springframework.util.Base64Utils; import org.springframework.util.StreamUtils; /** * @author Artem Bilan */ public class EmbeddedSftpServer implements InitializingBean, SmartLifecycle { /** * Let OS to obtain the proper port */ public static final int PORT = 0; private final SshServer server = SshServer.setUpDefaultServer(); private volatile int port; private volatile boolean running; private DefaultSftpSessionFactory defaultSftpSessionFactory; public void setPort(int port) { this.port = port; } public void setDefaultSftpSessionFactory(DefaultSftpSessionFactory defaultSftpSessionFactory) { this.defaultSftpSessionFactory = defaultSftpSessionFactory; } @Override public void afterPropertiesSet() throws Exception { final PublicKey allowedKey = decodePublicKey(); this.server.setPublickeyAuthenticator((username, key, session) -> key.equals(allowedKey)); this.server.setPort(this.port); this.server.setKeyPairProvider(new SimpleGeneratorHostKeyProvider(new File("hostkey.ser"))); this.server.setSubsystemFactories(Collections.singletonList(new SftpSubsystemFactory())); final String pathname = System.getProperty("java.io.tmpdir") + File.separator + "sftptest" + File.separator; new File(pathname).mkdirs(); server.setFileSystemFactory(new VirtualFileSystemFactory(Paths.get(pathname))); } private PublicKey decodePublicKey() throws Exception { InputStream stream = new ClassPathResource("META-INF/keys/sftp_rsa.pub").getInputStream(); byte[] keyBytes = StreamUtils.copyToByteArray(stream); // strip any newline chars while (keyBytes[keyBytes.length - 1] == 0x0a || keyBytes[keyBytes.length - 1] == 0x0d) { keyBytes = Arrays.copyOf(keyBytes, keyBytes.length - 1); } byte[] decodeBuffer = Base64Utils.decode(keyBytes); ByteBuffer bb = ByteBuffer.wrap(decodeBuffer); int len = bb.getInt(); byte[] type = new byte[len]; bb.get(type); if ("ssh-rsa".equals(new String(type))) { BigInteger e = decodeBigInt(bb); BigInteger m = decodeBigInt(bb); RSAPublicKeySpec spec = new RSAPublicKeySpec(m, e); return KeyFactory.getInstance("RSA").generatePublic(spec); } else { throw new IllegalArgumentException("Only supports RSA"); } } private BigInteger decodeBigInt(ByteBuffer bb) { int len = bb.getInt(); byte[] bytes = new byte[len]; bb.get(bytes); return new BigInteger(bytes); } @Override public boolean isAutoStartup() { return PORT == this.port; } @Override public int getPhase() { return Integer.MAX_VALUE; } @Override public void start() { try { this.server.start(); this.defaultSftpSessionFactory.setPort(this.server.getPort()); this.running = true; } catch (IOException e) { throw new IllegalStateException(e); } } @Override public void stop(Runnable callback) { stop(); callback.run(); } @Override public void stop() { if (this.running) { try { server.stop(true); } catch (Exception e) { throw new IllegalStateException(e); } finally { this.running = false; } } } @Override public boolean isRunning() { return this.running; } }