package org.ovirt.engine.core.aaa; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; import java.io.ByteArrayOutputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Function; import javax.net.ssl.TrustManagerFactory; import javax.servlet.http.HttpServletRequest; import org.apache.commons.codec.binary.Base64; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.apache.http.HttpStatus; import org.apache.http.client.entity.UrlEncodedFormEntity; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpUriRequest; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.message.BasicNameValuePair; import org.codehaus.jackson.map.DeserializationConfig; import org.codehaus.jackson.map.ObjectMapper; import org.ovirt.engine.api.extensions.ExtMap; import org.ovirt.engine.api.extensions.aaa.Authz; import org.ovirt.engine.core.aaa.filters.FiltersHelper; import org.ovirt.engine.core.utils.EngineLocalConfig; import org.ovirt.engine.core.utils.serialization.json.JsonExtMapMixIn; import org.ovirt.engine.core.uutils.net.HttpClientBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class SsoOAuthServiceUtils { private static final Logger log = LoggerFactory.getLogger(SsoOAuthServiceUtils.class); private static final String authzSearchScope = "ovirt-ext=token-info:authz-search"; private static final String publicAuthzSearchScope = "ovirt-ext=token-info:public-authz-search"; // Reference to the HTTP client used to send the requests to the SSO server: private static volatile CloseableHttpClient client; private static final ObjectMapper mapper; static { // Remember to close the client when going down: Runtime.getRuntime().addShutdownHook( new Thread(() -> IOUtils.closeQuietly(client)) ); mapper = new ObjectMapper() .configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false) .enableDefaultTyping(ObjectMapper.DefaultTyping.OBJECT_AND_NON_CONCRETE); mapper.getDeserializationConfig().addMixInAnnotations(ExtMap.class, JsonExtMapMixIn.class); mapper.getSerializationConfig().addMixInAnnotations(ExtMap.class, JsonExtMapMixIn.class); } public static Map<String, Object> authenticate(HttpServletRequest req, String scope) { try { HttpPost request = createPost("/oauth/token"); setClientIdSecretBasicAuthHeader(request); String[] credentials = getUserCredentialsFromHeader(req); List<BasicNameValuePair> form = new ArrayList<>(4); form.add(new BasicNameValuePair("grant_type", "password")); form.add(new BasicNameValuePair("username", credentials[0])); form.add(new BasicNameValuePair("password", credentials[1])); form.add(new BasicNameValuePair("scope", scope)); request.setEntity(new UrlEncodedFormEntity(form, StandardCharsets.UTF_8)); return getResponse(request); } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } public static Map<String, Object> loginOnBehalf(String username, String scope, ExtMap authRecord) { return loginWithPasswordImpl(username, "", scope, authRecord); } public static Map<String, Object> loginWithPassword(String username, String password, String scope) { return loginWithPasswordImpl(username, password, scope, null); } private static Map<String, Object> loginWithPasswordImpl( String username, String password, String scope, ExtMap authRecord) { try { HttpPost request = createPost("/oauth/token"); setClientIdSecretBasicAuthHeader(request); List<BasicNameValuePair> form = new ArrayList<>(5); form.add(new BasicNameValuePair("grant_type", "password")); form.add(new BasicNameValuePair("username", username)); form.add(new BasicNameValuePair("password", password)); form.add(new BasicNameValuePair("scope", scope)); if (authRecord != null) { form.add(new BasicNameValuePair("ovirt_auth_record", serialize(authRecord))); } request.setEntity(new UrlEncodedFormEntity(form, StandardCharsets.UTF_8)); return getResponse(request); } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } private static String serialize(Object obj) throws IOException { return mapper.writeValueAsString(obj); } private static <T> T deserialize(String json, Class<T> type) throws IOException { return mapper.readValue(json, type); } public static Map<String, Object> revoke(String token) { return revoke(token, "ovirt-ext=revoke:revoke-all"); } public static Map<String, Object> revoke(String token, String scope) { try { HttpPost request = createPost("/oauth/revoke"); setClientIdSecretBasicAuthHeader(request); List<BasicNameValuePair> form = new ArrayList<>(2); form.add(new BasicNameValuePair("token", token)); form.add(new BasicNameValuePair("scope", scope)); request.setEntity(new UrlEncodedFormEntity(form, StandardCharsets.UTF_8)); return getResponse(request); } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } public static Map<String, Object> getToken(String grantType, String code, String scope, String redirectUri) { try { HttpPost request = createPost("/oauth/token"); setClientIdSecretBasicAuthHeader(request); List<BasicNameValuePair> form = new ArrayList<>(4); form.add(new BasicNameValuePair("grant_type", grantType)); form.add(new BasicNameValuePair("code", code)); form.add(new BasicNameValuePair("redirect_uri", redirectUri)); form.add(new BasicNameValuePair("scope", scope)); request.setEntity(new UrlEncodedFormEntity(form, StandardCharsets.UTF_8)); return getResponse(request); } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } public static Map<String, Object> getTokenInfo(String token) { return getTokenInfo(token, null); } public static Map<String, Object> getTokenInfo(String token, String scope) { try { HttpPost request = createPost("/oauth/token-info"); setClientIdSecretBasicAuthHeader(request); List<BasicNameValuePair> form = new ArrayList<>(2); form.add(new BasicNameValuePair("token", token)); if (StringUtils.isNotEmpty(scope)) { form.add(new BasicNameValuePair("scope", scope)); } request.setEntity(new UrlEncodedFormEntity(form, StandardCharsets.UTF_8)); Map<String, Object> jsonData = getResponse(request); Map<String, Object> ovirtData = (Map<String, Object>) jsonData.get("ovirt"); if (ovirtData != null) { Collection<ExtMap> groupIds = (Collection<ExtMap>) ovirtData.get("group_ids"); if (groupIds != null) { ovirtData.put("group_ids", SsoOAuthServiceUtils.processGroupMembershipsFromJson(groupIds)); } } return jsonData; } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } public static Map<String, Object> isSsoDeployed() { try { HttpGet request = createGet("/status"); return getResponse(request); } catch (FileNotFoundException ex) { return buildMapWithError("server_error", "oVirt Engine is initializing."); } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } public static Map<String, Object> fetchPrincipalRecord( String token, String domain, String principal, boolean groupsResolving, boolean groupsResolvingRecursive) { Map<String, Object> params = new HashMap<>(); params.put("domain", domain); params.put("principal", principal); params.put("groups_resolving", groupsResolving); params.put("groups_resolving_recursive", groupsResolvingRecursive); return search(token, params, "fetch-principal-record", authzSearchScope); } public static Map<String, Object> findPrincipalsByIds( String token, String domain, String namespace, Collection<String> ids, boolean groupsResolving, boolean groupsResolvingRecursive) { Map<String, Object> params = new HashMap<>(); params.put("domain", domain); params.put("namespace", namespace); params.put("ids", ids); params.put("groups_resolving", groupsResolving); params.put("groups_resolving_recursive", groupsResolvingRecursive); return search(token, params, "find-principals-by-ids", authzSearchScope); } public static Map<String, Object> findLoginOnBehalfPrincipalById( String domain, String namespace, Collection<String> ids, boolean groupsResolving, boolean groupsResolvingRecursive) { Map<String, Object> params = new HashMap<>(); params.put("domain", domain); params.put("namespace", namespace); params.put("ids", ids); params.put("groups_resolving", groupsResolving); params.put("groups_resolving_recursive", groupsResolvingRecursive); return search(null, params, "find-login-on-behalf-principal-by-id", publicAuthzSearchScope); } public static Map<String, Object> findDirectoryUserById( String token, String domain, String namespace, String id, boolean groupsResolving, boolean groupsResolvingRecursive) { Map<String, Object> params = new HashMap<>(); params.put("domain", domain); params.put("namespace", StringUtils.defaultIfEmpty(namespace, "")); params.put("id", id); params.put("groups_resolving", groupsResolving); params.put("groups_resolving_recursive", groupsResolvingRecursive); return search(token, params, "find-principal-by-id", authzSearchScope); } public static Map<String, Object> findDirectoryGroupById( String token, String domain, String namespace, String id, boolean groupsResolving, boolean groupsResolvingRecursive) { Map<String, Object> params = new HashMap<>(); params.put("domain", domain); params.put("namespace", StringUtils.defaultIfEmpty(namespace, "")); params.put("id", id); params.put("groups_resolving", groupsResolving); params.put("groups_resolving_recursive", groupsResolvingRecursive); return search(token, params, "find-directory-group-by-id", authzSearchScope); } public static Map<String, Object> getDomainList(String token) { return search(token, null, "domain-list", authzSearchScope); } public static Map<String, Object> getAvailableNamespaces(String token) { return search(token, null, "available-namespaces", authzSearchScope); } public static Map<String, Object> getSessionStatues(Set<String> entries) { return search(null, Collections.singletonMap("tokens", entries), "session-statuses", publicAuthzSearchScope); } public static Map<String, Object> getProfileList() { return search(null, null, "profile-list", publicAuthzSearchScope); } public static Map<String, Object> searchUsers(String token, Map<String, Object> params) { return search(token, params, "users", authzSearchScope); } public static Map<String, Object> searchGroups(String token, Map<String, Object> params) { return search(token, params, "groups", authzSearchScope); } private static Map<String, Object> search( String token, Map<String, Object> params, String queryType, String scope) { try { HttpPost request = createPost("/oauth/token-info"); setClientIdSecretBasicAuthHeader(request); List<BasicNameValuePair> form = new ArrayList<>(4); form.add(new BasicNameValuePair("query_type", queryType)); form.add(new BasicNameValuePair("scope", scope)); if (StringUtils.isNotEmpty(token)) { form.add(new BasicNameValuePair("token", token)); } if (params != null) { form.add(new BasicNameValuePair("params", serialize(params))); } request.setEntity(new UrlEncodedFormEntity(form, StandardCharsets.UTF_8)); return getResponse(request); } catch (Exception ex) { return buildMapWithError("server_error", ex.getMessage()); } } private static String[] getUserCredentialsFromHeader(HttpServletRequest request) { String header = request.getHeader("Authorization"); String userName = ""; String passwd = ""; if (StringUtils.isNotEmpty(header) && header.startsWith("Basic")) { String[] creds = new String( Base64.decodeBase64(header.substring("Basic".length())), StandardCharsets.UTF_8 ).split(":", 2); userName = creds.length >= 1 ? creds[0] : ""; passwd = creds.length >= 2 ? creds[1] : ""; } return new String[] {userName, passwd}; } private static void setClientIdSecretBasicAuthHeader(HttpUriRequest request) { EngineLocalConfig config = EngineLocalConfig.getInstance(); byte[] encodedBytes = Base64.encodeBase64(String.format("%s:%s", config.getProperty("ENGINE_SSO_CLIENT_ID"), config.getProperty("ENGINE_SSO_CLIENT_SECRET")).getBytes()); request.setHeader(FiltersHelper.Constants.HEADER_AUTHORIZATION, String.format("Basic %s", new String(encodedBytes))); } private static Map<String, Object> buildMapWithError(String error_code, String error) { Map<String, Object> response = new HashMap<>(); response.put("error", error); response.put("error_code", error_code); return response; } private static HttpPost createPost(String path) throws Exception { EngineLocalConfig config = EngineLocalConfig.getInstance(); String base = config.getProperty("ENGINE_SSO_SERVICE_URL"); HttpPost request = new HttpPost(); request.setURI(new URI(base + path)); request.setHeader("Accept", "application/json"); request.setHeader("Content-Language", "en-US"); return request; } private static HttpGet createGet(String path) throws Exception { EngineLocalConfig config = EngineLocalConfig.getInstance(); String base = config.getProperty("ENGINE_SSO_SERVICE_URL"); HttpGet request = new HttpGet(); request.setURI(new URI(base + path)); request.setHeader("Accept", "application/json"); return request; } private static Map<String, Object> getResponse(HttpUriRequest request) throws Exception { try (CloseableHttpResponse response = execute(request)) { if (response.getStatusLine().getStatusCode() == HttpStatus.SC_NOT_FOUND) { throw new FileNotFoundException(); } try (ByteArrayOutputStream os = new ByteArrayOutputStream()) { try (InputStream input = response.getEntity().getContent()) { FiltersHelper.copy(input, os); } ClassLoader loader = Thread.currentThread().getContextClassLoader(); Thread.currentThread().setContextClassLoader(SsoOAuthServiceUtils.class.getClassLoader()); try { return deserialize( new String(os.toByteArray(), StandardCharsets.UTF_8.name()), HashMap.class); } finally { Thread.currentThread().setContextClassLoader(loader); } } } } private static CloseableHttpResponse execute(HttpUriRequest request) throws IOException, GeneralSecurityException { // Make sure the client is created: if (client == null) { synchronized (SsoOAuthServiceUtils.class) { if (client == null) { client = createClient(); } } } // Execute the request: return client.execute(request); } private static CloseableHttpClient createClient() throws IOException, GeneralSecurityException { EngineLocalConfig config = EngineLocalConfig.getInstance(); return new HttpClientBuilder() .setSslProtocol(config.getProperty("ENGINE_SSO_SERVICE_SSL_PROTOCOL")) .setPoolSize(config.getInteger("ENGINE_SSO_SERVICE_CLIENT_POOL_SIZE")) .setReadTimeout(0) .setRetryCount(config.getInteger("ENGINE_SSO_SERVICE_CONNECTION_RETRY_COUNT")) .setTrustManagerAlgorithm(TrustManagerFactory.getDefaultAlgorithm()) .setTrustStore(config.getProperty("ENGINE_HTTPS_PKI_TRUST_STORE")) .setTrustStorePassword(config.getProperty("ENGINE_HTTPS_PKI_TRUST_STORE_PASSWORD")) .setTrustStoreType(config.getProperty("ENGINE_HTTPS_PKI_TRUST_STORE_TYPE")) .setValidateAfterInactivity(config.getInteger("ENGINE_SSO_SERVICE_CONNECTION_VALIDATE_AFTER_INACTIVITY")) .setVerifyChain(config.getBoolean("ENGINE_SSO_SERVICE_SSL_VERIFY_CHAIN")) .setVerifyHost(config.getBoolean("ENGINE_SSO_SERVICE_SSL_VERIFY_HOST")) .build(); } /** * Currently jackson doesn't provide a way how to serialize graphs with cyclic references between nodes, which * may happen if those cyclic dependencies exists among nested groups which is a user member of. So in order to * deserialize from JSON successfully we have to revert changes done in * {@code org.ovirt.engine.core.sso.utils.SsoUtils.prepareGroupMembershipsForJson()} */ public static List<ExtMap> processGroupMembershipsFromJson(Collection<ExtMap> jsonGroupMemberships) { Map<String, ExtMap> groupsCache = jsonGroupMemberships.stream() .collect(toMap(item -> item.get(Authz.GroupRecord.ID), Function.identity())); jsonGroupMemberships.forEach(groupRecord -> groupRecord.put( Authz.GroupRecord.GROUPS, groupRecord.<Collection<String>>get(Authz.GroupRecord.GROUPS, Collections.emptyList()).stream() .map(memberOfId -> groupsCache.get(memberOfId)) .collect(toList()))); return groupsCache.values().stream() .filter(group -> group.containsKey(Authz.PrincipalRecord.PRINCIPAL)) .peek(group -> group.remove(Authz.PrincipalRecord.PRINCIPAL)) .collect(toList()); } }