/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.services.clientregistration.policy.impl;
import java.net.InetAddress;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.UnknownHostException;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.jboss.logging.Logger;
import org.keycloak.component.ComponentModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.protocol.oidc.utils.PairwiseSubMapperUtils;
import org.keycloak.representations.idm.ClientRepresentation;
import org.keycloak.services.ServicesLogger;
import org.keycloak.services.clientregistration.ClientRegistrationContext;
import org.keycloak.services.clientregistration.ClientRegistrationProvider;
import org.keycloak.services.clientregistration.policy.ClientRegistrationPolicy;
import org.keycloak.services.clientregistration.policy.ClientRegistrationPolicyException;
/**
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class TrustedHostClientRegistrationPolicy implements ClientRegistrationPolicy {
private static final Logger logger = Logger.getLogger(TrustedHostClientRegistrationPolicy.class);
private final KeycloakSession session;
private final ComponentModel componentModel;
public TrustedHostClientRegistrationPolicy(KeycloakSession session, ComponentModel componentModel) {
this.session = session;
this.componentModel = componentModel;
}
@Override
public void beforeRegister(ClientRegistrationContext context) throws ClientRegistrationPolicyException {
verifyHost();
verifyClientUrls(context);
}
@Override
public void afterRegister(ClientRegistrationContext context, ClientModel clientModel) {
}
@Override
public void beforeUpdate(ClientRegistrationContext context, ClientModel clientModel) throws ClientRegistrationPolicyException {
verifyHost();
verifyClientUrls(context);
}
@Override
public void afterUpdate(ClientRegistrationContext context, ClientModel clientModel) {
}
@Override
public void beforeView(ClientRegistrationProvider provider, ClientModel clientModel) throws ClientRegistrationPolicyException {
verifyHost();
}
@Override
public void beforeDelete(ClientRegistrationProvider provider, ClientModel clientModel) throws ClientRegistrationPolicyException {
verifyHost();
}
// IMPL
protected void verifyHost() throws ClientRegistrationPolicyException {
boolean hostMustMatch = isHostMustMatch();
if (!hostMustMatch) {
return;
}
String hostAddress = session.getContext().getConnection().getRemoteAddr();
logger.debugf("Verifying remote host : %s", hostAddress);
List<String> trustedHosts = getTrustedHosts();
List<String> trustedDomains = getTrustedDomains();
// Verify trustedHosts by their IP addresses
String verifiedHost = verifyHostInTrustedHosts(hostAddress, trustedHosts);
if (verifiedHost != null) {
return;
}
// Verify domains if hostAddress hostname belongs to the domain. This assumes proper DNS setup
verifiedHost = verifyHostInTrustedDomains(hostAddress, trustedDomains);
if (verifiedHost != null) {
return;
}
ServicesLogger.LOGGER.failedToVerifyRemoteHost(hostAddress);
throw new ClientRegistrationPolicyException("Host not trusted.");
}
protected List<String> getTrustedHosts() {
List<String> trustedHostsConfig = componentModel.getConfig().getList(TrustedHostClientRegistrationPolicyFactory.TRUSTED_HOSTS);
return trustedHostsConfig.stream().filter((String hostname) -> {
return !hostname.startsWith("*.");
}).collect(Collectors.toList());
}
protected List<String> getTrustedDomains() {
List<String> trustedHostsConfig = componentModel.getConfig().getList(TrustedHostClientRegistrationPolicyFactory.TRUSTED_HOSTS);
List<String> domains = new LinkedList<>();
for (String hostname : trustedHostsConfig) {
if (hostname.startsWith("*.")) {
hostname = hostname.substring(2);
domains.add(hostname);
}
}
return domains;
}
protected String verifyHostInTrustedHosts(String hostAddress, List<String> trustedHosts) {
for (String confHostName : trustedHosts) {
try {
String hostIPAddress = InetAddress.getByName(confHostName).getHostAddress();
logger.tracef("Trying host '%s' of address '%s'", confHostName, hostIPAddress);
if (hostIPAddress.equals(hostAddress)) {
logger.debugf("Successfully verified host : %s", confHostName);
return confHostName;
}
} catch (UnknownHostException uhe) {
logger.debugf(uhe, "Unknown host from realm configuration: %s", confHostName);
}
}
return null;
}
protected String verifyHostInTrustedDomains(String hostAddress, List<String> trustedDomains) {
if (!trustedDomains.isEmpty()) {
try {
String hostname = InetAddress.getByName(hostAddress).getHostName();
logger.debugf("Trying verify request from address '%s' of host '%s' by domains", hostAddress, hostname);
for (String confDomain : trustedDomains) {
if (hostname.endsWith(confDomain)) {
logger.debugf("Successfully verified host '%s' by trusted domain '%s'", hostname, confDomain);
return hostname;
}
}
} catch (UnknownHostException uhe) {
logger.debugf(uhe, "Request of address '%s' came from unknown host. Skip verification by domains", hostAddress);
}
}
return null;
}
protected void verifyClientUrls(ClientRegistrationContext context) throws ClientRegistrationPolicyException {
boolean redirectUriMustMatch = isClientUrisMustMatch();
if (!redirectUriMustMatch) {
return;
}
List<String> trustedHosts = getTrustedHosts();
List<String> trustedDomains = getTrustedDomains();
ClientRepresentation client = context.getClient();
String rootUrl = client.getRootUrl();
String baseUrl = client.getBaseUrl();
String adminUrl = client.getAdminUrl();
List<String> redirectUris = client.getRedirectUris();
baseUrl = relativeToAbsoluteURI(rootUrl, baseUrl);
adminUrl = relativeToAbsoluteURI(rootUrl, adminUrl);
Set<String> resolvedRedirects = PairwiseSubMapperUtils.resolveValidRedirectUris(rootUrl, redirectUris);
if (rootUrl != null) {
checkURLTrusted(rootUrl, trustedHosts, trustedDomains);
}
if (baseUrl != null) {
checkURLTrusted(baseUrl, trustedHosts, trustedDomains);
}
if (adminUrl != null) {
checkURLTrusted(adminUrl, trustedHosts, trustedDomains);
}
for (String redirect : resolvedRedirects) {
checkURLTrusted(redirect, trustedHosts, trustedDomains);
}
}
protected void checkURLTrusted(String url, List<String> trustedHosts, List<String> trustedDomains) throws ClientRegistrationPolicyException {
try {
String host = new URL(url).getHost();
for (String trustedHost : trustedHosts) {
if (host.equals(trustedHost)) {
return;
}
}
for (String trustedDomain : trustedDomains) {
if (host.endsWith(trustedDomain)) {
return;
}
}
} catch (MalformedURLException mfe) {
logger.debugf(mfe, "URL '%s' is malformed", url);
throw new ClientRegistrationPolicyException("URL is malformed");
}
ServicesLogger.LOGGER.urlDoesntMatch(url);
throw new ClientRegistrationPolicyException("URL doesn't match any trusted host or trusted domain");
}
private static String relativeToAbsoluteURI(String rootUrl, String relative) {
if (relative == null) {
return null;
}
if (!relative.startsWith("/")) {
return relative;
} else if (rootUrl == null || rootUrl.isEmpty()) {
return null;
}
return rootUrl + relative;
}
boolean isHostMustMatch() {
return parseBoolean(TrustedHostClientRegistrationPolicyFactory.HOST_SENDING_REGISTRATION_REQUEST_MUST_MATCH);
}
boolean isClientUrisMustMatch() {
return parseBoolean(TrustedHostClientRegistrationPolicyFactory.CLIENT_URIS_MUST_MATCH);
}
// True by default
private boolean parseBoolean(String propertyKey) {
String val = componentModel.getConfig().getFirst(propertyKey);
return val==null || Boolean.parseBoolean(val);
}
}