package org.ovirt.engine.ssoreg.core; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.security.GeneralSecurityException; import java.security.NoSuchAlgorithmException; import java.security.SecureRandom; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.FileHandler; import java.util.logging.Level; import java.util.logging.Logger; import java.util.logging.SimpleFormatter; import org.apache.commons.codec.binary.Base64; import org.ovirt.engine.core.uutils.cli.parser.ArgumentsParser; import org.ovirt.engine.core.uutils.crypto.EnvelopePBE; import org.ovirt.engine.ssoreg.db.DBUtils; import org.slf4j.LoggerFactory; public class SsoRegistrationToolExecutor { private static String PROGRAM_NAME = System.getProperty("org.ovirt.engine.ssoreg.core.programName"); private static String PACKAGE_NAME = System.getProperty("org.ovirt.engine.ssoreg.core.packageName"); private static String PACKAGE_VERSION = System.getProperty("org.ovirt.engine.ssoreg.core.packageVersion"); private static String PACKAGE_DISPLAY_NAME = System.getProperty("org.ovirt.engine.ssoreg.core.packageDisplayName"); private static String ENGINE_ETC = System.getProperty("org.ovirt.engine.ssoreg.core.engineEtc"); private static final org.slf4j.Logger log = LoggerFactory.getLogger(SsoRegistrationToolExecutor.class); private static final Logger OVIRT_LOGGER = Logger.getLogger("org.ovirt"); private static SecureRandom secureRandom; static { try { secureRandom = SecureRandom.getInstance("SHA1PRNG"); } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); } } public static void main(String... args) { int exitStatus = 1; List<String> cmdArgs = new ArrayList<>(Arrays.asList(args)); try { final Map<String, String> contextSubstitutions = new HashMap<>(); contextSubstitutions.put("@ENGINE_ETC@", ENGINE_ETC); contextSubstitutions.put("@PROGRAM_NAME@", PROGRAM_NAME); setupLogger(); ArgumentsParser parser; final Map<String, String> substitutions = new HashMap<>(contextSubstitutions); try (InputStream stream = SsoRegistrationToolExecutor.class.getResourceAsStream("arguments.properties")) { parser = new ArgumentsParser(stream, "core"); parser.getSubstitutions().putAll(substitutions); } parser.parse(cmdArgs); Map<String, Object> argMap = parser.getParsedArgs(); setupLogger(argMap); log.debug("Version: {}-{} ({})", PACKAGE_NAME, PACKAGE_VERSION, PACKAGE_DISPLAY_NAME); if(!parser.getErrors().isEmpty()) { for(Throwable t : parser.getErrors()) { log.error(t.getMessage()); log.debug(t.getMessage(), t); } throw new ExitException("Parsing error", 1); } DBUtils dbUtils = new DBUtils(); log.info("========================================================================="); log.info("================== oVirt Sso Client Registration Tool ==================="); log.info("========================================================================="); String clientId = getUserInput("Client Id: "); String certificateFile = getUserInput("Client CA Certificate File Location: "); while (!new File(certificateFile).exists()) { System.out.format("%s is not a valid certificate, please enter path to an existing certificate.%n", certificateFile); certificateFile = getUserInput("Enter Client CA Certificate File Location: "); } String callbackPrefix = getUserInput("Callback Prefix URL: "); while (!callbackPrefix.startsWith("http") && !callbackPrefix.startsWith("https")) { System.out.format("%s is not a valid URL, please enter a proper URL.%n", callbackPrefix); callbackPrefix = getUserInput("Enter Callback Prefix URL: "); } String clientSecret = generateClientSecret(); String encodedClientSecret = encode(argMap, clientSecret); dbUtils.unregisterClient(clientId); dbUtils.registerClient(clientId, encodedClientSecret, certificateFile, callbackPrefix); String tmpFile = createTmpSsoClientConfFile(clientId, clientSecret, certificateFile, callbackPrefix); System.out.println("Client registration completed successfully"); System.out.format("Client secret has been written to file %s%n", tmpFile); log.info("========================================================================"); log.info("========================= Execution Completed =========================="); log.info("========================================================================"); exitStatus = 0; } catch(ExitException e) { log.debug(e.getMessage(), e); exitStatus = e.getExitCode(); } catch (Throwable t) { t.printStackTrace(); log.error(t.getMessage() != null ? t.getMessage() : t.getClass().getName()); log.debug("Exception:", t); } log.debug("Exiting with status '{}'", exitStatus); System.exit(exitStatus); } /** * Read a line from the standard input. */ private static String getUserInput(String question) { System.out.print(question); StringBuilder buffer = new StringBuilder(); for (;;) { int character; try { character = System.in.read(); } catch (IOException exception) { log.error( "Error while reading line from standard input. Will " + "consider it the end of the line and continue.", exception ); break; } if (character == -1 || character == '\n') { break; } buffer.append((char) character); } return buffer.toString(); } private static String generateClientSecret() { byte[] s = new byte[32]; secureRandom.nextBytes(s); return new Base64(0, new byte[0], true).encodeToString(s); } private static String encode(Map<String, Object> args, String clientSecret) throws IOException, GeneralSecurityException { return EnvelopePBE.encode((String) args.get("encoding-algorithm"), Integer.parseInt((String) args.get("key-size")), Integer.parseInt((String) args.get("iterations")), null, clientSecret); } private static void setupLogger() { String logLevel = System.getenv("OVIRT_LOGGING_LEVEL"); OVIRT_LOGGER.setLevel( logLevel != null ? Level.parse(logLevel) : Level.INFO ); } private static void setupLogger(Map<String, Object> args) throws IOException { Logger log = Logger.getLogger(""); String logfile = (String)args.get("log-file"); if(logfile != null) { FileHandler fh = new FileHandler( new File(SsoLocalConfig.getInstance().getLogDir(), logfile).getAbsolutePath(), true); fh.setFormatter(new SimpleFormatter()); log.addHandler(fh); } OVIRT_LOGGER.setLevel((Level)args.get("log-level")); } private static String createTmpSsoClientConfFile(String clientId, String clientSecret, String certificateFile, String callbackPrefix) throws FileNotFoundException { File tmpDir = SsoLocalConfig.getInstance().getTmpDir(); if (tmpDir.mkdirs()) { log.debug("Created ovirt temp directory: {}", tmpDir.getAbsolutePath()); } File tmpFile = new File(tmpDir, String.format("99_sso_client_%s.conf", System.currentTimeMillis())); try ( PrintWriter pw = new PrintWriter(new FileOutputStream(tmpFile)) ) { pw.println(String.format("SSO_CLIENT_ID=%s", clientId)); pw.println(String.format("SSO_CLIENT_SECRET=%s", clientSecret)); pw.println(String.format("SSO_CLIENT_CERTIFICATE_FILE=%s", certificateFile)); pw.println(String.format("SSO_CLIENT_CALLBACK_PREFIX=%s", callbackPrefix)); } return tmpFile.getAbsolutePath(); } }