/*
* Copyright 2015 Julien Viet
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.termd.core.tty;
import com.jcraft.jsch.ChannelShell;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.Session;
import com.jcraft.jsch.UserInfo;
import io.termd.core.TestBase;
import io.termd.core.ssh.TtyCommand;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Field;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
/**
* @author <a href="mailto:julien@julienviet.com">Julien Viet</a>
*/
public abstract class SshTtyTestBase extends TtyTestBase {
JSch jsch = new JSch();
Session session;
ChannelShell channel;
InputStream in;
OutputStream out;
@Override
protected void assertConnect(String term) throws Exception {
if (session != null) {
throw failure("Already a session");
}
session = jsch.getSession("whatever", "localhost", 5000);
session.setPassword("whocares");
session.setUserInfo(new UserInfo() {
@Override
public String getPassphrase() {
return null;
}
@Override
public String getPassword() {
return null;
}
@Override
public boolean promptPassword(String s) {
return false;
}
@Override
public boolean promptPassphrase(String s) {
return false;
}
@Override
public boolean promptYesNo(String s) {
return true;
} // Accept all server keys
@Override
public void showMessage(String s) {
}
});
session.connect();
channel = (ChannelShell) session.openChannel("shell");
if (term != null) {
channel.setPtyType(term);
}
channel.connect();
in = channel.getInputStream();
out = channel.getOutputStream();
}
@Override
public boolean checkDisconnected() {
try {
return in != null && in.read() == -1;
} catch (IOException e) {
throw TestBase.failure(e);
}
}
@Override
protected void assertDisconnect(boolean clean) throws Exception {
if (clean) {
session.disconnect();
} else {
Field socketField = session.getClass().getDeclaredField("socket");
socketField.setAccessible(true);
Socket socket = (Socket) socketField.get(session);
socket.close();
}
}
@Override
protected void resize(int width, int height) {
channel.setPtySize(width, height, width * 8, height * 8);
}
@Override
protected void assertWrite(String s) throws Exception {
out.write(s.getBytes(charset));
out.flush();
}
@Override
protected String assertReadString(int len) throws Exception {
byte[] buf = new byte[len];
while (len > 0) {
int count = in.read(buf, buf.length - len, len);
if (count == -1) {
throw failure("Could not read enough");
}
len -= count;
}
return new String(buf, "UTF-8");
}
@Override
protected void assertWriteln(String s) throws Exception {
assertWrite((s + "\r"));
}
private SshServer sshd;
protected abstract SshServer createServer();
protected TtyCommand createConnection(Consumer<TtyConnection> onConnect) {
return new TtyCommand(charset, onConnect);
}
@Override
protected void server(Consumer<TtyConnection> onConnect) {
if (sshd != null) {
throw failure("Already a server");
}
try {
sshd = createServer();
sshd.setPort(5000);
sshd.setKeyPairProvider(new SimpleGeneratorHostKeyProvider(new File("hostkey.ser").toPath()));
sshd.setPasswordAuthenticator((username, password, session) -> true);
sshd.setShellFactory(() -> createConnection(onConnect));
sshd.start();
} catch (Exception e) {
throw failure(e);
}
}
@Before
public void before() {
sshd = null;
session = null;
}
@Test
public void testExitCode() throws Exception {
server(conn -> {
conn.setStdinHandler(bytes -> {
conn.close(25);
});
});
assertConnect();
assertWrite("whatever");
long timeout = System.currentTimeMillis() + 5000;
while (!channel.isClosed()) {
assertTrue(System.currentTimeMillis() < timeout);
Thread.sleep(10);
}
assertEquals(25, channel.getExitStatus());
}
@After
public void after() throws Exception {
if (out != null) {
try { out.close(); } catch (Exception ignore) {}
}
if (in != null) {
try { in.close(); } catch (Exception ignore) {}
}
if (channel != null) {
try { channel.disconnect(); } catch (Exception ignore) {}
}
if (session != null) {
try { session.disconnect(); } catch (Exception ignore) {}
}
if (sshd != null && !sshd.isClosed()) {
try {
sshd.close();
} catch (Exception ignore) {
}
}
}
}