/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sshd.client.kex; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.OutputStream; import java.io.PipedInputStream; import java.io.PipedOutputStream; import java.nio.charset.StandardCharsets; import java.util.Collection; import java.util.Collections; import java.util.EnumSet; import java.util.concurrent.TimeUnit; import org.apache.sshd.client.ClientBuilder; import org.apache.sshd.client.SshClient; import org.apache.sshd.client.channel.ClientChannel; import org.apache.sshd.client.channel.ClientChannelEvent; import org.apache.sshd.client.session.ClientSession; import org.apache.sshd.common.NamedFactory; import org.apache.sshd.common.channel.Channel; import org.apache.sshd.common.kex.BuiltinDHFactories; import org.apache.sshd.common.kex.KeyExchange; import org.apache.sshd.common.util.security.SecurityUtils; import org.apache.sshd.server.SshServer; import org.apache.sshd.util.test.BaseTestSupport; import org.apache.sshd.util.test.TeeOutputStream; import org.apache.sshd.util.test.Utils; import org.junit.AfterClass; import org.junit.Assume; import org.junit.BeforeClass; import org.junit.FixMethodOrder; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.MethodSorters; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; /** * Test client key exchange algorithms. * * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> */ @FixMethodOrder(MethodSorters.NAME_ASCENDING) @RunWith(Parameterized.class) // see https://github.com/junit-team/junit/wiki/Parameterized-tests public class KexTest extends BaseTestSupport { private static SshServer sshd; private static int port; private static SshClient client; private final BuiltinDHFactories factory; public KexTest(BuiltinDHFactories factory) { this.factory = factory; } @Parameters(name = "Factory={0}") public static Collection<Object[]> parameters() { return parameterize(BuiltinDHFactories.VALUES); } @BeforeClass public static void setupClientAndServer() throws Exception { sshd = Utils.setupTestServer(KexTest.class); sshd.start(); port = sshd.getPort(); client = Utils.setupTestClient(KexTest.class); client.start(); } @AfterClass public static void tearDownClientAndServer() throws Exception { if (sshd != null) { try { sshd.stop(true); } finally { sshd = null; } } if (client != null) { try { client.stop(); } finally { client = null; } } } @Test public void testClientKeyExchange() throws Exception { if (factory.isGroupExchange()) { assertEquals(factory.getName() + " not supported even though DH group exchange supported", SecurityUtils.isDHGroupExchangeSupported(), factory.isSupported()); } Assume.assumeTrue(factory.getName() + " not supported", factory.isSupported()); testClient(ClientBuilder.DH2KEX.apply(factory)); } private void testClient(NamedFactory<KeyExchange> kex) throws Exception { try (ByteArrayOutputStream sent = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) { client.setKeyExchangeFactories(Collections.singletonList(kex)); try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession()) { session.addPasswordIdentity(getCurrentTestName()); session.auth().verify(5L, TimeUnit.SECONDS); try (ClientChannel channel = session.createChannel(Channel.CHANNEL_SHELL); PipedOutputStream pipedIn = new PipedOutputStream(); InputStream inPipe = new PipedInputStream(pipedIn); ByteArrayOutputStream err = new ByteArrayOutputStream(); OutputStream teeOut = new TeeOutputStream(sent, pipedIn)) { channel.setIn(inPipe); channel.setOut(out); channel.setErr(err); channel.open().verify(9L, TimeUnit.SECONDS); teeOut.write("this is my command\n".getBytes(StandardCharsets.UTF_8)); teeOut.flush(); StringBuilder sb = new StringBuilder(); for (int i = 0; i < 10; i++) { sb.append("0123456789"); } sb.append("\n"); teeOut.write(sb.toString().getBytes(StandardCharsets.UTF_8)); teeOut.write("exit\n".getBytes(StandardCharsets.UTF_8)); teeOut.flush(); Collection<ClientChannelEvent> result = channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TimeUnit.SECONDS.toMillis(15L)); assertFalse("Timeout while waiting for channel closure", result.contains(ClientChannelEvent.TIMEOUT)); } } assertArrayEquals(kex.getName(), sent.toByteArray(), out.toByteArray()); } } }