/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.forward;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoAcceptor;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.channel.ChannelDirectTcpip;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.client.session.forward.ExplicitPortForwardingTracker;
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.PropertyResolverUtils;
import org.apache.sshd.common.session.ConnectionService;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
import org.apache.sshd.server.global.CancelTcpipForwardHandler;
import org.apache.sshd.server.global.TcpipForwardHandler;
import org.apache.sshd.util.test.BaseTestSupport;
import org.apache.sshd.util.test.JSchLogger;
import org.apache.sshd.util.test.SimpleUserInfo;
import org.apache.sshd.util.test.Utils;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Port forwarding tests
*/
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class PortForwardingTest extends BaseTestSupport {
@SuppressWarnings("checkstyle:anoninnerlength")
private static final PortForwardingEventListener SERVER_SIDE_LISTENER = new PortForwardingEventListener() {
private final org.slf4j.Logger log = LoggerFactory.getLogger(PortForwardingEventListener.class);
@Override
public void establishingExplicitTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local,
SshdSocketAddress remote, boolean localForwarding) throws IOException {
log.info("establishingExplicitTunnel(session={}, local={}, remote={}, localForwarding={})",
session, local, remote, localForwarding);
}
@Override
public void establishedExplicitTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local,
SshdSocketAddress remote, boolean localForwarding, SshdSocketAddress boundAddress, Throwable reason)
throws IOException {
log.info("establishedExplicitTunnel(session={}, local={}, remote={}, bound={}, localForwarding={}): {}",
session, local, remote, boundAddress, localForwarding, reason);
}
@Override
public void tearingDownExplicitTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address,
boolean localForwarding) throws IOException {
log.info("tearingDownExplicitTunnel(session={}, address={}, localForwarding={})", session, address, localForwarding);
}
@Override
public void tornDownExplicitTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address,
boolean localForwarding, Throwable reason) throws IOException {
log.info("tornDownExplicitTunnel(session={}, address={}, localForwarding={}, reason={})",
session, address, localForwarding, reason);
}
@Override
public void establishingDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local)
throws IOException {
log.info("establishingDynamicTunnel(session={}, local={})", session, local);
}
@Override
public void establishedDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local,
SshdSocketAddress boundAddress, Throwable reason) throws IOException {
log.info("establishedDynamicTunnel(session={}, local={}, bound={}, reason={})", session, local, boundAddress, reason);
}
@Override
public void tearingDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address)
throws IOException {
log.info("tearingDownDynamicTunnel(session={}, address={})", session, address);
}
@Override
public void tornDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address,
Throwable reason) throws IOException {
log.info("tornDownDynamicTunnel(session={}, address={}, reason={})", session, address, reason);
}
};
private static final BlockingQueue<String> REQUESTS_QUEUE = new LinkedBlockingDeque<>();
private static SshServer sshd;
private static int sshPort;
private static int echoPort;
private static IoAcceptor acceptor;
private static SshClient client;
private final Logger log = LoggerFactory.getLogger(getClass());
public PortForwardingTest() {
super();
}
@BeforeClass
public static void setUpTestEnvironment() throws Exception {
JSchLogger.init();
sshd = Utils.setupTestServer(PortForwardingTest.class);
PropertyResolverUtils.updateProperty(sshd, FactoryManager.WINDOW_SIZE, 2048);
PropertyResolverUtils.updateProperty(sshd, FactoryManager.MAX_PACKET_SIZE, 256);
sshd.setTcpipForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
sshd.addPortForwardingEventListener(SERVER_SIDE_LISTENER);
sshd.start();
sshPort = sshd.getPort();
if (!REQUESTS_QUEUE.isEmpty()) {
REQUESTS_QUEUE.clear();
}
TcpipForwarderFactory factory = Objects.requireNonNull(sshd.getTcpipForwarderFactory(), "No TcpipForwarderFactory");
sshd.setTcpipForwarderFactory(new TcpipForwarderFactory() {
private final Class<?>[] interfaces = {TcpipForwarder.class};
private final Map<String, String> method2req =
GenericUtils.<String, String>mapBuilder(String.CASE_INSENSITIVE_ORDER)
.put("localPortForwardingRequested", TcpipForwardHandler.REQUEST)
.put("localPortForwardingCancelled", CancelTcpipForwardHandler.REQUEST)
.build();
@Override
public TcpipForwarder create(ConnectionService service) {
Thread thread = Thread.currentThread();
ClassLoader cl = thread.getContextClassLoader();
TcpipForwarder forwarder = factory.create(service);
return (TcpipForwarder) Proxy.newProxyInstance(cl, interfaces, new InvocationHandler() {
private final org.slf4j.Logger log = LoggerFactory.getLogger(TcpipForwarder.class);
@SuppressWarnings("synthetic-access")
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Object result = method.invoke(forwarder, args);
String name = method.getName();
String request = method2req.get(name);
if (GenericUtils.length(request) > 0) {
if (REQUESTS_QUEUE.offer(request)) {
log.info("Signal " + request);
} else {
log.error("Failed to offer request=" + request);
}
}
return result;
}
});
}
});
NioSocketAcceptor acceptor = new NioSocketAcceptor();
acceptor.setHandler(new IoHandlerAdapter() {
@Override
public void messageReceived(IoSession session, Object message) throws Exception {
IoBuffer recv = (IoBuffer) message;
IoBuffer sent = IoBuffer.allocate(recv.remaining());
sent.put(recv);
sent.flip();
session.write(sent);
}
});
acceptor.setReuseAddress(true);
acceptor.bind(new InetSocketAddress(0));
echoPort = acceptor.getLocalAddress().getPort();
client = Utils.setupTestClient(PortForwardingTest.class);
client.start();
}
@AfterClass
public static void tearDownTestEnvironment() throws Exception {
if (sshd != null) {
sshd.stop(true);
}
if (acceptor != null) {
acceptor.dispose(true);
}
if (client != null) {
client.stop();
}
}
private void waitForForwardingRequest(String expected, long timeout) throws InterruptedException {
for (long remaining = timeout; remaining > 0L;) {
long waitStart = System.currentTimeMillis();
String actual = REQUESTS_QUEUE.poll(remaining, TimeUnit.MILLISECONDS);
long waitEnd = System.currentTimeMillis();
if (GenericUtils.isEmpty(actual)) {
throw new IllegalStateException("Failed to retrieve request=" + expected);
}
if (expected.equals(actual)) {
return;
}
long waitDuration = waitEnd - waitStart;
remaining -= waitDuration;
}
throw new IllegalStateException("Timeout while waiting to retrieve request=" + expected);
}
@Test
public void testRemoteForwarding() throws Exception {
Session session = createSession();
try {
int forwardedPort = Utils.getFreePort();
session.setPortForwardingR(forwardedPort, TEST_LOCALHOST, echoPort);
waitForForwardingRequest(TcpipForwardHandler.REQUEST, TimeUnit.SECONDS.toMillis(5L));
try (Socket s = new Socket(TEST_LOCALHOST, forwardedPort);
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(13L));
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
output.write(bytes);
output.flush();
byte[] buf = new byte[bytes.length + Long.SIZE];
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data", expected, res);
} finally {
session.delPortForwardingR(forwardedPort);
}
} finally {
session.disconnect();
}
}
@Test
public void testRemoteForwardingSecondTimeInSameSession() throws Exception {
Session session = createSession();
try {
int forwardedPort = Utils.getFreePort();
session.setPortForwardingR(forwardedPort, TEST_LOCALHOST, echoPort);
waitForForwardingRequest(TcpipForwardHandler.REQUEST, TimeUnit.SECONDS.toMillis(5L));
session.delPortForwardingR(TEST_LOCALHOST, forwardedPort);
waitForForwardingRequest(CancelTcpipForwardHandler.REQUEST, TimeUnit.SECONDS.toMillis(5L));
session.setPortForwardingR(forwardedPort, TEST_LOCALHOST, echoPort);
waitForForwardingRequest(TcpipForwardHandler.REQUEST, TimeUnit.SECONDS.toMillis(5L));
try (Socket s = new Socket(TEST_LOCALHOST, forwardedPort);
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(13L));
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
output.write(bytes);
output.flush();
byte[] buf = new byte[bytes.length + Long.SIZE];
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data", expected, res);
} finally {
session.delPortForwardingR(TEST_LOCALHOST, forwardedPort);
}
} finally {
session.disconnect();
}
}
@Test
public void testRemoteForwardingNative() throws Exception {
try (ClientSession session = createNativeSession(null)) {
SshdSocketAddress remote = new SshdSocketAddress("", 0);
SshdSocketAddress local = new SshdSocketAddress(TEST_LOCALHOST, echoPort);
SshdSocketAddress bound = session.startRemotePortForwarding(remote, local);
try (Socket s = new Socket(bound.getHostName(), bound.getPort());
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(13L));
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
output.write(bytes);
output.flush();
byte[] buf = new byte[bytes.length + Long.SIZE];
int n = input.read(buf);
String res = new String(buf, 0, n);
assertEquals("Mismatched data", expected, res);
} finally {
session.stopRemotePortForwarding(remote);
}
}
}
@Test
public void testRemoteForwardingNativeBigPayload() throws Exception {
AtomicReference<SshdSocketAddress> localAddressHolder = new AtomicReference<>();
AtomicReference<SshdSocketAddress> remoteAddressHolder = new AtomicReference<>();
AtomicReference<SshdSocketAddress> boundAddressHolder = new AtomicReference<>();
AtomicInteger tearDownSignal = new AtomicInteger(0);
@SuppressWarnings("checkstyle:anoninnerlength")
PortForwardingEventListener listener = new PortForwardingEventListener() {
@Override
public void tornDownExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding, Throwable reason)
throws IOException {
assertFalse("Unexpected local tunnel has been torn down: address=" + address, localForwarding);
assertEquals("Tear down indication not invoked", 1, tearDownSignal.get());
}
@Override
public void tornDownDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, Throwable reason) throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel torn down indication: session=" + session + ", address=" + address);
}
@Override
public void tearingDownExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding)
throws IOException {
assertFalse("Unexpected local tunnel being torn down: address=" + address, localForwarding);
assertEquals("Duplicate tear down signalling", 1, tearDownSignal.incrementAndGet());
}
@Override
public void tearingDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address)
throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel tearing down indication: session=" + session + ", address=" + address);
}
@Override
public void establishingExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress remote, boolean localForwarding)
throws IOException {
assertFalse("Unexpected local tunnel being established: local=" + local + ", remote=" + remote, localForwarding);
assertNull("Duplicate establishment indication call for local address=" + local, localAddressHolder.getAndSet(local));
assertNull("Duplicate establishment indication call for remote address=" + remote, remoteAddressHolder.getAndSet(remote));
}
@Override
public void establishingDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local)
throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel establishing indication: session=" + session + ", address=" + local);
}
@Override
public void establishedExplicitTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local,
SshdSocketAddress remote, boolean localForwarding, SshdSocketAddress boundAddress, Throwable reason)
throws IOException {
assertFalse("Unexpected local tunnel has been established: local=" + local + ", remote=" + remote + ", bound=" + boundAddress, localForwarding);
assertSame("Mismatched established tunnel local address", local, localAddressHolder.get());
assertSame("Mismatched established tunnel remote address", remote, remoteAddressHolder.get());
assertNull("Duplicate establishment indication call for bound address=" + boundAddress, boundAddressHolder.getAndSet(boundAddress));
}
@Override
public void establishedDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress boundAddress, Throwable reason)
throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel established indication: session=" + session + ", address=" + boundAddress);
}
};
try (ClientSession session = createNativeSession(listener);
ExplicitPortForwardingTracker tracker =
session.createRemotePortForwardingTracker(new SshdSocketAddress("", 0), new SshdSocketAddress(TEST_LOCALHOST, echoPort))) {
assertTrue("Tracker not marked as open", tracker.isOpen());
assertFalse("Tracker not marked as remote", tracker.isLocalForwarding());
SshdSocketAddress bound = tracker.getBoundAddress();
try (Socket s = new Socket(bound.getHostName(), bound.getPort());
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(13L));
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
byte[] buf = new byte[bytes.length + Long.SIZE];
for (int i = 0; i < 1000; i++) {
output.write(bytes);
output.flush();
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data at iteration #" + i, expected, res);
}
} finally {
tracker.close();
}
assertFalse("Tracker not marked as closed", tracker.isOpen());
} finally {
client.removePortForwardingEventListener(listener);
}
assertNotNull("Local tunnel address not indicated", localAddressHolder.getAndSet(null));
assertNotNull("Remote tunnel address not indicated", remoteAddressHolder.getAndSet(null));
assertNotNull("Bound tunnel address not indicated", boundAddressHolder.getAndSet(null));
}
@Test
public void testLocalForwarding() throws Exception {
Session session = createSession();
try {
int forwardedPort = Utils.getFreePort();
session.setPortForwardingL(forwardedPort, TEST_LOCALHOST, echoPort);
try (Socket s = new Socket(TEST_LOCALHOST, forwardedPort);
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(13L));
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
output.write(bytes);
output.flush();
byte[] buf = new byte[bytes.length + Long.SIZE];
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data", expected, res);
} finally {
session.delPortForwardingL(forwardedPort);
}
} finally {
session.disconnect();
}
}
@Test
public void testLocalForwardingNative() throws Exception {
final AtomicReference<SshdSocketAddress> localAddressHolder = new AtomicReference<>();
final AtomicReference<SshdSocketAddress> remoteAddressHolder = new AtomicReference<>();
final AtomicReference<SshdSocketAddress> boundAddressHolder = new AtomicReference<>();
final AtomicInteger tearDownSignal = new AtomicInteger(0);
@SuppressWarnings("checkstyle:anoninnerlength")
PortForwardingEventListener listener = new PortForwardingEventListener() {
@Override
public void tornDownExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding, Throwable reason)
throws IOException {
assertTrue("Unexpected remote tunnel has been torn down: address=" + address, localForwarding);
assertEquals("Tear down indication not invoked", 1, tearDownSignal.get());
}
@Override
public void tornDownDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, Throwable reason) throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel torn down indication: session=" + session + ", address=" + address);
}
@Override
public void tearingDownExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding)
throws IOException {
assertTrue("Unexpected remote tunnel being torn down: address=" + address, localForwarding);
assertEquals("Duplicate tear down signalling", 1, tearDownSignal.incrementAndGet());
}
@Override
public void tearingDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address)
throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel tearing down indication: session=" + session + ", address=" + address);
}
@Override
public void establishingExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress remote, boolean localForwarding)
throws IOException {
assertTrue("Unexpected remote tunnel being established: local=" + local + ", remote=" + remote, localForwarding);
assertNull("Duplicate establishment indication call for local address=" + local, localAddressHolder.getAndSet(local));
assertNull("Duplicate establishment indication call for remote address=" + remote, remoteAddressHolder.getAndSet(remote));
}
@Override
public void establishingDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local)
throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel establishing indication: session=" + session + ", address=" + local);
}
@Override
public void establishedExplicitTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress local,
SshdSocketAddress remote, boolean localForwarding, SshdSocketAddress boundAddress, Throwable reason)
throws IOException {
assertTrue("Unexpected remote tunnel has been established: local=" + local + ", remote=" + remote + ", bound=" + boundAddress, localForwarding);
assertSame("Mismatched established tunnel local address", local, localAddressHolder.get());
assertSame("Mismatched established tunnel remote address", remote, remoteAddressHolder.get());
assertNull("Duplicate establishment indication call for bound address=" + boundAddress, boundAddressHolder.getAndSet(boundAddress));
}
@Override
public void establishedDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress boundAddress, Throwable reason)
throws IOException {
throw new UnsupportedOperationException("Unexpected dynamic tunnel established indication: session=" + session + ", address=" + boundAddress);
}
};
try (ClientSession session = createNativeSession(listener);
ExplicitPortForwardingTracker tracker =
session.createLocalPortForwardingTracker(new SshdSocketAddress("", 0), new SshdSocketAddress(TEST_LOCALHOST, echoPort))) {
assertTrue("Tracker not marked as open", tracker.isOpen());
assertTrue("Tracker not marked as local", tracker.isLocalForwarding());
SshdSocketAddress bound = tracker.getBoundAddress();
try (Socket s = new Socket(bound.getHostName(), bound.getPort());
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(13L));
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
output.write(bytes);
output.flush();
byte[] buf = new byte[bytes.length + Long.SIZE];
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data", expected, res);
} finally {
tracker.close();
}
assertFalse("Tracker not marked as closed", tracker.isOpen());
} finally {
client.removePortForwardingEventListener(listener);
}
assertNotNull("Local tunnel address not indicated", localAddressHolder.getAndSet(null));
assertNotNull("Remote tunnel address not indicated", remoteAddressHolder.getAndSet(null));
assertNotNull("Bound tunnel address not indicated", boundAddressHolder.getAndSet(null));
}
@Test
public void testLocalForwardingNativeReuse() throws Exception {
try (ClientSession session = createNativeSession(null)) {
SshdSocketAddress local = new SshdSocketAddress("", 0);
SshdSocketAddress remote = new SshdSocketAddress(TEST_LOCALHOST, echoPort);
SshdSocketAddress bound = session.startLocalPortForwarding(local, remote);
session.stopLocalPortForwarding(bound);
SshdSocketAddress bound2 = session.startLocalPortForwarding(local, remote);
session.stopLocalPortForwarding(bound2);
}
}
@Test
public void testLocalForwardingNativeBigPayload() throws Exception {
try (ClientSession session = createNativeSession(null)) {
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
byte[] buf = new byte[bytes.length + Long.SIZE];
SshdSocketAddress local = new SshdSocketAddress("", 0);
SshdSocketAddress remote = new SshdSocketAddress(TEST_LOCALHOST, echoPort);
SshdSocketAddress bound = session.startLocalPortForwarding(local, remote);
try (Socket s = new Socket(bound.getHostName(), bound.getPort());
OutputStream output = s.getOutputStream();
InputStream input = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(10L));
for (int i = 0; i < 1000; i++) {
output.write(bytes);
output.flush();
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data at iteration #" + i, expected, res);
}
} finally {
session.stopLocalPortForwarding(bound);
}
}
}
@Test
public void testForwardingChannel() throws Exception {
try (ClientSession session = createNativeSession(null)) {
SshdSocketAddress local = new SshdSocketAddress("", 0);
SshdSocketAddress remote = new SshdSocketAddress(TEST_LOCALHOST, echoPort);
try (ChannelDirectTcpip channel = session.createDirectTcpipChannel(local, remote)) {
channel.open().verify(9L, TimeUnit.SECONDS);
String expected = getCurrentTestName();
byte[] bytes = expected.getBytes(StandardCharsets.UTF_8);
try (OutputStream output = channel.getInvertedIn();
InputStream input = channel.getInvertedOut()) {
output.write(bytes);
output.flush();
byte[] buf = new byte[bytes.length + Long.SIZE];
int n = input.read(buf);
String res = new String(buf, 0, n, StandardCharsets.UTF_8);
assertEquals("Mismatched data", expected, res);
}
channel.close(false);
}
}
}
@Test(timeout = 45000)
public void testRemoteForwardingWithDisconnect() throws Exception {
Session session = createSession();
try {
// 1. Create a Port Forward
int forwardedPort = Utils.getFreePort();
session.setPortForwardingR(forwardedPort, TEST_LOCALHOST, echoPort);
waitForForwardingRequest(TcpipForwardHandler.REQUEST, TimeUnit.SECONDS.toMillis(5L));
// 2. Establish a connection through it
try (Socket s = new Socket(TEST_LOCALHOST, forwardedPort)) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(10L));
// 3. Simulate the client going away
rudelyDisconnectJschSession(session);
// 4. Make sure the NIOprocessor is not stuck
Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
// from here, we need to check all the threads running and find a
// "NioProcessor-"
// that is stuck on a PortForward.dispose
ThreadGroup root = Thread.currentThread().getThreadGroup().getParent();
while (root.getParent() != null) {
root = root.getParent();
}
for (int index = 0;; index++) {
Collection<Thread> pending = findThreads(root, "NioProcessor-");
if (GenericUtils.size(pending) <= 0) {
log.info("Finished after " + index + " iterations");
break;
}
try {
Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
} catch (InterruptedException e) {
// ignored
}
}
session.delPortForwardingR(forwardedPort);
}
} finally {
session.disconnect();
}
}
/**
* Close the socket inside this JSCH session. Use reflection to find it and
* just close it.
*
* @param session the Session to violate
* @throws Exception
*/
private void rudelyDisconnectJschSession(Session session) throws Exception {
Field fSocket = session.getClass().getDeclaredField("socket");
fSocket.setAccessible(true);
try (Socket socket = (Socket) fSocket.get(session)) {
assertTrue("socket is not connected", socket.isConnected());
assertFalse("socket should not be closed", socket.isClosed());
socket.close();
assertTrue("socket has not closed", socket.isClosed());
}
}
private Set<Thread> findThreads(ThreadGroup group, String name) {
int numThreads = group.activeCount();
Thread[] threads = new Thread[numThreads * 2];
numThreads = group.enumerate(threads, false);
Set<Thread> ret = new HashSet<>();
// Enumerate each thread in `group'
for (int i = 0; i < numThreads; ++i) {
Thread t = threads[i];
// Get thread
// log.debug("Thread name: " + threads[i].getName());
if (checkThreadForPortForward(t, name)) {
ret.add(t);
}
}
// didn't find the thread to check the
int numGroups = group.activeGroupCount();
ThreadGroup[] groups = new ThreadGroup[numGroups * 2];
numGroups = group.enumerate(groups, false);
for (int i = 0; i < numGroups; ++i) {
ThreadGroup g = groups[i];
Collection<Thread> c = findThreads(g, name);
if (GenericUtils.isEmpty(c)) {
continue; // debug breakpoint
}
ret.addAll(c);
}
return ret;
}
private boolean checkThreadForPortForward(Thread thread, String name) {
if (thread == null) {
return false;
}
// does it contain the name we're looking for?
if (thread.getName().contains(name)) {
// look at the stack
StackTraceElement[] stack = thread.getStackTrace();
if (stack.length == 0) {
return false;
}
// does it have 'org.apache.sshd.server.session.TcpipForwardSupport.close'?
for (StackTraceElement aStack : stack) {
String clazzName = aStack.getClassName();
String methodName = aStack.getMethodName();
// log.debug("Class: " + clazzName);
// log.debug("Method: " + methodName);
if (clazzName.equals("org.apache.sshd.server.session.TcpipForwardSupport")
&& (methodName.equals("close") || methodName.equals("sessionCreated"))) {
log.warn(thread.getName() + " stuck at " + clazzName
+ "." + methodName + ": "
+ aStack.getLineNumber());
return true;
}
}
}
return false;
}
protected Session createSession() throws JSchException {
JSch sch = new JSch();
Session session = sch.getSession(getCurrentTestName(), TEST_LOCALHOST, sshPort);
session.setUserInfo(new SimpleUserInfo(getCurrentTestName()));
session.connect();
return session;
}
protected ClientSession createNativeSession(PortForwardingEventListener listener) throws Exception {
PropertyResolverUtils.updateProperty(client, FactoryManager.WINDOW_SIZE, 2048);
PropertyResolverUtils.updateProperty(client, FactoryManager.MAX_PACKET_SIZE, 256);
client.setTcpipForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
if (listener != null) {
client.addPortForwardingEventListener(listener);
}
ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, sshPort).verify(7L, TimeUnit.SECONDS).getSession();
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(11L, TimeUnit.SECONDS);
return session;
}
}