/* * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.social.connect.sqlite; import java.util.ArrayList; import java.util.Collections; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Map.Entry; import java.util.Set; import org.springframework.security.crypto.encrypt.TextEncryptor; import org.springframework.social.connect.Connection; import org.springframework.social.connect.ConnectionData; import org.springframework.social.connect.ConnectionFactory; import org.springframework.social.connect.ConnectionFactoryLocator; import org.springframework.social.connect.ConnectionKey; import org.springframework.social.connect.ConnectionRepository; import org.springframework.social.connect.DuplicateConnectionException; import org.springframework.social.connect.NoSuchConnectionException; import org.springframework.social.connect.NotConnectedException; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import android.content.ContentValues; import android.database.Cursor; import android.database.sqlite.SQLiteConstraintException; import android.database.sqlite.SQLiteDatabase; import android.database.sqlite.SQLiteOpenHelper; /** * {@link ConnectionRepository} that uses SQLite to persist connection data to a relational database. * * @author Roy Clarkson * @since 1.0 */ public class SQLiteConnectionRepository implements ConnectionRepository { private final String userId; private final SQLiteOpenHelper repositoryHelper; private final ConnectionFactoryLocator connectionFactoryLocator; private final TextEncryptor textEncryptor; public SQLiteConnectionRepository(SQLiteOpenHelper repositoryHelper, ConnectionFactoryLocator connectionFactoryLocator, TextEncryptor textEncryptor) { this("1", repositoryHelper, connectionFactoryLocator, textEncryptor); } public SQLiteConnectionRepository(String userId, SQLiteOpenHelper repositoryHelper, ConnectionFactoryLocator connectionFactoryLocator, TextEncryptor textEncryptor) { this.userId = userId; this.repositoryHelper = repositoryHelper; this.connectionFactoryLocator = connectionFactoryLocator; this.textEncryptor = textEncryptor; } public MultiValueMap<String, Connection<?>> findAllConnections() { final String sql = selectFromUserConnection() + " where userId = ? order by providerId, rank"; final String[] selectionArgs = { userId }; List<Connection<?>> resultList = queryForConnections(sql, selectionArgs); MultiValueMap<String, Connection<?>> connections = new LinkedMultiValueMap<String, Connection<?>>(); Set<String> registeredProviderIds = connectionFactoryLocator.registeredProviderIds(); for (String registeredProviderId : registeredProviderIds) { connections.put(registeredProviderId, Collections.<Connection<?>> emptyList()); } for (Connection<?> connection : resultList) { String providerId = connection.getKey().getProviderId(); if (connections.get(providerId).size() == 0) { connections.put(providerId, new LinkedList<Connection<?>>()); } connections.add(providerId, connection); } return connections; } public List<Connection<?>> findConnections(String providerId) { final String sql = selectFromUserConnection() + " where userId = ? and providerId = ? order by rank"; final String[] selectionArgs = { userId, providerId }; return queryForConnections(sql, selectionArgs); } @SuppressWarnings("unchecked") public <A> List<Connection<A>> findConnections(Class<A> apiType) { List<?> connections = findConnections(getProviderId(apiType)); return (List<Connection<A>>) connections; } public MultiValueMap<String, Connection<?>> findConnectionsToUsers(MultiValueMap<String, String> providerUsers) { if (providerUsers == null || providerUsers.isEmpty()) { throw new IllegalArgumentException("Unable to execute find: no providerUsers provided"); } StringBuilder providerUsersCriteriaSql = new StringBuilder(); List<String> args = new ArrayList<String>(1 + providerUsers.size() * 2); args.add(userId); for (Iterator<Entry<String, List<String>>> entries = providerUsers.entrySet().iterator(); entries.hasNext();) { Entry<String, List<String>> entry = entries.next(); providerUsersCriteriaSql.append("providerId = ? and providerUserId in (?"); args.add(entry.getKey()); for (Iterator<String> values = entry.getValue().iterator(); values.hasNext();) { String value = values.next(); args.add(value); if (values.hasNext()) { providerUsersCriteriaSql.append(", ?"); } } providerUsersCriteriaSql.append(")"); if (entries.hasNext()) { providerUsersCriteriaSql.append(" or "); } } final String sql = selectFromUserConnection() + " where userId = ? and " + providerUsersCriteriaSql + " order by providerId, rank"; final String[] selectionArgs = args.toArray(new String[0]); List<Connection<?>> resultList = queryForConnections(sql, selectionArgs); MultiValueMap<String, Connection<?>> connectionsForUsers = new LinkedMultiValueMap<String, Connection<?>>(); for (Connection<?> connection : resultList) { String providerId = connection.getKey().getProviderId(); List<String> userIds = providerUsers.get(providerId); List<Connection<?>> connections = connectionsForUsers.get(providerId); if (connections == null) { connections = new ArrayList<Connection<?>>(userIds.size()); for (int i = 0; i < userIds.size(); i++) { connections.add(null); } connectionsForUsers.put(providerId, connections); } String providerUserId = connection.getKey().getProviderUserId(); int connectionIndex = userIds.indexOf(providerUserId); connections.set(connectionIndex, connection); } return connectionsForUsers; } public Connection<?> getConnection(ConnectionKey connectionKey) { final String sql = selectFromUserConnection() + " where userId = ? and providerId = ? and providerUserId = ? order by rank"; final String[] selectionArgs = { userId, connectionKey.getProviderId(), connectionKey.getProviderUserId() }; Connection<?> connection = queryForConnection(sql, selectionArgs); if (connection == null) { throw new NoSuchConnectionException(connectionKey); } return connection; } @SuppressWarnings("unchecked") public <A> Connection<A> getConnection(Class<A> apiType, String providerUserId) { String providerId = getProviderId(apiType); return (Connection<A>) getConnection(new ConnectionKey(providerId, providerUserId)); } @SuppressWarnings("unchecked") public <A> Connection<A> getPrimaryConnection(Class<A> apiType) { String providerId = getProviderId(apiType); Connection<A> connection = (Connection<A>) findPrimaryConnection(providerId); if (connection == null) { throw new NotConnectedException(providerId); } return connection; } @SuppressWarnings("unchecked") public <A> Connection<A> findPrimaryConnection(Class<A> apiType) { String providerId = getProviderId(apiType); return (Connection<A>) findPrimaryConnection(providerId); } public void addConnection(Connection<?> connection) { try { ConnectionData data = connection.createData(); SQLiteDatabase db = repositoryHelper.getWritableDatabase(); // generate rank final String sql = "select coalesce(max(rank) + 1, 1) as rank from UserConnection where userId = ? and providerId = ?"; final String[] selectionArgs = { userId, data.getProviderId() }; Cursor c = db.rawQuery(sql, selectionArgs); c.moveToFirst(); int rank = c.getInt(c.getColumnIndex("rank")); c.close(); // insert connection ContentValues values = new ContentValues(); values.put("userId", userId); values.put("providerId", data.getProviderId()); values.put("providerUserId", data.getProviderUserId()); values.put("rank", rank); values.put("displayName", data.getDisplayName()); values.put("profileUrl", data.getProfileUrl()); values.put("imageUrl", data.getImageUrl()); values.put("accessToken", encrypt(data.getAccessToken())); values.put("secret", encrypt(data.getSecret())); values.put("refreshToken", encrypt(data.getRefreshToken())); values.put("expireTime", data.getExpireTime()); db.insertOrThrow("UserConnection", null, values); db.close(); } catch (SQLiteConstraintException e) { throw new DuplicateConnectionException(connection.getKey()); } } public void updateConnection(Connection<?> connection) { ConnectionData data = connection.createData(); SQLiteDatabase db = repositoryHelper.getWritableDatabase(); ContentValues values = new ContentValues(); values.put("displayName", data.getDisplayName()); values.put("profileUrl", data.getProfileUrl()); values.put("imageUrl", data.getImageUrl()); values.put("accessToken", encrypt(data.getAccessToken())); values.put("secret", encrypt(data.getSecret())); values.put("refreshToken", encrypt(data.getRefreshToken())); values.put("expireTime", data.getExpireTime()); final String whereClause = "userId = ? and providerId = ? and providerUserId = ?"; final String[] whereArgs = { userId, data.getProviderId(), data.getProviderUserId() }; db.update("UserConnection", values, whereClause, whereArgs); db.close(); } public void removeConnections(String providerId) { SQLiteDatabase db = repositoryHelper.getWritableDatabase(); final String whereClause = "userId = ? and providerId = ?"; final String[] whereArgs = { userId, providerId }; db.delete("UserConnection", whereClause, whereArgs); db.close(); } public void removeConnection(ConnectionKey connectionKey) { SQLiteDatabase db = repositoryHelper.getWritableDatabase(); final String whereClause = "userId = ? and providerId = ? and providerUserId = ?"; final String[] whereArgs = { userId, connectionKey.getProviderId(), connectionKey.getProviderUserId() }; db.delete("UserConnection", whereClause, whereArgs); db.close(); } // internal helpers private String selectFromUserConnection() { return "select userId, providerId, providerUserId, displayName, profileUrl, imageUrl, accessToken, secret, refreshToken, expireTime from UserConnection"; } private Connection<?> findPrimaryConnection(String providerId) { final String sql = selectFromUserConnection() + " where userId = ? and providerId = ? and rank = 1"; final String[] selectionArgs = { userId, providerId }; List<Connection<?>> connections = queryForConnections(sql, selectionArgs); if (connections.size() > 0) { return connections.get(0); } else { return null; } } private <A> String getProviderId(Class<A> apiType) { return connectionFactoryLocator.getConnectionFactory(apiType).getProviderId(); } private String encrypt(String text) { return text != null ? textEncryptor.encrypt(text) : text; } private String decrypt(String encryptedText) { return encryptedText != null ? textEncryptor.decrypt(encryptedText) : encryptedText; } private Long expireTime(long expireTime) { return expireTime == 0 ? null : expireTime; } private Connection<?> queryForConnection(final String sql, final String[] selectionArgs) { SQLiteDatabase db = repositoryHelper.getReadableDatabase(); Cursor c = null; Connection<?> connection = null; try { c = db.rawQuery(sql, selectionArgs); if (c.getCount() > 0) { c.moveToFirst(); connection = mapConnectionRow(c); } } finally { c.close(); db.close(); } return connection; } private List<Connection<?>> queryForConnections(final String sql, final String[] selectionArgs) { SQLiteDatabase db = repositoryHelper.getReadableDatabase(); Cursor c = null; List<Connection<?>> connections = new ArrayList<Connection<?>>(); try { c = db.rawQuery(sql, selectionArgs); c.moveToFirst(); for (int i = 0; i < c.getCount(); i++) { connections.add(mapConnectionRow(c)); c.moveToNext(); } } finally { c.close(); db.close(); } return connections; } private Connection<?> mapConnectionRow(Cursor c) { ConnectionData connectionData = mapConnectionData(c); ConnectionFactory<?> connectionFactory = connectionFactoryLocator.getConnectionFactory(connectionData.getProviderId()); return connectionFactory.createConnection(connectionData); } private ConnectionData mapConnectionData(Cursor c) { return new ConnectionData(c.getString(c.getColumnIndex("providerId")), c.getString(c.getColumnIndex("providerUserId")), c.getString(c.getColumnIndex("displayName")), c.getString(c.getColumnIndex("profileUrl")), c.getString(c .getColumnIndex("imageUrl")), decrypt(c.getString(c.getColumnIndex("accessToken"))), decrypt(c.getString(c.getColumnIndex("secret"))), decrypt(c.getString(c.getColumnIndex("refreshToken"))), expireTime(c.getLong(c .getColumnIndex("expireTime")))); } }