/*
* Copyright 2013-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package sparklr.common;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import javax.sql.DataSource;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.runner.RunWith;
import org.springframework.aop.framework.Advised;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.security.SecurityProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.embedded.LocalServerPort;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.core.io.ClassPathResource;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.security.crypto.codec.Base64;
import org.springframework.security.oauth2.client.resource.BaseOAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.test.BeforeOAuth2Context;
import org.springframework.security.oauth2.client.test.OAuth2ContextSetup;
import org.springframework.security.oauth2.client.token.AccessTokenProvider;
import org.springframework.security.oauth2.client.token.OAuth2AccessTokenSupport;
import org.springframework.security.oauth2.client.token.grant.implicit.ImplicitResourceDetails;
import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordResourceDetails;
import org.springframework.security.oauth2.client.token.grant.redirect.AbstractRedirectResourceDetails;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.approval.ApprovalStore;
import org.springframework.security.oauth2.provider.approval.InMemoryApprovalStore;
import org.springframework.security.oauth2.provider.approval.JdbcApprovalStore;
import org.springframework.security.oauth2.provider.token.ConsumerTokenServices;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.security.oauth2.provider.token.store.InMemoryTokenStore;
import org.springframework.security.oauth2.provider.token.store.JdbcTokenStore;
import org.springframework.test.context.junit4.SpringRunner;
@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment=WebEnvironment.RANDOM_PORT)
public abstract class AbstractIntegrationTests {
public static final String JAVAX_NET_SSL_TRUST_STORE_PASSWORD = "javax.net.ssl.trustStorePassword";
public static final String JAVAX_NET_SSL_TRUST_STORE = "javax.net.ssl.trustStore";
private static String globalTokenPath;
private static String globalTokenKeyPath;
private static String globalCheckTokenPath;
private static String globalAuthorizePath;
static {
if (new ClassPathResource("sample.jks").exists()) {
try {
System.setProperty(JAVAX_NET_SSL_TRUST_STORE,
new ClassPathResource("sample.jks").getFile().getAbsolutePath());
} catch (IOException e) {
throw new IllegalStateException(e);
}
System.setProperty(JAVAX_NET_SSL_TRUST_STORE_PASSWORD, "secret");
}
}
@LocalServerPort
private int port;
@Rule
public HttpTestUtils http = HttpTestUtils.standard();
@Rule
public OAuth2ContextSetup context = OAuth2ContextSetup.standard(http);
@Autowired(required = false)
private TokenStore tokenStore;
@Autowired(required = false)
private ApprovalStore approvalStore;
@Autowired(required = false)
private DataSource dataSource;
@Autowired
private SecurityProperties security;
@Autowired
private ServerProperties server;
@Autowired(required = false)
@Qualifier("consumerTokenServices")
private ConsumerTokenServices tokenServices;
@After
public void cancelToken() {
try {
OAuth2AccessToken token = context.getOAuth2ClientContext().getAccessToken();
if (token != null) {
tokenServices.revokeToken(token.getValue());
}
} catch (Exception e) {
// ignore
}
}
protected void cancelToken(String value) {
try {
tokenServices.revokeToken(value);
} catch (Exception e) {
// ignore
}
}
protected AccessTokenProvider createAccessTokenProvider() {
return null;
}
@Before
public void init() {
String prefix = server.getServletPrefix();
http.setPort(port);
http.setPrefix(prefix);
if (new ClassPathResource("sample.jks").exists()) {
http.setProtocol("https");
}
}
@BeforeOAuth2Context
public void setupAccessTokenProvider() {
AccessTokenProvider accessTokenProvider = createAccessTokenProvider();
if (accessTokenProvider instanceof OAuth2AccessTokenSupport) {
((OAuth2AccessTokenSupport) accessTokenProvider)
.setRequestFactory(context.getRestTemplate().getRequestFactory());
context.setAccessTokenProvider(accessTokenProvider);
}
}
@BeforeOAuth2Context
public void fixPaths() {
init();
BaseOAuth2ProtectedResourceDetails resource = (BaseOAuth2ProtectedResourceDetails) context.getResource();
List<HttpMessageConverter<?>> converters = new ArrayList<>(context.getRestTemplate().getMessageConverters());
converters.addAll(getAdditionalConverters());
context.getRestTemplate().setMessageConverters(converters);
context.getRestTemplate().setInterceptors(getInterceptors());
resource.setAccessTokenUri(http.getUrl(tokenPath()));
if (resource instanceof AbstractRedirectResourceDetails) {
((AbstractRedirectResourceDetails) resource).setUserAuthorizationUri(http.getUrl(authorizePath()));
}
if (resource instanceof ImplicitResourceDetails) {
resource.setAccessTokenUri(http.getUrl(authorizePath()));
}
if (resource instanceof ResourceOwnerPasswordResourceDetails && !(resource instanceof DoNotOverride)) {
((ResourceOwnerPasswordResourceDetails) resource).setUsername(getUsername());
((ResourceOwnerPasswordResourceDetails) resource).setPassword(getPassword());
}
}
protected List<ClientHttpRequestInterceptor> getInterceptors() {
return Collections.emptyList();
}
protected Collection<? extends HttpMessageConverter<?>> getAdditionalConverters() {
return Collections.emptySet();
}
protected String getPassword() {
return security.getUser().getPassword();
}
protected String getUsername() {
return security.getUser().getName();
}
public interface DoNotOverride {
}
@After
public void close() throws Exception {
clear(tokenStore);
clear(approvalStore);
}
protected String getBasicAuthentication() {
return "Basic " + new String(Base64.encode((getUsername() + ":" + getPassword()).getBytes()));
}
private void clear(ApprovalStore approvalStore) throws Exception {
if (approvalStore instanceof Advised) {
Advised advised = (Advised) tokenStore;
ApprovalStore target = (ApprovalStore) advised.getTargetSource().getTarget();
clear(target);
return;
}
if (approvalStore instanceof InMemoryApprovalStore) {
((InMemoryApprovalStore) approvalStore).clear();
}
if (approvalStore instanceof JdbcApprovalStore) {
JdbcTemplate template = new JdbcTemplate(dataSource);
template.execute("delete from oauth_approvals");
}
}
private void clear(TokenStore tokenStore) throws Exception {
if (tokenStore instanceof Advised) {
Advised advised = (Advised) tokenStore;
TokenStore target = (TokenStore) advised.getTargetSource().getTarget();
clear(target);
return;
}
if (tokenStore instanceof InMemoryTokenStore) {
((InMemoryTokenStore) tokenStore).clear();
}
if (tokenStore instanceof JdbcTokenStore) {
JdbcTemplate template = new JdbcTemplate(dataSource);
template.execute("delete from oauth_access_token");
template.execute("delete from oauth_refresh_token");
template.execute("delete from oauth_client_token");
template.execute("delete from oauth_code");
}
}
@Value("${oauth.paths.token:/oauth/token}")
public void setTokenPath(String tokenPath) {
globalTokenPath = tokenPath;
}
@Value("${oauth.paths.token_key:/oauth/token_key}")
public void setTokenKeyPath(String tokenKeyPath) {
globalTokenKeyPath = tokenKeyPath;
}
@Value("${oauth.paths.check_token:/oauth/check_token}")
public void setCheckTokenPath(String tokenPath) {
globalCheckTokenPath = tokenPath;
}
@Value("${oauth.paths.authorize:/oauth/authorize}")
public void setAuthorizePath(String authorizePath) {
globalAuthorizePath = authorizePath;
}
public static String tokenPath() {
return globalTokenPath;
}
public static String tokenKeyPath() {
return globalTokenKeyPath;
}
public static String checkTokenPath() {
return globalCheckTokenPath;
}
public static String authorizePath() {
return globalAuthorizePath;
}
}