/*
* Copyright © 2014 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package co.cask.cdap.security.server;
import co.cask.cdap.common.conf.CConfiguration;
import co.cask.cdap.common.conf.Constants;
import co.cask.cdap.common.conf.SConfiguration;
import co.cask.cdap.common.guice.ConfigModule;
import co.cask.cdap.common.guice.DiscoveryRuntimeModule;
import co.cask.cdap.common.guice.IOModule;
import co.cask.cdap.common.io.Codec;
import co.cask.cdap.common.utils.Networks;
import co.cask.cdap.security.auth.AccessToken;
import co.cask.cdap.security.auth.AccessTokenCodec;
import co.cask.cdap.security.guice.InMemorySecurityModule;
import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
import com.google.common.io.ByteStreams;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import com.google.inject.name.Names;
import com.google.inject.util.Modules;
import com.unboundid.ldap.listener.InMemoryDirectoryServer;
import com.unboundid.ldap.listener.InMemoryDirectoryServerConfig;
import com.unboundid.ldap.listener.InMemoryListenerConfig;
import com.unboundid.ldap.sdk.Entry;
import org.apache.commons.codec.binary.Base64;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.twill.discovery.Discoverable;
import org.apache.twill.discovery.DiscoveryServiceClient;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.net.SocketAddress;
import java.net.URL;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import javax.security.auth.login.Configuration;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.contains;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
/**
* Base test class for ExternalAuthenticationServer.
*/
public abstract class ExternalAuthenticationServerTestBase {
private static final Logger LOG = LoggerFactory.getLogger(ExternalAuthenticationServerTestBase.class);
private static ExternalAuthenticationServer server;
private static int port;
private static Codec<AccessToken> tokenCodec;
private static DiscoveryServiceClient discoveryServiceClient;
private static InMemoryDirectoryServer ldapServer;
protected static int ldapPort = Networks.getRandomPort();
private static final Logger TEST_AUDIT_LOGGER = mock(Logger.class);
// Needs to be set by derived classes.
protected static CConfiguration configuration;
protected static SConfiguration sConfiguration;
protected static InMemoryListenerConfig ldapListenerConfig;
protected abstract String getProtocol();
protected abstract HttpClient getHTTPClient() throws Exception;
protected static void setup() throws Exception {
Assert.assertNotNull("CConfiguration needs to be set by derived classes", configuration);
Module securityModule = Modules.override(new InMemorySecurityModule()).with(
new AbstractModule() {
@Override
protected void configure() {
bind(AuditLogHandler.class)
.annotatedWith(Names.named(
ExternalAuthenticationServer.NAMED_EXTERNAL_AUTH))
.toInstance(new AuditLogHandler(TEST_AUDIT_LOGGER));
}
}
);
Injector injector = Guice.createInjector(new IOModule(), securityModule,
new ConfigModule(getConfiguration(configuration),
HBaseConfiguration.create(), sConfiguration),
new DiscoveryRuntimeModule().getInMemoryModules());
server = injector.getInstance(ExternalAuthenticationServer.class);
tokenCodec = injector.getInstance(AccessTokenCodec.class);
discoveryServiceClient = injector.getInstance(DiscoveryServiceClient.class);
if (configuration.getBoolean(Constants.Security.SSL_ENABLED)) {
port = configuration.getInt(Constants.Security.AuthenticationServer.SSL_PORT);
} else {
port = configuration.getInt(Constants.Security.AUTH_SERVER_BIND_PORT);
}
try {
startLDAPServer();
} catch (Exception e) {
throw Throwables.propagate(e);
}
server.startAndWait();
LOG.info("Auth server running on port {}", port);
ldapServer.startListening();
TimeUnit.SECONDS.sleep(3);
}
public int getAuthServerPort() {
return port;
}
/**
* LDAP server and related handler configurations.
*/
private static CConfiguration getConfiguration(CConfiguration cConf) {
String configBase = Constants.Security.AUTH_HANDLER_CONFIG_BASE;
// Use random port for testing
cConf.setInt(Constants.Security.AUTH_SERVER_BIND_PORT, Networks.getRandomPort());
cConf.setInt(Constants.Security.AuthenticationServer.SSL_PORT, Networks.getRandomPort());
cConf.set(Constants.Security.AUTH_HANDLER_CLASS, LDAPAuthenticationHandler.class.getName());
cConf.set(Constants.Security.LOGIN_MODULE_CLASS_NAME, LDAPLoginModule.class.getName());
cConf.set(configBase.concat("debug"), "true");
cConf.set(configBase.concat("hostname"), "localhost");
cConf.set(configBase.concat("port"), Integer.toString(ldapPort));
cConf.set(configBase.concat("userBaseDn"), "dc=example,dc=com");
cConf.set(configBase.concat("userRdnAttribute"), "cn");
cConf.set(configBase.concat("userObjectClass"), "inetorgperson");
URL keytabUrl = ExternalAuthenticationServerTestBase.class.getClassLoader().getResource("test.keytab");
Assert.assertNotNull(keytabUrl);
cConf.set(Constants.Security.CFG_CDAP_MASTER_KRB_KEYTAB_PATH, keytabUrl.getPath());
cConf.set(Constants.Security.CFG_CDAP_MASTER_KRB_PRINCIPAL, "test_principal");
return cConf;
}
private static void startLDAPServer() throws Exception {
InMemoryDirectoryServerConfig config = new InMemoryDirectoryServerConfig("dc=example,dc=com");
config.setListenerConfigs(ldapListenerConfig);
Entry defaultEntry = new Entry(
"dn: dc=example,dc=com",
"objectClass: top",
"objectClass: domain",
"dc: example");
Entry userEntry = new Entry(
"dn: uid=user,dc=example,dc=com",
"objectClass: inetorgperson",
"cn: admin",
"sn: User",
"uid: user",
"userPassword: realtime");
ldapServer = new InMemoryDirectoryServer(config);
ldapServer.addEntries(defaultEntry, userEntry);
}
@AfterClass
public static void afterClass() throws Exception {
ldapServer.shutDown(true);
server.stopAndWait();
// Clear any security properties for zookeeper.
System.clearProperty(Constants.External.Zookeeper.ENV_AUTH_PROVIDER_1);
Configuration.setConfiguration(null);
}
/**
* Test an authorized request to server.
* @throws Exception
*/
@Test
public void testValidAuthentication() throws Exception {
HttpClient client = getHTTPClient();
String uri = String.format("%s://%s:%d/%s", getProtocol(), server.getSocketAddress().getAddress().getHostAddress(),
server.getSocketAddress().getPort(), GrantAccessToken.Paths.GET_TOKEN);
HttpGet request = new HttpGet(uri);
request.addHeader("Authorization", "Basic YWRtaW46cmVhbHRpbWU=");
HttpResponse response = client.execute(request);
assertEquals(response.getStatusLine().getStatusCode(), 200);
verify(TEST_AUDIT_LOGGER, timeout(10000).atLeastOnce()).trace(contains("admin"));
// Test correct headers being returned
String cacheControlHeader = response.getFirstHeader(HttpHeaders.Names.CACHE_CONTROL).getValue();
String pragmaHeader = response.getFirstHeader(HttpHeaders.Names.PRAGMA).getValue();
String contentType = response.getFirstHeader(HttpHeaders.Names.CONTENT_TYPE).getValue();
assertEquals("no-store", cacheControlHeader);
assertEquals("no-cache", pragmaHeader);
assertEquals("application/json;charset=UTF-8", contentType);
// Test correct response body
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ByteStreams.copy(response.getEntity().getContent(), bos);
String responseBody = bos.toString("UTF-8");
bos.close();
JsonParser parser = new JsonParser();
JsonObject responseJson = (JsonObject) parser.parse(responseBody);
String tokenType = responseJson.get(ExternalAuthenticationServer.ResponseFields.TOKEN_TYPE).toString();
long expiration = responseJson.get(ExternalAuthenticationServer.ResponseFields.EXPIRES_IN).getAsLong();
assertEquals(String.format("\"%s\"", ExternalAuthenticationServer.ResponseFields.TOKEN_TYPE_BODY), tokenType);
long expectedExpiration = configuration.getInt(Constants.Security.TOKEN_EXPIRATION);
// Test expiration time in seconds
assertEquals(expectedExpiration / 1000, expiration);
// Test that the server passes back an AccessToken object which can be decoded correctly.
String encodedToken = responseJson.get(ExternalAuthenticationServer.ResponseFields.ACCESS_TOKEN).getAsString();
AccessToken token = tokenCodec.decode(Base64.decodeBase64(encodedToken));
assertEquals("admin", token.getIdentifier().getUsername());
LOG.info("AccessToken got from ExternalAuthenticationServer is: " + encodedToken);
}
/**
* Test an unauthorized request to server.
* @throws Exception
*/
@Test
public void testInvalidAuthentication() throws Exception {
HttpClient client = getHTTPClient();
String uri = String.format("%s://%s:%d/%s", getProtocol(), server.getSocketAddress().getAddress().getHostAddress(),
server.getSocketAddress().getPort(), GrantAccessToken.Paths.GET_TOKEN);
HttpGet request = new HttpGet(uri);
request.addHeader("Authorization", "xxxxx");
HttpResponse response = client.execute(request);
// Request is Unauthorized
assertEquals(401, response.getStatusLine().getStatusCode());
verify(TEST_AUDIT_LOGGER, timeout(10000).atLeastOnce()).trace(contains("401"));
}
/**
* Test an unauthorized status request to server.
* @throws Exception
*/
@Test
public void testStatusResponse() throws Exception {
HttpClient client = getHTTPClient();
String uri = String.format("%s://%s:%d/%s", getProtocol(), server.getSocketAddress().getAddress().getHostAddress(),
server.getSocketAddress().getPort(), Constants.EndPoints.STATUS);
HttpGet request = new HttpGet(uri);
HttpResponse response = client.execute(request);
// Status request is authorized without any extra headers
assertEquals(200, response.getStatusLine().getStatusCode());
}
/**
* Test getting a long lasting Access Token.
* @throws Exception
*/
@Test
public void testExtendedToken() throws Exception {
HttpClient client = getHTTPClient();
String uri = String.format("%s://%s:%d/%s", getProtocol(), server.getSocketAddress().getAddress().getHostAddress(),
server.getSocketAddress().getPort(), GrantAccessToken.Paths.GET_EXTENDED_TOKEN);
HttpGet request = new HttpGet(uri);
request.addHeader("Authorization", "Basic YWRtaW46cmVhbHRpbWU=");
HttpResponse response = client.execute(request);
assertEquals(200, response.getStatusLine().getStatusCode());
// Test correct response body
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ByteStreams.copy(response.getEntity().getContent(), bos);
String responseBody = bos.toString("UTF-8");
bos.close();
JsonParser parser = new JsonParser();
JsonObject responseJson = (JsonObject) parser.parse(responseBody);
long expiration = responseJson.get(ExternalAuthenticationServer.ResponseFields.EXPIRES_IN).getAsLong();
long expectedExpiration = configuration.getInt(Constants.Security.EXTENDED_TOKEN_EXPIRATION);
// Test expiration time in seconds
assertEquals(expectedExpiration / 1000, expiration);
// Test that the server passes back an AccessToken object which can be decoded correctly.
String encodedToken = responseJson.get(ExternalAuthenticationServer.ResponseFields.ACCESS_TOKEN).getAsString();
AccessToken token = tokenCodec.decode(Base64.decodeBase64(encodedToken));
assertEquals("admin", token.getIdentifier().getUsername());
LOG.info("AccessToken got from ExternalAuthenticationServer is: " + encodedToken);
}
/**
* Test that invalid paths return a 404 Not Found.
* @throws Exception
*/
@Test
public void testInvalidPath() throws Exception {
HttpClient client = getHTTPClient();
String uri = String.format("%s://%s:%d/%s", getProtocol(), server.getSocketAddress().getAddress().getHostAddress(),
server.getSocketAddress().getPort(), "invalid");
HttpGet request = new HttpGet(uri);
request.addHeader("Authorization", "Basic YWRtaW46cmVhbHRpbWU=");
HttpResponse response = client.execute(request);
assertEquals(404, response.getStatusLine().getStatusCode());
}
/**
* Test that the service is discoverable.
* @throws Exception
*/
@Test
public void testServiceRegistration() throws Exception {
Iterable<Discoverable> discoverables = discoveryServiceClient.discover(Constants.Service.EXTERNAL_AUTHENTICATION);
Set<SocketAddress> addresses = Sets.newHashSet();
for (Discoverable discoverable : discoverables) {
addresses.add(discoverable.getSocketAddress());
}
Assert.assertTrue(addresses.contains(server.getSocketAddress()));
}
}