/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.brooklyn.util.core.internal.ssh.sshj;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import org.apache.brooklyn.core.BrooklynFeatureEnablement;
import org.apache.brooklyn.util.core.internal.ssh.SshAbstractTool.SshAction;
import org.apache.brooklyn.util.core.internal.ssh.sshj.SshjTool;
import org.apache.brooklyn.util.core.internal.ssh.sshj.SshjTool.ShellAction;
import org.apache.brooklyn.util.exceptions.Exceptions;
import org.apache.brooklyn.util.time.Duration;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
/**
* Tests for async-exec with {@link SshjTool}, where it stubs out the actual ssh commands
* to return a controlled sequence of responses.
*/
public class SshjToolAsyncStubIntegrationTest {
static class InjectedResult {
Predicate<SshjTool.ShellAction> expected;
Function<SshjTool.ShellAction, Integer> result;
InjectedResult(Predicate<SshjTool.ShellAction> expected, Function<SshjTool.ShellAction, Integer> result) {
this.expected = expected;
this.result = result;
}
}
private SshjTool tool;
private List<InjectedResult> sequence;
int counter = 0;
private boolean origFeatureEnablement;
@BeforeMethod(alwaysRun=true)
public void setUp() throws Exception {
origFeatureEnablement = BrooklynFeatureEnablement.enable(BrooklynFeatureEnablement.FEATURE_SSH_ASYNC_EXEC);
sequence = Lists.newArrayList();
counter = 0;
tool = new SshjTool(ImmutableMap.<String,Object>of("host", "localhost")) {
@SuppressWarnings("unchecked")
protected <T, C extends SshAction<T>> T acquire(C action, int sshTries, Duration sshTriesTimeout) {
if (action instanceof SshjTool.ShellAction) {
SshjTool.ShellAction shellAction = (SshjTool.ShellAction) action;
InjectedResult injectedResult = sequence.get(counter);
assertTrue(injectedResult.expected.apply(shellAction), "counter="+counter+"; cmds="+shellAction.commands);
counter++;
return (T) injectedResult.result.apply(shellAction);
}
return super.acquire(action, sshTries, sshTriesTimeout);
}
};
}
@AfterMethod(alwaysRun=true)
public void tearDown() throws Exception {
try {
if (tool != null) tool.disconnect();
} finally {
BrooklynFeatureEnablement.setEnablement(BrooklynFeatureEnablement.FEATURE_SSH_ASYNC_EXEC, origFeatureEnablement);
}
}
private Predicate<SshjTool.ShellAction> containsCmd(final String cmd) {
return new Predicate<SshjTool.ShellAction>() {
@Override public boolean apply(ShellAction input) {
return input != null && input.commands.toString().contains(cmd);
}
};
}
private Function<SshjTool.ShellAction, Integer> returning(final int result, final String stdout, final String stderr) {
return new Function<SshjTool.ShellAction, Integer>() {
@Override public Integer apply(ShellAction input) {
try {
if (stdout != null && input.out != null) input.out.write(stdout.getBytes());
if (stderr != null && input.err != null) input.err.write(stderr.getBytes());
} catch (IOException e) {
throw Exceptions.propagate(e);
}
return result;
}
};
}
@Test(groups="Integration")
public void testPolls() throws Exception {
sequence = ImmutableList.of(
new InjectedResult(containsCmd("nohup"), returning(0, "", "")),
new InjectedResult(containsCmd("# Long poll"), returning(0, "mystringToStdout", "mystringToStderr")));
runTest(0, "mystringToStdout", "mystringToStderr");
assertEquals(counter, sequence.size());
}
@Test(groups="Integration")
public void testPollsAndReturnsNonZeroExitCode() throws Exception {
sequence = ImmutableList.of(
new InjectedResult(containsCmd("nohup"), returning(0, "", "")),
new InjectedResult(containsCmd("# Long poll"), returning(123, "mystringToStdout", "mystringToStderr")),
new InjectedResult(containsCmd("# Retrieve status"), returning(0, "123", "")));
runTest(123, "mystringToStdout", "mystringToStderr");
assertEquals(counter, sequence.size());
}
@Test(groups="Integration")
public void testPollsRepeatedly() throws Exception {
sequence = ImmutableList.of(
new InjectedResult(containsCmd("nohup"), returning(0, "", "")),
new InjectedResult(containsCmd("# Long poll"), returning(125, "mystringToStdout", "mystringToStderr")),
new InjectedResult(containsCmd("# Retrieve status"), returning(0, "", "")),
new InjectedResult(containsCmd("# Long poll"), returning(125, "mystringToStdout2", "mystringToStderr2")),
new InjectedResult(containsCmd("# Retrieve status"), returning(0, "", "")),
new InjectedResult(containsCmd("# Long poll"), returning(-1, "mystringToStdout3", "mystringToStderr3")),
new InjectedResult(containsCmd("# Long poll"), returning(125, "mystringToStdout4", "mystringToStderr4")),
new InjectedResult(containsCmd("# Retrieve status"), returning(0, "", "")),
new InjectedResult(containsCmd("# Long poll"), returning(0, "mystringToStdout5", "mystringToStderr5")));
runTest(0,
"mystringToStdout"+"mystringToStdout2"+"mystringToStdout3"+"mystringToStdout4"+"mystringToStdout5",
"mystringToStderr"+"mystringToStderr2"+"mystringToStderr3"+"mystringToStderr4"+"mystringToStderr5");
assertEquals(counter, sequence.size());
}
protected void runTest(int expectedExit, String expectedStdout, String expectedStderr) throws Exception {
List<String> cmds = ImmutableList.of("abc");
ByteArrayOutputStream out = new ByteArrayOutputStream();
ByteArrayOutputStream err = new ByteArrayOutputStream();
int exitCode = tool.execScript(
ImmutableMap.of(
"out", out,
"err", err,
SshjTool.PROP_EXEC_ASYNC.getName(), true,
SshjTool.PROP_NO_EXTRA_OUTPUT.getName(), true,
SshjTool.PROP_EXEC_ASYNC_POLLING_TIMEOUT.getName(), Duration.ONE_MILLISECOND),
cmds,
ImmutableMap.<String,String>of());
String outStr = new String(out.toByteArray());
String errStr = new String(err.toByteArray());
assertEquals(exitCode, expectedExit);
assertEquals(outStr.trim(), expectedStdout);
assertEquals(errStr.trim(), expectedStderr);
}
}