/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.wildfly.adduser;
import com.fasterxml.jackson.core.type.TypeReference;
import org.jboss.aesh.cl.CommandDefinition;
import org.jboss.aesh.cl.Option;
import org.jboss.aesh.cl.parser.ParserGenerator;
import org.jboss.aesh.console.command.Command;
import org.jboss.aesh.console.command.CommandNotFoundException;
import org.jboss.aesh.console.command.CommandResult;
import org.jboss.aesh.console.command.container.CommandContainer;
import org.jboss.aesh.console.command.invocation.CommandInvocation;
import org.jboss.aesh.console.command.registry.AeshCommandRegistryBuilder;
import org.jboss.aesh.console.command.registry.CommandRegistry;
import org.keycloak.common.util.Base64;
import org.keycloak.credential.CredentialModel;
import org.keycloak.credential.hash.PasswordHashProvider;
import org.keycloak.credential.hash.PasswordHashProviderFactory;
import org.keycloak.credential.hash.Pbkdf2PasswordHashProviderFactory;
import org.keycloak.models.PasswordPolicy;
import org.keycloak.representations.idm.CredentialRepresentation;
import org.keycloak.representations.idm.RealmRepresentation;
import org.keycloak.representations.idm.UserRepresentation;
import org.keycloak.util.JsonSerialization;
import java.io.Console;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
/**
* @author <a href="mailto:sthorger@redhat.com">Stian Thorgersen</a>
*/
public class AddUser {
private static final String COMMAND_NAME = "add-user";
private static final int DEFAULT_HASH_ITERATIONS = 100000;
private static final String DEFAULT_HASH_ALGORITH = PasswordPolicy.HASH_ALGORITHM_DEFAULT;
public static void main(String[] args) throws Exception {
AddUserCommand command = new AddUserCommand();
try {
ParserGenerator.parseAndPopulate(command, COMMAND_NAME, args);
} catch (Exception e) {
System.err.println(e.getMessage());
System.exit(1);
}
if (command.isHelp()) {
printHelp(command);
} else {
try {
String password = command.getPassword();
checkRequired(command, "user");
if(isEmpty(command, "password")){
password = promptForInput();
}
File addUserFile = getAddUserFile(command);
createUser(addUserFile, command.getRealm(), command.getUser(), password, command.getRoles(), command.getIterations());
} catch (Exception e) {
System.err.println(e.getMessage());
System.exit(1);
}
}
}
private static File getAddUserFile(AddUserCommand command) throws Exception {
File configDir;
if (command.isDomain()) {
if (command.getDc() != null) {
configDir = new File(command.getDc());
} else if (System.getProperty("jboss.domain.config.user.dir") != null) {
configDir = new File(System.getProperty("jboss.domain.config.user.dir"));
} else if (System.getenv("JBOSS_HOME") != null) {
configDir = new File(System.getenv("JBOSS_HOME") + File.separator + "domain" + File.separator + "configuration");
} else {
throw new Exception("Could not find domain configuration directory");
}
} else {
if (command.getSc() != null) {
configDir = new File(command.getSc());
} else if (System.getProperty("jboss.server.config.user.dir") != null) {
configDir = new File(System.getProperty("jboss.server.config.user.dir"));
} else if (System.getenv("JBOSS_HOME") != null) {
configDir = new File(System.getenv("JBOSS_HOME") + File.separator + "standalone" + File.separator + "configuration");
} else {
throw new Exception("Could not find standalone configuration directory");
}
}
if (!configDir.isDirectory()) {
throw new Exception("'" + configDir + "' does not exist or is not a directory");
}
File addUserFile = new File(configDir, "keycloak-add-user.json");
return addUserFile;
}
private static void createUser(File addUserFile, String realmName, String userName, String password, String rolesString, int iterations) throws Exception {
List<RealmRepresentation> realms;
if (addUserFile.isFile()) {
realms = JsonSerialization.readValue(new FileInputStream(addUserFile), new TypeReference<List<RealmRepresentation>>() {});
} else {
realms = new LinkedList<>();
}
if (realmName == null) {
realmName = "master";
}
RealmRepresentation realm = null;
for (RealmRepresentation r : realms) {
if (r.getRealm().equals(realmName)) {
realm = r;
}
}
if (realm == null) {
realm = new RealmRepresentation();
realm.setRealm(realmName);
realms.add(realm);
realm.setUsers(new LinkedList<UserRepresentation>());
}
for (UserRepresentation u : realm.getUsers()) {
if (u.getUsername().equals(userName)) {
throw new Exception("User with username '" + userName + "' already added to '" + addUserFile + "'");
}
}
UserRepresentation user = new UserRepresentation();
user.setEnabled(true);
user.setUsername(userName);
user.setCredentials(new LinkedList<CredentialRepresentation>());
Map<String, Object> config = new HashMap<>();
if (iterations > 0) {
config.put("hashIterations", iterations);
}
PasswordHashProviderFactory hashProviderFactory = getHashProviderFactory(DEFAULT_HASH_ALGORITH);
PasswordHashProvider hashProvider = hashProviderFactory.create(null);
CredentialModel credentialModel = new CredentialModel();
hashProvider.encode(password, iterations > 0 ? iterations : DEFAULT_HASH_ITERATIONS, credentialModel);
CredentialRepresentation credentials = new CredentialRepresentation();
credentials.setType(credentialModel.getType());
credentials.setAlgorithm(credentialModel.getAlgorithm());
credentials.setHashIterations(credentialModel.getHashIterations());
credentials.setSalt(Base64.encodeBytes(credentialModel.getSalt()));
credentials.setHashedSaltedValue(credentialModel.getValue());
user.getCredentials().add(credentials);
String[] roles;
if (rolesString != null) {
roles = rolesString.split(",");
} else {
if (realmName.equals("master")) {
roles = new String[] { "admin" };
} else {
roles = new String[] { "realm-management/realm-admin" };
}
}
for (String r : roles) {
if (r.indexOf('/') != -1) {
String[] cr = r.split("/");
String client = cr[0];
String clientRole = cr[1];
if (user.getClientRoles() == null) {
user.setClientRoles(new HashMap<String, List<String>>());
}
if (user.getClientRoles().get(client) == null) {
user.getClientRoles().put(client, new LinkedList<String>());
}
user.getClientRoles().get(client).add(clientRole);
} else {
if (user.getRealmRoles() == null) {
user.setRealmRoles(new LinkedList<String>());
}
user.getRealmRoles().add(r);
}
}
realm.getUsers().add(user);
JsonSerialization.writeValuePrettyToStream(new FileOutputStream(addUserFile), realms);
System.out.println("Added '" + userName + "' to '" + addUserFile + "', restart server to load user");
}
private static PasswordHashProviderFactory getHashProviderFactory(String providerId) {
ServiceLoader<PasswordHashProviderFactory> providerFactories = ServiceLoader.load(PasswordHashProviderFactory.class);
for (PasswordHashProviderFactory f : providerFactories) {
if (f.getId().equals(providerId)) {
return f;
}
}
return null;
}
private static void checkRequired(Command command, String field) throws Exception {
if (isEmpty(command, field)) {
Option option = command.getClass().getDeclaredField(field).getAnnotation(Option.class);
String optionName;
if (option != null && option.shortName() != '\u0000') {
optionName = "-" + option.shortName() + ", --" + field;
} else {
optionName = "--" + field;
}
throw new Exception("Option: " + optionName + " is required");
}
}
private static Boolean isEmpty(Command command, String field) throws Exception {
Method m = command.getClass().getMethod("get" + Character.toUpperCase(field.charAt(0)) + field.substring(1));
if (m.invoke(command) == null) {
return true;
}
return false;
}
private static String promptForInput() throws Exception {
Console console = System.console();
if (console == null) {
throw new Exception("Couldn't get Console instance");
}
console.printf("Press ctrl-d (Unix) or ctrl-z (Windows) to exit\n");
char passwordArray[] = console.readPassword("Password: ");
if(passwordArray == null) System.exit(0);
return new String(passwordArray);
}
private static void printHelp(Command command) throws CommandNotFoundException {
CommandRegistry registry = new AeshCommandRegistryBuilder().command(command).create();
CommandContainer commandContainer = registry.getCommand(command.getClass().getAnnotation(CommandDefinition.class).name(), null);
String help = commandContainer.printHelp(null);
System.out.println(help);
}
@CommandDefinition(name= COMMAND_NAME, description = "[options...]")
public static class AddUserCommand implements Command {
@Option(shortName = 'r', hasValue = true, description = "Name of realm to add user to")
private String realm;
@Option(shortName = 'u', hasValue = true, description = "Name of the user")
private String user;
@Option(shortName = 'p', hasValue = true, description = "Password of the user")
private String password;
@Option(hasValue = true, description = "Roles to add to the user")
private String roles;
@Option(hasValue = true, description = "Hash iterations")
private int iterations;
@Option(hasValue = false, description = "Enable domain mode")
private boolean domain;
@Option(hasValue = true, description = "Define the location of the server config directory")
private String sc;
@Option(hasValue = true, description = "Define the location of the domain config directory")
private String dc;
@Option(shortName = 'h', hasValue = false, description = "Display this help and exit")
private boolean help;
@Override
public CommandResult execute(CommandInvocation commandInvocation) throws InterruptedException {
return CommandResult.SUCCESS;
}
public String getRealm() {
return realm;
}
public String getUser() {
return user;
}
public String getPassword() {
return password;
}
public String getRoles() {
return roles;
}
public int getIterations() {
return iterations;
}
public boolean isDomain() {
return domain;
}
public String getSc() {
return sc;
}
public String getDc() {
return dc;
}
public boolean isHelp() {
return help;
}
}
}