package org.rakam.postgresql.analysis; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.util.concurrent.UncheckedExecutionException; import org.rakam.analysis.ApiKeyService; import org.rakam.analysis.JDBCPoolDataSource; import org.rakam.util.CryptUtil; import org.rakam.util.RakamException; import javax.annotation.PostConstruct; import java.net.URI; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; import static java.lang.String.format; import static org.rakam.analysis.ApiKeyService.AccessKeyType.*; public class JDBCApiKeyService implements ApiKeyService { private final LoadingCache<String, List<Set<String>>> apiKeyCache; protected final JDBCPoolDataSource connectionPool; private final LoadingCache<ApiKey, String> apiKeyReverseCache; public JDBCApiKeyService(JDBCPoolDataSource connectionPool) { this.connectionPool = connectionPool; apiKeyCache = CacheBuilder.newBuilder().expireAfterWrite(1, TimeUnit.MINUTES).build(new CacheLoader<String, List<Set<String>>>() { @Override public List<Set<String>> load(String project) throws Exception { try (Connection conn = connectionPool.getConnection()) { return getKeys(conn, project); } } }); apiKeyReverseCache = CacheBuilder.newBuilder().build(new CacheLoader<ApiKey, String>() { @Override public String load(ApiKey apiKey) throws Exception { try (Connection conn = connectionPool.getConnection()) { PreparedStatement ps = conn.prepareStatement(format("SELECT lower(project) FROM api_key WHERE %s = ?", apiKey.type.name())); ps.setString(1, apiKey.key); ResultSet resultSet = ps.executeQuery(); if (!resultSet.next()) { throw new RakamException(apiKey.type.getKey() + " is invalid", FORBIDDEN); } return resultSet.getString(1); } catch (SQLException e) { throw Throwables.propagate(e); } } }); } @PostConstruct public void setup() { try (Connection connection = connectionPool.getConnection()) { Statement statement = connection.createStatement(); URI uri = URI.create(connectionPool.getConfig().getUrl().replaceAll("^jdbc:", "")); String primaryKey; if(uri.getScheme().equals("mysql")) { primaryKey = " id MEDIUMINT NOT NULL AUTO_INCREMENT,\n"; } else if(uri.getScheme().equals("postgresql")) { primaryKey = " id SERIAL,\n"; } else { throw new IllegalStateException(); } statement.execute("CREATE TABLE IF NOT EXISTS api_key (" + primaryKey + " project VARCHAR(255) NOT NULL,\n" + " read_key VARCHAR(255) NOT NULL,\n" + " write_key VARCHAR(255) NOT NULL,\n" + " master_key VARCHAR(255) NOT NULL,\n" + " created_at TIMESTAMP default current_timestamp NOT NULL," + "PRIMARY KEY (id)\n" + " )"); } catch (SQLException e) { Throwables.propagate(e); } } @Override public ProjectApiKeys createApiKeys(String project) { String masterKey = CryptUtil.generateRandomKey(64); String readKey = CryptUtil.generateRandomKey(64); String writeKey = CryptUtil.generateRandomKey(64); try (Connection connection = connectionPool.getConnection()) { PreparedStatement ps = connection.prepareStatement("INSERT INTO api_key " + "(master_key, read_key, write_key, project) VALUES (?, ?, ?, ?)", Statement.RETURN_GENERATED_KEYS); ps.setString(1, masterKey); ps.setString(2, readKey); ps.setString(3, writeKey); ps.setString(4, project); ps.executeUpdate(); final ResultSet generatedKeys = ps.getGeneratedKeys(); generatedKeys.next(); } catch (SQLException e) { throw Throwables.propagate(e); } return ProjectApiKeys.create(masterKey, readKey, writeKey); } @Override public String getProjectOfApiKey(String apiKey, AccessKeyType type) { if (type == null) { throw new IllegalStateException(); } if (apiKey == null) { throw new RakamException(type.getKey() + " is missing", FORBIDDEN); } try { return apiKeyReverseCache.getUnchecked(new ApiKey(apiKey, type)); } catch (UncheckedExecutionException e) { throw Throwables.propagate(e.getCause()); } } @Override public void revokeApiKeys(String project, String masterKey) { try (Connection conn = connectionPool.getConnection()) { PreparedStatement ps = conn.prepareStatement("DELETE FROM api_key WHERE project = ? AND master_key = ?"); ps.setString(1, project); ps.setString(2, masterKey); ps.execute(); } catch (SQLException e) { throw Throwables.propagate(e); } } @Override public void revokeAllKeys(String project) { try (Connection conn = connectionPool.getConnection()) { PreparedStatement ps = conn.prepareStatement("DELETE FROM api_key WHERE project = ?"); ps.setString(1, project); ps.execute(); } catch (SQLException e) { throw Throwables.propagate(e); } } private List<Set<String>> getKeys(Connection conn, String project) throws SQLException { Set<String> masterKeyList = new HashSet<>(); Set<String> readKeyList = new HashSet<>(); Set<String> writeKeyList = new HashSet<>(); Set<String>[] keys = Arrays.stream(AccessKeyType.values()).map(key -> new HashSet<String>()).toArray(Set[]::new); PreparedStatement ps = conn.prepareStatement("SELECT master_key, read_key, write_key from api_key WHERE project = ?"); ps.setString(1, project); ResultSet resultSet = ps.executeQuery(); while (resultSet.next()) { String apiKey; apiKey = resultSet.getString(1); if (apiKey != null) { masterKeyList.add(apiKey); } apiKey = resultSet.getString(2); if (apiKey != null) { readKeyList.add(apiKey); } apiKey = resultSet.getString(3); if (apiKey != null) { writeKeyList.add(apiKey); } } keys[MASTER_KEY.ordinal()] = Collections.unmodifiableSet(masterKeyList); keys[READ_KEY.ordinal()] = Collections.unmodifiableSet(readKeyList); keys[WRITE_KEY.ordinal()] = Collections.unmodifiableSet(writeKeyList); return Collections.unmodifiableList(Arrays.asList(keys)); } public void clearCache() { apiKeyCache.cleanUp(); } public static final class ApiKey { public final String key; public final AccessKeyType type; public ApiKey(String key, AccessKeyType type) { this.key = key; this.type = type; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (!(o instanceof ApiKey)) { return false; } ApiKey apiKey = (ApiKey) o; if (!key.equals(apiKey.key)) { return false; } return type == apiKey.type; } @Override public int hashCode() { int result = key.hashCode(); result = 31 * result + type.hashCode(); return result; } } }