/* * Copyright (C) 2011-2012 Intel Corporation * All rights reserved. */ package com.intel.mtwilson.setup; import com.intel.mtwilson.api.*; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.net.InetAddress; import java.net.URL; import java.net.UnknownHostException; import java.security.*; import java.security.cert.CertificateException; import java.util.concurrent.TimeUnit; import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.direct.Session; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.verification.HostKeyVerifier; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @since 0.5.3 * @author jbuhacoff */ public class SshUtil { private static Logger log = LoggerFactory.getLogger(SshUtil.class); private static interface Command { void run(String[] args); } /** * Executes a remote command with no timeout * * @param ssh * @param command * @return * @throws IOException which could be a ConnectionException or TransportException */ public static String remote(SSHClient ssh, String command) throws IOException { return remote(ssh, command, null); } /** * The sshj client is designed to permit one command per "session". But * you can start multiple sessions per connection so this is ok. * * @param ssh the ssh client * @param command string to execute on the remote shell * @param timeoutSeconds or null to wait indefinitely for the command to complete * @return * @throws IOException which could be a ConnectionException or TransportException */ public static String remote(SSHClient ssh, String command, Timeout timeout) throws IOException { Session session = ssh.startSession(); try { Session.Command cmd = session.exec(command); // ConnectionException, TransportException if( timeout == null ) { cmd.join(); } else { cmd.join((int)timeout.toSeconds(), TimeUnit.SECONDS); // the parameters are the timeout. if you want to wait indefinitely call join() } log.debug("Command exit status: {}", cmd.getExitStatus()); String output = IOUtils.toString(cmd.getInputStream()); // IOException return output; } finally { session.close(); } } /** * This should not be a @Test method because it requires the root password * of the server and we should not store that. So invoke it by running this * class via the main() method which will prompt for the root password. * * @throws KeyStoreException * @throws IOException * @throws NoSuchAlgorithmException * @throws CertificateException * @throws UnrecoverableEntryException * @throws KeyManagementException * @throws ApiException * @throws SignatureException */ public static void executeRemoteCommand(String ipAddress, String rootPassword, SshRemoteCommand command) throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableEntryException, KeyManagementException, ApiException, SignatureException { SSHClient ssh = new SSHClient(); //ssh.loadKnownHosts(); // this is only if we have a known_hosts file... //ssh.addHostKeyVerifier("..."); // this is only if we know the fingerprint of the remote host we're connecting to ssh.addHostKeyVerifier(new HostKeyVerifier() {@Override public boolean verify(String arg0, int arg1, PublicKey arg2) { return true; } }); // this accepts all remote public keys ssh.connect(ipAddress); try { ssh.authPassword("root", rootPassword); command.execute(ssh); } finally { ssh.disconnect(); } } public static interface SshRemoteCommand { void execute(SSHClient ssh) throws IOException; } private static class ShowTrustHosts implements SshRemoteCommand { private String[] args; public ShowTrustHosts(String[] args) { this.args = args; } @Override public void execute(SSHClient ssh) throws IOException { // find out what is the previous list of trusted hosts String previousWhitelistString = remote(ssh, "msctl show mtwilson.api.trust"); System.out.println(previousWhitelistString); } } private static class AddLocalHostTrust implements SshRemoteCommand { private String[] args; public AddLocalHostTrust(String[] args) { this.args = args; } @Override public void execute(SSHClient ssh) throws IOException { // find out what is the previous list of trusted hosts String previousWhitelistString = remote(ssh, "msctl show mtwilson.api.trust"); String[] previousWhitelist = previousWhitelistString.trim().split(","); // trim is required to remove the newline at the end of the output log.debug("Previous trusted clients network address list: {}", previousWhitelistString); // get local ip address and add it to the list InetAddress addr = InetAddress.getLocalHost(); String[] updatedWhitelist = (String[]) ArrayUtils.add(previousWhitelist, addr.getHostAddress()); String updatedWhitelistString = StringUtils.join(updatedWhitelist, ","); log.debug("Updated trusted clients network address list: {}", updatedWhitelistString); // set the new list on the server and restart the application remote(ssh, String.format("msctl edit mtwilson.api.trust \"%s\"", updatedWhitelistString)); remote(ssh, "msctl restart"); } } private static class SetTrustHosts implements SshRemoteCommand { private String[] args; public SetTrustHosts(String[] args) { this.args = args; } @Override public void execute(SSHClient ssh) throws IOException { // whitelist should be in args[2] String previousWhitelistString = args[2]; // now restore the original trusted hosts whitelist remote(ssh, String.format("msctl edit mtwilson.api.trust \"%s\"", previousWhitelistString)); log.info("Restored previous trusted clients network address list"); remote(ssh, "msctl restart"); } } /** * Syntax: * java -cp path/to/apiclient.jar com.intel.mtwilson.RemoteCommand <command> [parameters] * Configuration options: * --conf=filename * * @param args */ public static void main(String[] args) throws IOException, KeyManagementException, NoSuchAlgorithmException, GeneralSecurityException, ApiException { for(int i=0; i<args.length ;i++) { System.out.println("RemoteCommand ARG "+i+" = "+args[i]); } if( args.length == 0 ) { printUsage(); System.exit(1); } if( args[0].equals("AddTrustLocalHost") ) { if( args.length < 2 ) { System.err.println("Usage: AddTrustLocalHost ServiceURL"); System.err.println("ServiceURL is the URL to the management service"); System.exit(1); } // args[1] should be an api baseurl (from which we will extract an ip address or hostname and connect to it with ssh) URL url = new URL(args[1]); String host = url.getHost(); BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); System.out.print("Remote password: "); String password = in.readLine(); executeRemoteCommand(host, password, new AddLocalHostTrust(args)); System.exit(0); } if( args[0].equals("SetTrustHosts") ) { if( args.length < 2 ) { System.err.println("Usage: SetTrustHosts ServiceURL value-for-mtwilson.api.trust"); System.err.println("ServiceURL is the URL to the management service"); System.exit(1); } URL url = new URL(args[1]); String host = url.getHost(); BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); System.out.print("Remote password: "); String password = in.readLine(); executeRemoteCommand(host, password, new SetTrustHosts(args)); System.exit(0); } if( args[0].equals("ShowTrustHosts") ) { if( args.length < 2 ) { System.err.println("Usage: ShowTrustHosts ServiceURL"); System.err.println("ServiceURL is the URL to the management service"); System.exit(1); } URL url = new URL(args[1]); String host = url.getHost(); BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); System.out.print("Remote password: "); String password = in.readLine(); executeRemoteCommand(host, password, new ShowTrustHosts(args)); System.exit(0); } } private static void printUsage() { System.err.println("Usage:"); System.err.println("CreateUser /path/to/directory"); System.err.println(" Will prompt for username and password."); System.err.println(" Will create username.jks in directory."); } }