package org.ovirt.engine.core.uutils.ssh; import static org.junit.Assert.assertEquals; import static org.junit.Assume.assumeTrue; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.security.KeyPair; import java.security.KeyPairGenerator; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.util.Arrays; import java.util.LinkedList; import java.util.List; import javax.naming.AuthenticationException; import javax.naming.TimeLimitExceededException; import org.apache.commons.lang.SystemUtils; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; /* * Test properties * $ mvn -Dssh-host=host1 -Dssh-test-port=22 -Dssh-test-user=root -Dssh-test-password=password -Dssh-test-p12=a.p12 -Dssh-test-p12-alias=alias -Dssh-test-p12-password=password */ public class SSHDialogTest { private static final int BUFFER_SIZE = 10 * 1024; private static class Sink implements Runnable, SSHDialog.Sink { private SSHDialog.Control control; private BufferedReader incoming; private PrintWriter outgoing; private List<String> expect; private List<String> send; private Throwable throwable; private Thread thread; public Sink(String[] expect, String[] send) { this.expect = new LinkedList<>(Arrays.asList(expect)); this.send = new LinkedList<>(Arrays.asList(send)); thread = new Thread(this); } public void exception() throws Throwable { if (throwable != null) { throw throwable; } assertEquals(0, expect.size()); assertEquals(0, send.size()); } @Override public void setControl(SSHDialog.Control control) { this.control = control; } @Override public void setStreams(InputStream incoming, OutputStream outgoing) { this.incoming = incoming == null ? null : new BufferedReader( new InputStreamReader( incoming, StandardCharsets.UTF_8), BUFFER_SIZE); this.outgoing = outgoing == null ? null : new PrintWriter( new OutputStreamWriter( outgoing, StandardCharsets.UTF_8), true); } @Override public void start() { thread.start(); } @Override public void stop() { if (thread != null) { thread.interrupt(); while (true) { try { thread.join(); break; } catch (InterruptedException ignore) { } } thread = null; } } public void run() { try { while (expect.size() > 0) { assertEquals(expect.remove(0), incoming.readLine()); if (send.size() > 0) { String tosend = send.remove(0); if (tosend != null) { for (String s : tosend.split("\n")) { outgoing.println(s); } } } } } catch (Throwable t) { if (throwable == null) { throwable = t; } } finally { try { control.close(); } catch (IOException e) { throw new RuntimeException(e); } } } } private static String sshHost; private static String sshUser; private static String sshPassword; private static KeyPair sshKeyPair; private static int sshPort; private static SSHD sshd; private SSHDialog sshDialog; private static KeyPair getKeyPair(String p12, String alias, String password) throws KeyStoreException { KeyStore.PrivateKeyEntry entry; try (InputStream in = new FileInputStream(p12)) { KeyStore ks = KeyStore.getInstance("PKCS12"); ks.load(in, password.toCharArray()); entry = (KeyStore.PrivateKeyEntry) ks.getEntry( alias, new KeyStore.PasswordProtection( password.toCharArray())); } catch (Exception e) { throw new KeyStoreException( String.format( "Failed to get certificate entry from key store: %1$s/%2$s", p12, alias), e); } if (entry == null) { throw new KeyStoreException( String.format( "Bad key store: %1$s/%2$s", p12, alias)); } return new KeyPair( entry.getCertificate().getPublicKey(), entry.getPrivateKey()); } @BeforeClass public static void init() throws IOException { assumeTrue(SystemUtils.IS_OS_UNIX); sshHost = System.getProperty("ssh-host"); if (sshHost == null) { sshHost = "localhost"; sshUser = "root"; sshPassword = "password"; try { sshKeyPair = KeyPairGenerator.getInstance("RSA").generateKeyPair(); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } sshd = new SSHD(); sshd.setUser( sshUser, sshPassword, sshKeyPair.getPublic()); try { sshd.start(); } catch (IOException e) { throw new RuntimeException(e); } sshPort = sshd.getPort(); } else { sshPort = Integer.parseInt(System.getProperty("ssh-test-port", "22")); sshUser = System.getProperty("ssh-test-user", "root"); sshPassword = System.getProperty("ssh-test-password", "password"); try { sshKeyPair = getKeyPair( System.getProperty("ssh-test-p12", "src/test/resources/key.p12"), System.getProperty("ssh-test-p12-alias", "1"), System.getProperty("ssh-test-p12-password", "NoSoup4U")); } catch (KeyStoreException e) { throw new RuntimeException(e); } } System.out.println("Key fingerprint is: " + OpenSSHUtils.getKeyString(sshKeyPair.getPublic(), "test")); } @AfterClass public static void terminate() throws Exception { if (sshd != null) { sshd.stop(); } } @Before public void setUp() { sshDialog = new SSHDialog(); sshDialog.setHost(sshHost, sshPort); sshDialog.setPassword(sshPassword); sshDialog.setKeyPair(sshKeyPair); sshDialog.setSoftTimeout(10 * 1000); sshDialog.setHardTimeout(30 * 1000); } @After public void tearDown() throws Exception { if (sshDialog != null) { sshDialog.close(); sshDialog = null; } } @Test public void testKeyPair() throws Exception { sshDialog.connect(); sshDialog.authenticate(); } @Test public void testPassword() throws Exception { sshDialog.setKeyPair(null); sshDialog.connect(); sshDialog.authenticate(); } @Test(expected = AuthenticationException.class) public void testWrongKeyPair() throws Exception { sshDialog.setKeyPair( KeyPairGenerator.getInstance("RSA").generateKeyPair()); sshDialog.connect(); sshDialog.authenticate(); } @Test(expected = AuthenticationException.class) public void testWrongPassword() throws Exception { sshDialog.setKeyPair(null); sshDialog.setPassword("bad"); sshDialog.connect(); sshDialog.authenticate(); } @Test public void testSimple() throws Throwable { try (final InputStream start = new ByteArrayInputStream("start\n".getBytes("UTF-8"))) { Sink sink = new Sink( new String[] { "start", "text1", "text2" }, new String[] { "text1", "text2" }); sshDialog.connect(); sshDialog.authenticate(); sshDialog.executeCommand( sink, "cat", new InputStream[] { start }); sink.exception(); } } @Test(expected = TimeLimitExceededException.class) public void testTimeout() throws Throwable { Sink sink = new Sink( new String[] { "start" }, new String[] { }); sshDialog.setSoftTimeout(1 * 1000); sshDialog.connect(); sshDialog.authenticate(); sshDialog.executeCommand( sink, "cat", null); sink.exception(); } @Test(expected = RuntimeException.class) public void testStderr() throws Throwable { try (final InputStream start = new ByteArrayInputStream("start\n".getBytes("UTF-8"))) { Sink sink = new Sink( new String[] { "start", "text1", "text2" }, new String[] { "text1", "text2" }); sshDialog.connect(); sshDialog.authenticate(); sshDialog.executeCommand( sink, "echo message >&2 && cat", new InputStream[] { start }); sink.exception(); } } @Test public void testLong() throws Throwable { final String LINE = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASDSSSSSSSSSSSSSSSSSSSSSSSDDDDDD"; final int NUM = 10000; final int FACTOR = 5; StringBuilder longText = new StringBuilder(); for (int i = 0; i < NUM / FACTOR; i++) { longText.append(LINE).append("\n"); } List<String> expect = new LinkedList<>(); expect.add("start"); for (int i = 0; i < NUM; i++) { expect.add(LINE); } List<String> send = new LinkedList<>(); for (int i = 0; i < NUM; i++) { if (i % (NUM / FACTOR) == 0) { send.add(longText.toString()); } else { send.add(null); } } Sink sink = new Sink( expect.toArray(new String[0]), send.toArray(new String[0])); sshDialog.connect(); sshDialog.authenticate(); sshDialog.executeCommand( sink, "echo start && sleep 4 && cat", null); sink.exception(); } private static class ReaderSink implements Runnable, SSHDialog.Sink { private SSHDialog.Control control; private BufferedReader incoming; private PrintWriter outgoing; private Throwable throwable; private Thread thread; private int delay; private String last; public ReaderSink(int delay) { thread = new Thread(this); this.delay = delay; } public String getLast() { return last; } public void exception() throws Throwable { if (throwable != null) { throw throwable; } } @Override public void setControl(SSHDialog.Control control) { this.control = control; } @Override public void setStreams(InputStream incoming, OutputStream outgoing) { this.incoming = incoming == null ? null : new BufferedReader( new InputStreamReader( incoming, StandardCharsets.UTF_8), BUFFER_SIZE); this.outgoing = outgoing == null ? null : new PrintWriter( new OutputStreamWriter( outgoing, StandardCharsets.UTF_8), true); } @Override public void start() { thread.start(); } @Override public void stop() { if (thread != null) { while (true) { try { thread.join(); break; } catch (InterruptedException ignore) { } } thread = null; } } public void run() { try { String l; while ((l = incoming.readLine()) != null) { last = l; Thread.sleep(delay); } } catch (Throwable t) { if (throwable == null) { throwable = t; } } finally { try { control.close(); } catch (IOException e) { throw new RuntimeException(e); } } } } @Test public void testDelay() throws Throwable { ReaderSink sink = new ReaderSink(10); sshDialog.setSoftTimeout(60 * 1000); sshDialog.setHardTimeout(60 * 1000); sshDialog.connect(); sshDialog.authenticate(); sshDialog.executeCommand( sink, "x=0;while [ $x -lt 100 ]; do echo line$x; x=$(($x+1)); done", null); sink.exception(); assertEquals("line99", sink.getLast()); } }