/* * Copyright 2017 ThoughtWorks, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.thoughtworks.go.server.persistence; import java.lang.reflect.Field; import java.util.*; import com.thoughtworks.go.domain.PersistentObject; import com.thoughtworks.go.server.dao.DatabaseAccessHelper; import com.thoughtworks.go.server.domain.oauth.OauthAuthorization; import com.thoughtworks.go.server.domain.oauth.OauthClient; import com.thoughtworks.go.server.domain.oauth.OauthToken; import com.thoughtworks.go.server.oauth.OauthDataSource; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.orm.hibernate3.HibernateObjectRetrievalFailureException; import org.springframework.orm.hibernate3.HibernateTemplate; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import static com.thoughtworks.go.util.DataStructureUtils.m; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.core.Is.is; import static org.hamcrest.core.IsNot.not; import static org.hamcrest.core.IsNull.nullValue; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.junit.matchers.JUnitMatchers.hasItems; @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration(locations = { "classpath:WEB-INF/applicationContext-global.xml", "classpath:WEB-INF/applicationContext-dataLocalAccess.xml", "classpath:WEB-INF/applicationContext-acegi-security.xml" }) public class OauthRepositoryTest { @Autowired private OauthRepository repo; @Autowired private DatabaseAccessHelper dbHelper; private HibernateTemplate template; @Before public void setUp() throws Exception { template = repo.getHibernateTemplate(); dbHelper.onSetUp(); } @After public void tearDown() throws Exception { dbHelper.onTearDown(); } @Test public void shouldHonorTransaction() throws Exception{ final OauthClient client = new OauthClient("mingle", "client_id", "client_secret", "http://something"); try { repo.transaction(new Runnable() { public void run() { template.save(client); throw new RuntimeException("ouch! it failed."); } }); fail("should have bubbled up transaction failing exception"); } catch (Exception e) { assertThat(e.getMessage(), is("ouch! it failed.")); } try { template.load(OauthClient.class, client.getId()); fail("should have failed because transaction was not successful"); } catch (Exception e) { assertThat(e, is(Matchers.<Object>instanceOf(HibernateObjectRetrievalFailureException.class))); } } @Test public void shouldFindOauthClientById() { OauthClient client = new OauthClient("mingle", "client_id", "client_secret", "http://something"); template.save(client); OauthDataSource.OauthClientDTO dto = repo.findOauthClientById(client.getId()); assertThat(dto, is(client.getDTO())); } @Test public void shouldFindOauthClientByName() { OauthClient client = new OauthClient("mingle", "client_id", "client_secret", "http://something"); template.save(client); OauthDataSource.OauthClientDTO dto = repo.findOauthClientByName("mingle"); assertThat(dto, is(client.getDTO())); } @Test public void shouldFindOauthClientByRedirectUri() { OauthClient client = new OauthClient("mingle", "client_id", "client_secret", "http://something"); template.save(client); OauthDataSource.OauthClientDTO dto = repo.findOauthClientByRedirectUri("http://something"); assertThat(dto, is(client.getDTO())); } @Test public void shouldHaveUniqueClientNames() { OauthClient client = new OauthClient("mingle", "client_id", "client_secret", "http://something"); template.save(client); try { template.save(new OauthClient("mingle", "another_id", "super_secret", "http://something/another")); fail("Should not be able to add 2 clients with the same name"); } catch (Exception expected) { } } @Test public void shouldFindAllOauthClients() { OauthClient mingle1 = new OauthClient("mingle09", "client_id_9", "client_secret_9", "http://something"); OauthClient mingle2 = new OauthClient("mingle05", "client_id_5", "client_secret_5", "http://something-else"); OauthClient go01 = new OauthClient("go01", "client_id_1", "client_secret_1", "http://fake-go-server"); template.save(mingle1); template.save(mingle2); template.save(go01); Collection<OauthDataSource.OauthClientDTO> clients = repo.findAllOauthClient(); assertThat(clients.size(), is(3)); assertThat(clients, hasItems(mingle1.getDTO(), mingle2.getDTO(), go01.getDTO())); } @Test public void shouldFindOauthClientByClientId() { OauthClient mingle09 = new OauthClient("mingle09", "client_id_9", "client_secret_9", "http://something"); template.save(mingle09); OauthDataSource.OauthClientDTO dto = repo.findOauthClientByClientId(String.valueOf("client_id_9")); assertThat(dto, is(mingle09.getDTO())); dto = repo.findOauthClientByClientId(String.valueOf("client_id_non_existent")); assertThat(dto, is(nullValue())); } @Test public void shouldSaveOauthClient() throws IllegalAccessException, NoSuchFieldException { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); Map<String, String> map = attrMap(mingle09.getDTO()); repo.saveOauthClient(map); OauthClient client = (OauthClient) template.find("from OauthClient where clientSecret = 'client_secret'").get(0); OauthDataSource.OauthClientDTO dto = client.getDTO(); assertHasIdAndMatches(dto, mingle09.getDTO()); } @Test public void shouldUpdateOauthClient() throws IllegalAccessException { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); final Map map = attrMap(mingle09.getDTO()); template.save(mingle09); map.put("name", "a different name"); map.put("id", mingle09.getId()); repo.transaction( new Runnable() { public void run() { repo.saveOauthClient(map); }}); OauthClient client = template.load(OauthClient.class, mingle09.getId()); assertThat(client.getDTO().getName(), is("a different name")); } @Test public void shouldDeleteOauthClient() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); assertThat(template.find("from OauthClient").size(), is(1)); repo.deleteOauthClient(mingle09.getId()); assertThat(template.find("from OauthClient").size(), is(0)); } @Test public void shouldNotBombWhenNoOauthClientToDelete() { assertThat(template.find("from OauthClient").size(), is(0)); repo.deleteOauthClient(10); } @Test public void shouldDeleteOauthClientAndAssociatedApprovalsAndTokens() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); template.save(new OauthAuthorization("foo@bar.com", mingle09, "code", 332333)); template.save(new OauthToken("foo@bar.com", mingle09, "foo-access", "foo-refresh", 12345)); assertThat(template.find("from OauthClient").size(), is(1)); assertThat(template.find("from OauthAuthorization").size(), is(1)); assertThat(template.find("from OauthToken").size(), is(1)); repo.deleteOauthClient(mingle09.getId()); assertThat(template.find("from OauthClient").size(), is(0)); assertThat(template.find("from OauthAuthorization").size(), is(0)); assertThat(template.find("from OauthToken").size(), is(0)); } @Test public void shouldFindAllAuthorizationsByClientId() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); template.save(authFoo); OauthAuthorization authBar = new OauthAuthorization("bar@bar.com", mingle09, "code112", 334337); template.save(authBar); OauthClient mingle05 = new OauthClient("mingle05", "client_id_2", "client_secret_2", "http://something-else"); template.save(mingle05); OauthAuthorization authBaz = new OauthAuthorization("baz@bar.com", mingle05, "code115", 334337); template.save(authBaz); Collection<OauthDataSource.OauthAuthorizationDTO> authorizations = repo.findAllOauthAuthorizationByOauthClientId("client_id"); assertThat(authorizations.size(), is(2)); assertThat(authorizations, hasItems(authFoo.getDTO(), authBar.getDTO())); } @Test public void shouldFindOauthAuthorizationById() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); template.save(authFoo); OauthDataSource.OauthAuthorizationDTO dto = repo.findOauthAuthorizationById(authFoo.getId()); assertThat(dto, is(authFoo.getDTO())); } @Test public void shouldFindOauthAuthorizationByCode() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); template.save(authFoo); OauthDataSource.OauthAuthorizationDTO dto = repo.findOauthAuthorizationByCode("code"); assertThat(dto, is(authFoo.getDTO())); } @Test public void shouldReturnNullWhenNoAuthFoundByCode() { OauthDataSource.OauthAuthorizationDTO dto = repo.findOauthAuthorizationByCode("some-non-existing-code"); assertThat(dto, is(nullValue())); } @Test public void shouldSaveAuthorization() throws IllegalAccessException, NoSuchFieldException { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); repo.saveOauthAuthorization(attrMap(authFoo.getDTO())); OauthDataSource.OauthAuthorizationDTO dto = ((OauthAuthorization) template.find("from OauthAuthorization where code = ? ", "code").get(0)).getDTO(); assertHasIdAndMatches(dto, authFoo.getDTO()); } @Test public void shouldUpdateAuthorization() throws IllegalAccessException, NoSuchFieldException { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); template.save(authFoo); Map map = attrMap(authFoo.getDTO()); map.put("id", authFoo.getId()); map.put("code", "hello_world"); repo.saveOauthAuthorization(map); OauthDataSource.OauthAuthorizationDTO dto = template.load(OauthAuthorization.class, authFoo.getId()).getDTO(); assertThat(dto.getCode(), is("hello_world")); } @Test public void shouldDeleteAuthorization() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); template.save(authFoo); assertThat(template.find("from OauthAuthorization").size(), is(1)); repo.deleteOauthAuthorization(authFoo.getId()); assertThat(template.find("from OauthAuthorization").size(), is(0)); } @Test public void shouldNotBombWhenNoAuthorizationToDelete() { assertThat(template.find("from OauthAuthorization").size(), is(0)); repo.deleteOauthAuthorization(10); } @Test public void shouldFindOauthTokenById() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken token = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(token); OauthDataSource.OauthTokenDTO tokenDto = repo.findOauthTokenById(token.getId()); assertThat(tokenDto, is(token.getDTO())); } @Test public void shouldFindAllOauthTokenByOauthClientId() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken tokenFoo = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(tokenFoo); OauthToken tokenBar = new OauthToken("bar@bar.com", mingle09, "access-token-bar", "refresh-token-bar", 23324324); template.save(tokenBar); OauthClient mingle05 = new OauthClient("mingle05", "client_id_5", "client_secret_5", "http://something"); template.save(mingle05); OauthToken tokenBaz = new OauthToken("baz@bar.com", mingle05, "access-token-baz", "refresh-token-baz", 23324324); template.save(tokenBaz); Collection<OauthDataSource.OauthTokenDTO> tokens = repo.findAllOauthTokenByOauthClientId("client_id"); assertThat(tokens.size(), is(2)); assertThat(tokens, hasItems(tokenFoo.getDTO(), tokenBar.getDTO())); } @Test public void shouldFindAllOauthTokenByUserId() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(fooTokenFor09); OauthToken barTokenFor09 = new OauthToken("bar@bar.com", mingle09, "access-token-bar", "refresh-token-bar", 23324324); template.save(barTokenFor09); OauthClient mingle05 = new OauthClient("mingle05", "client_id_5", "client_secret_5", "http://something"); template.save(mingle05); OauthToken fooTokenFor05 = new OauthToken("foo@bar.com", mingle05, "access-token-baz", "refresh-token-baz", 23324324); template.save(fooTokenFor05); Collection<OauthDataSource.OauthTokenDTO> tokens = repo.findAllOauthTokenByUserId("foo@bar.com"); assertThat(tokens.size(), is(2)); assertThat(tokens, hasItems(fooTokenFor09.getDTO(), fooTokenFor05.getDTO())); } @Test public void shouldfindOauthTokenByAccessToken() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(fooTokenFor09); OauthToken barTokenFor09 = new OauthToken("bar@bar.com", mingle09, "access-token-bar", "refresh-token-bar", 23324324); template.save(barTokenFor09); OauthDataSource.OauthTokenDTO dto = repo.findOauthTokenByAccessToken("access-token"); assertThat(dto, is(fooTokenFor09.getDTO())); } @Test public void shouldfindOauthTokenByRefreshToken() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(fooTokenFor09); OauthToken barTokenFor09 = new OauthToken("bar@bar.com", mingle09, "access-token-bar", "refresh-token-bar", 23324324); template.save(barTokenFor09); OauthDataSource.OauthTokenDTO dto = repo.findOauthTokenByRefreshToken("refresh-token"); assertThat(dto, is(fooTokenFor09.getDTO())); } @Test public void shouldSaveTokens() throws IllegalAccessException, NoSuchFieldException { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); repo.saveOauthToken(attrMap(fooTokenFor09.getDTO())); OauthDataSource.OauthTokenDTO dto = ((OauthToken) template.find("from OauthToken where accessToken = ? ", "access-token").get(0)).getDTO(); assertHasIdAndMatches(dto, fooTokenFor09.getDTO()); } @Test public void shouldUpdateTokens() throws IllegalAccessException, NoSuchFieldException { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(fooTokenFor09); Map map = attrMap(fooTokenFor09.getDTO()); map.put("id", fooTokenFor09.getId()); map.put("access_token", "hello_world"); repo.saveOauthToken(map); OauthDataSource.OauthTokenDTO dto = template.load(OauthToken.class, fooTokenFor09.getId()).getDTO(); assertThat(dto.getAccessToken(), is("hello_world")); } @Test public void shouldDeleteOauthToken() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(fooTokenFor09); assertThat(template.find("from OauthToken").size(), is(1)); repo.deleteOauthToken(fooTokenFor09.getId()); assertThat(template.find("from OauthToken").size(), is(0)); } @Test public void shouldDeleteAllTokensAndCodes() { OauthClient mingle09 = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(mingle09); OauthAuthorization authFoo = new OauthAuthorization("foo@bar.com", mingle09, "code", 332333); template.save(authFoo); OauthToken fooTokenFor09 = new OauthToken("foo@bar.com", mingle09, "access-token", "refresh-token", 23324324); template.save(fooTokenFor09); repo.deleteAllOauthGrants(); assertThat(template.find("from OauthAuthorization").size(), is(0)); assertThat(template.find("from OauthToken").size(), is(0)); } @Test public void shouldNotBombWhenNoTokenToDelete() { assertThat(template.find("from OauthToken").size(), is(0)); repo.deleteOauthToken(10); } @Test public void shouldSaveClientIfIdIsMissing() throws Exception { String name = "oauth"; String clientId = "client_id"; String clientSecret = "client_secret"; String redirectUri = "http://something"; Map<String, String> map = m("id", "", "name", name, "client_id", clientId, "client_secret", clientSecret, "redirect_uri", redirectUri); repo.saveClient(map); OauthClient client = (OauthClient) template.find("from OauthClient where clientSecret = '" + clientSecret + "'").get(0); OauthDataSource.OauthClientDTO dto = client.getDTO(); assertThat(dto.getName(), is(name)); assertThat(dto.getClientId(), is(clientId)); assertThat(dto.getClientSecret(), is(clientSecret)); assertThat(dto.getRedirectUri(), is(redirectUri)); assertThat(dto.getId(), is(not(PersistentObject.NOT_PERSISTED))); } @Test public void shouldSaveAuthorization_ForEngineFlow() throws Exception { OauthClient client = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(client); long expiresAt = new Date().getTime(); Map<String, String> attributes = m("authenticity_token", "eJkmGwpHh045A/h5uhme+4Pqdr+E8b+jgRq1+vt/s6M=", "authorize", "Yes", "client_id", String.valueOf(client.getDTO().getId()), "redirect_uri", "https://mingle05.thoughtworks.com/gadgets/oauthcallback", "response_type", "code", "state", "eyJvYXV0aF9hdXRob3JpemVfdXJsIjoiaHR0cHM6Ly8xOTIuMTY4Ljk5LjU5\nOjgxNTQvZ28vYWRtaW4vb2F1dGgvYXV0aG9yaXplIn0=", "code", "ABCD", "expires_at", expiresAt, "user_id", "1"); OauthDataSource.OauthAuthorizationDTO dto = repo.saveAuthorization(attributes); assertThat(dto.getId(), is(not(Matchers.nullValue()))); assertThat(dto.getClientId(), is(String.valueOf(client.getDTO().getId()))); assertThat(dto.getExpiresAt(), is(expiresAt)); assertThat(dto.getCode(), is("ABCD")); assertThat(dto.getOauthClientId(), is(String.valueOf(client.getDTO().getId()))); assertThat(dto.getUserId(), is("1")); } @Test public void shouldSaveToken_ForEngineFlow() throws Exception { OauthClient client = new OauthClient("mingle09", "client_id", "client_secret", "http://something"); template.save(client); String accessToken = UUID.randomUUID().toString(); String refreshToken = UUID.randomUUID().toString(); long expiresAt = new Date().getTime(); Map<String, String> attributes = m("user_id", "1", "client_id", String.valueOf(client.getDTO().getId()), "access_token", accessToken, "refresh_token", refreshToken, "expires_at", expiresAt); OauthDataSource.OauthTokenDTO dto = repo.saveToken(attributes); assertThat(dto.getUserId(), is("1")); } static void assertHasIdAndMatches(Object loaded, Object unpersistentExpected) throws NoSuchFieldException, IllegalAccessException { assertThat(loaded, is(Matchers.<Object>instanceOf(unpersistentExpected.getClass()))); Field id = loaded.getClass().getDeclaredField("id"); id.setAccessible(true); assertThat((Long) id.get(loaded), is(greaterThan(0l))); for (Field field : loaded.getClass().getDeclaredFields()) { if (field.getName().equals("id")) { continue; } field.setAccessible(true); assertThat(field.get(loaded), is(field.get(unpersistentExpected))); } } static Map attrMap(Object dto) throws IllegalAccessException { Map map = new HashMap(); for (Field field : dto.getClass().getDeclaredFields()) { field.setAccessible(true); map.put(camelToSnakeCase(field.getName()), field.get(dto)); } return map; } static String camelToSnakeCase(String camelCased) { StringBuilder builder = new StringBuilder(); for (Character part : camelCased.toCharArray()) { if (part.toString().matches("[A-Z]")) { builder.append("_"); } builder.append(part.toString().toLowerCase()); } return builder.toString(); } }