/* * Copyright 2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.social.connect.mem; import java.util.Collections; import java.util.List; import java.util.Set; import java.util.Map.Entry; import org.springframework.social.connect.Connection; import org.springframework.social.connect.ConnectionFactoryLocator; import org.springframework.social.connect.ConnectionKey; import org.springframework.social.connect.ConnectionRepository; import org.springframework.social.connect.DuplicateConnectionException; import org.springframework.social.connect.NoSuchConnectionException; import org.springframework.social.connect.NotConnectedException; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; public class InMemoryConnectionRepository implements ConnectionRepository { // <providerId, Connection<provider API>> private MultiValueMap<String, Connection<?>> connections; private ConnectionFactoryLocator connectionFactoryLocator; public InMemoryConnectionRepository(ConnectionFactoryLocator connectionFactoryLocator) { this.connectionFactoryLocator = connectionFactoryLocator; this.connections = new LinkedMultiValueMap<String, Connection<?>>(); } public MultiValueMap<String, Connection<?>> findAllConnections() { if (connections.isEmpty()) { MultiValueMap<String, Connection<?>> result = new LinkedMultiValueMap<String, Connection<?>>(); Set<String> registeredProviderIds = connectionFactoryLocator.registeredProviderIds(); for (String registeredProviderId : registeredProviderIds) { result.put(registeredProviderId, Collections.<Connection<?>>emptyList()); } return result; } else { return connections; } } public List<Connection<?>> findConnections(String providerId) { List<Connection<?>> emptyConnectionList = Collections.emptyList(); return connections.containsKey(providerId) ? connections.get(providerId) : emptyConnectionList; } @SuppressWarnings("unchecked") public <A> List<Connection<A>> findConnections(Class<A> apiType) { List<?> providerConnections = findConnections(getProviderId(apiType)); return (List<Connection<A>>) providerConnections; } public MultiValueMap<String, Connection<?>> findConnectionsToUsers(MultiValueMap<String, String> providerUserIds) { Assert.notEmpty(providerUserIds, "Provider user IDs cannot be empty."); MultiValueMap<String, Connection<?>> connectionsToUsers = new LinkedMultiValueMap<String, Connection<?>>(providerUserIds.size()); for (Entry<String, List<String>> providerConnectionEntry : providerUserIds.entrySet()) { String providerId = providerConnectionEntry.getKey(); List<String> userIds = providerConnectionEntry.getValue(); if (connections.containsKey(providerId)) { List<Connection<?>> providerConnections = connections.get(providerId); for (Connection<?> connection : providerConnections) { if (userIds.contains(connection.getKey().getProviderUserId())) { connectionsToUsers.add(providerId, connection); } } } } return connectionsToUsers; } public Connection<?> getConnection(ConnectionKey connectionKey) { if (connections.containsKey(connectionKey.getProviderId())) { List<Connection<?>> providerConnections = connections.get(connectionKey.getProviderId()); for (Connection<?> connection : providerConnections) { if (connection.getKey().equals(connectionKey)) { return connection; } } } throw new NoSuchConnectionException(connectionKey); } @SuppressWarnings("unchecked") public <A> Connection<A> getConnection(Class<A> apiType, String providerUserId) { return (Connection<A>) getConnection(new ConnectionKey(getProviderId(apiType), providerUserId)); } public <A> Connection<A> getPrimaryConnection(Class<A> apiType) { Connection<A> primaryConnection = findPrimaryConnection(apiType); if (primaryConnection == null) { throw new NotConnectedException(getProviderId(apiType)); } return primaryConnection; } @SuppressWarnings("unchecked") public <A> Connection<A> findPrimaryConnection(Class<A> apiType) { String providerId = getProviderId(apiType); if (connections.containsKey(providerId)) { return (Connection<A>) connections.get(providerId).get(0); } return null; } public void addConnection(Connection<?> connection) { try { ConnectionKey connectionKey = connection.getKey(); getConnection(connectionKey); throw new DuplicateConnectionException(connectionKey); } catch (NoSuchConnectionException e) { connections.add(connection.createData().getProviderId(), connection); } } public void updateConnection(Connection<?> connection) { connections.add(connection.createData().getProviderId(), connection); } public void removeConnections(String providerId) { connections.remove(providerId); } public void removeConnection(ConnectionKey connectionKey) { String providerId = connectionKey.getProviderId(); if (connections.containsKey(providerId)) { List<Connection<?>> providerConnections = connections.get(providerId); for (Connection<?> connection : providerConnections) { if (connection.getKey().equals(connectionKey)) { providerConnections.remove(connection); } } } } private <A> String getProviderId(Class<A> apiType) { return connectionFactoryLocator.getConnectionFactory(apiType).getProviderId(); } }