/*
* Copyright (C) 2015 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package keywhiz.service.daos;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.inject.Inject;
import keywhiz.api.ApiDate;
import keywhiz.api.GroupDetailResponse;
import keywhiz.api.SecretDetailResponse;
import keywhiz.api.model.Client;
import keywhiz.api.model.Group;
import keywhiz.api.model.SanitizedSecret;
import keywhiz.api.model.Secret;
import keywhiz.api.model.SecretContent;
import keywhiz.api.model.SecretSeries;
import keywhiz.api.model.SecretSeriesAndContent;
import keywhiz.jooq.tables.records.SecretsRecord;
import keywhiz.log.AuditLog;
import keywhiz.log.Event;
import keywhiz.log.EventTag;
import keywhiz.service.config.Readonly;
import keywhiz.service.daos.ClientDAO.ClientDAOFactory;
import keywhiz.service.daos.GroupDAO.GroupDAOFactory;
import keywhiz.service.daos.SecretContentDAO.SecretContentDAOFactory;
import keywhiz.service.daos.SecretSeriesDAO.SecretSeriesDAOFactory;
import org.jooq.Configuration;
import org.jooq.DSLContext;
import org.jooq.Record;
import org.jooq.SelectQuery;
import org.jooq.impl.DSL;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static java.lang.String.format;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
import static keywhiz.jooq.tables.Accessgrants.ACCESSGRANTS;
import static keywhiz.jooq.tables.Clients.CLIENTS;
import static keywhiz.jooq.tables.Groups.GROUPS;
import static keywhiz.jooq.tables.Memberships.MEMBERSHIPS;
import static keywhiz.jooq.tables.Secrets.SECRETS;
import static keywhiz.jooq.tables.SecretsContent.SECRETS_CONTENT;
public class AclDAO {
private static final Logger logger = LoggerFactory.getLogger(AclDAO.class);
private final DSLContext dslContext;
private final ClientDAOFactory clientDAOFactory;
private final GroupDAOFactory groupDAOFactory;
private final SecretContentDAOFactory secretContentDAOFactory;
private final SecretSeriesDAOFactory secretSeriesDAOFactory;
private final ClientMapper clientMapper;
private final GroupMapper groupMapper;
private final SecretSeriesMapper secretSeriesMapper;
private final SecretContentMapper secretContentMapper;
private AclDAO(DSLContext dslContext, ClientDAOFactory clientDAOFactory, GroupDAOFactory groupDAOFactory,
SecretContentDAOFactory secretContentDAOFactory, SecretSeriesDAOFactory secretSeriesDAOFactory,
ClientMapper clientMapper, GroupMapper groupMapper, SecretSeriesMapper secretSeriesMapper,
SecretContentMapper secretContentMapper) {
this.dslContext = dslContext;
this.clientDAOFactory = clientDAOFactory;
this.groupDAOFactory = groupDAOFactory;
this.secretContentDAOFactory = secretContentDAOFactory;
this.secretSeriesDAOFactory = secretSeriesDAOFactory;
this.clientMapper = clientMapper;
this.groupMapper = groupMapper;
this.secretSeriesMapper = secretSeriesMapper;
this.secretContentMapper = secretContentMapper;
}
public void findAndAllowAccess(long secretId, long groupId, AuditLog auditLog, String user, Map<String, String> extraInfo) {
dslContext.transaction(configuration -> {
GroupDAO groupDAO = groupDAOFactory.using(configuration);
SecretSeriesDAO secretSeriesDAO = secretSeriesDAOFactory.using(configuration);
Optional<Group> group = groupDAO.getGroupById(groupId);
if (!group.isPresent()) {
logger.info("Failure to allow access groupId {}, secretId {}: groupId not found.", groupId,
secretId);
throw new IllegalStateException(format("GroupId %d doesn't exist.", groupId));
}
Optional<SecretSeries> secret = secretSeriesDAO.getSecretSeriesById(secretId);
if (!secret.isPresent()) {
logger.info("Failure to allow access groupId {}, secretId {}: secretId not found.", groupId,
secretId);
throw new IllegalStateException(format("SecretId %d doesn't exist.", secretId));
}
allowAccess(configuration, secretId, groupId);
extraInfo.put("group", group.get().getName());
extraInfo.put("secret added", secret.get().name());
auditLog.recordEvent(new Event(Instant.now(), EventTag.CHANGEACL_GROUP_SECRET, user, group.get().getName(), extraInfo));
});
}
public void findAndRevokeAccess(long secretId, long groupId, AuditLog auditLog, String user, Map<String, String> extraInfo) {
dslContext.transaction(configuration -> {
GroupDAO groupDAO = groupDAOFactory.using(configuration);
SecretSeriesDAO secretSeriesDAO = secretSeriesDAOFactory.using(configuration);
Optional<Group> group = groupDAO.getGroupById(groupId);
if (!group.isPresent()) {
logger.info("Failure to revoke access groupId {}, secretId {}: groupId not found.", groupId,
secretId);
throw new IllegalStateException(format("GroupId %d doesn't exist.", groupId));
}
Optional<SecretSeries> secret = secretSeriesDAO.getSecretSeriesById(secretId);
if (!secret.isPresent()) {
logger.info("Failure to revoke access groupId {}, secretId {}: secretId not found.",
groupId, secretId);
throw new IllegalStateException(format("SecretId %d doesn't exist.", secretId));
}
revokeAccess(configuration, secretId, groupId);
extraInfo.put("group", group.get().getName());
extraInfo.put("secret removed", secret.get().name());
auditLog.recordEvent(new Event(Instant.now(), EventTag.CHANGEACL_GROUP_SECRET, user, group.get().getName(), extraInfo));
});
}
public void findAndEnrollClient(long clientId, long groupId, AuditLog auditLog, String user, Map<String, String> extraInfo) {
dslContext.transaction(configuration -> {
ClientDAO clientDAO = clientDAOFactory.using(configuration);
GroupDAO groupDAO = groupDAOFactory.using(configuration);
Optional<Client> client = clientDAO.getClientById(clientId);
if (!client.isPresent()) {
logger.info("Failure to enroll membership clientId {}, groupId {}: clientId not found.",
clientId, groupId);
throw new IllegalStateException(format("ClientId %d doesn't exist.", clientId));
}
Optional<Group> group = groupDAO.getGroupById(groupId);
if (!group.isPresent()) {
logger.info("Failure to enroll membership clientId {}, groupId {}: groupId not found.",
clientId, groupId);
throw new IllegalStateException(format("GroupId %d doesn't exist.", groupId));
}
enrollClient(configuration, clientId, groupId);
extraInfo.put("group", group.get().getName());
extraInfo.put("client added", client.get().getName());
auditLog.recordEvent(new Event(Instant.now(), EventTag.CHANGEACL_GROUP_CLIENT, user, group.get().getName(), extraInfo));
});
}
public void findAndEvictClient(long clientId, long groupId, AuditLog auditLog, String user, Map<String, String> extraInfo) {
dslContext.transaction(configuration -> {
ClientDAO clientDAO = clientDAOFactory.using(configuration);
GroupDAO groupDAO = groupDAOFactory.using(configuration);
Optional<Client> client = clientDAO.getClientById(clientId);
if (!client.isPresent()) {
logger.info("Failure to evict membership clientId {}, groupId {}: clientId not found.",
clientId, groupId);
throw new IllegalStateException(format("ClientId %d doesn't exist.", clientId));
}
Optional<Group> group = groupDAO.getGroupById(groupId);
if (!group.isPresent()) {
logger.info("Failure to evict membership clientId {}, groupId {}: groupId not found.",
clientId, groupId);
throw new IllegalStateException(format("GroupId %d doesn't exist.", groupId));
}
evictClient(configuration, clientId, groupId);
extraInfo.put("group", group.get().getName());
extraInfo.put("client removed", client.get().getName());
auditLog.recordEvent(new Event(Instant.now(), EventTag.CHANGEACL_GROUP_CLIENT, user, group.get().getName(), extraInfo));
});
}
public ImmutableSet<SanitizedSecret> getSanitizedSecretsFor(Group group) {
checkNotNull(group);
ImmutableSet.Builder<SanitizedSecret> set = ImmutableSet.builder();
return dslContext.transactionResult(configuration -> {
SecretContentDAO secretContentDAO = secretContentDAOFactory.using(configuration);
for (SecretSeries series : getSecretSeriesFor(configuration, group)) {
SecretContent content = secretContentDAO.getSecretContentById(series.currentVersion().get()).get();
SecretSeriesAndContent seriesAndContent = SecretSeriesAndContent.of(series, content);
set.add(SanitizedSecret.fromSecretSeriesAndContent(seriesAndContent));
}
return set.build();
});
}
public Set<Group> getGroupsFor(Secret secret) {
List<Group> r = dslContext
.select(GROUPS.fields())
.from(GROUPS)
.join(ACCESSGRANTS).on(GROUPS.ID.eq(ACCESSGRANTS.GROUPID))
.join(SECRETS).on(ACCESSGRANTS.SECRETID.eq(SECRETS.ID))
.where(SECRETS.NAME.eq(secret.getName()))
.fetchInto(GROUPS)
.map(groupMapper);
return new HashSet<>(r);
}
public Set<Group> getGroupsFor(Client client) {
List<Group> r = dslContext
.select(GROUPS.fields())
.from(GROUPS)
.join(MEMBERSHIPS).on(GROUPS.ID.eq(MEMBERSHIPS.GROUPID))
.join(CLIENTS).on(CLIENTS.ID.eq(MEMBERSHIPS.CLIENTID))
.where(CLIENTS.NAME.eq(client.getName()))
.fetchInto(GROUPS)
.map(groupMapper);
return new HashSet<>(r);
}
public Set<Client> getClientsFor(Group group) {
List<Client> r = dslContext
.select(CLIENTS.fields())
.from(CLIENTS)
.join(MEMBERSHIPS).on(CLIENTS.ID.eq(MEMBERSHIPS.CLIENTID))
.join(GROUPS).on(GROUPS.ID.eq(MEMBERSHIPS.GROUPID))
.where(GROUPS.NAME.eq(group.getName()))
.fetchInto(CLIENTS)
.map(clientMapper);
return new HashSet<>(r);
}
public ImmutableSet<SanitizedSecret> getSanitizedSecretsFor(Client client) {
checkNotNull(client);
ImmutableSet.Builder<SanitizedSecret> sanitizedSet = ImmutableSet.builder();
SelectQuery<Record> query = dslContext.select(SECRETS.fields())
.from(SECRETS)
.join(ACCESSGRANTS).on(SECRETS.ID.eq(ACCESSGRANTS.SECRETID))
.join(MEMBERSHIPS).on(ACCESSGRANTS.GROUPID.eq(MEMBERSHIPS.GROUPID))
.join(CLIENTS).on(CLIENTS.ID.eq(MEMBERSHIPS.CLIENTID))
.join(SECRETS_CONTENT).on(SECRETS_CONTENT.ID.eq(SECRETS.CURRENT))
.where(CLIENTS.NAME.eq(client.getName()).and(SECRETS.CURRENT.isNotNull()))
.getQuery();
query.addSelect(SECRETS_CONTENT.CONTENT_HMAC);
query.addSelect(SECRETS_CONTENT.CREATEDAT);
query.addSelect(SECRETS_CONTENT.CREATEDBY);
query.addSelect(SECRETS_CONTENT.UPDATEDAT);
query.addSelect(SECRETS_CONTENT.UPDATEDBY);
query.addSelect(SECRETS_CONTENT.METADATA);
query.addSelect(SECRETS_CONTENT.EXPIRY);
query.fetch()
.map(row -> {
SecretSeries series = secretSeriesMapper.map(row.into(SECRETS));
return SanitizedSecret.of(
series.id(),
series.name(),
series.description(),
row.getValue(SECRETS_CONTENT.CONTENT_HMAC),
new ApiDate(row.getValue(SECRETS_CONTENT.CREATEDAT)),
row.getValue(SECRETS_CONTENT.CREATEDBY),
new ApiDate(row.getValue(SECRETS_CONTENT.UPDATEDAT)),
row.getValue(SECRETS_CONTENT.UPDATEDBY),
secretContentMapper.tryToReadMapFromMetadata(row.getValue(SECRETS_CONTENT.METADATA)),
series.type().orElse(null),
series.generationOptions(),
row.getValue(SECRETS_CONTENT.EXPIRY),
series.currentVersion().orElse(null));
})
.forEach(row -> sanitizedSet.add(row));
return sanitizedSet.build();
}
public Set<Client> getClientsFor(Secret secret) {
List<Client> r = dslContext
.select(CLIENTS.fields())
.from(CLIENTS)
.join(MEMBERSHIPS).on(CLIENTS.ID.eq(MEMBERSHIPS.CLIENTID))
.join(ACCESSGRANTS).on(MEMBERSHIPS.GROUPID.eq(ACCESSGRANTS.GROUPID))
.join(SECRETS).on(SECRETS.ID.eq(ACCESSGRANTS.SECRETID))
.where(SECRETS.NAME.eq(secret.getName()))
.fetchInto(CLIENTS)
.map(clientMapper);
return new HashSet<>(r);
}
public Optional<SanitizedSecret> getSanitizedSecretFor(Client client, String secretName) {
checkNotNull(client);
checkArgument(!secretName.isEmpty());
SelectQuery<Record> query = dslContext.select(SECRETS.fields())
.from(SECRETS)
.join(ACCESSGRANTS).on(SECRETS.ID.eq(ACCESSGRANTS.SECRETID))
.join(MEMBERSHIPS).on(ACCESSGRANTS.GROUPID.eq(MEMBERSHIPS.GROUPID))
.join(CLIENTS).on(CLIENTS.ID.eq(MEMBERSHIPS.CLIENTID))
.join(SECRETS_CONTENT).on(SECRETS_CONTENT.ID.eq(SECRETS.CURRENT))
.where(CLIENTS.NAME.eq(client.getName())
.and(SECRETS.CURRENT.isNotNull())
.and(SECRETS.NAME.eq(secretName)))
.limit(1)
.getQuery();
query.addSelect(SECRETS_CONTENT.CONTENT_HMAC);
query.addSelect(SECRETS_CONTENT.CREATEDAT);
query.addSelect(SECRETS_CONTENT.CREATEDBY);
query.addSelect(SECRETS_CONTENT.UPDATEDAT);
query.addSelect(SECRETS_CONTENT.UPDATEDBY);
query.addSelect(SECRETS_CONTENT.METADATA);
query.addSelect(SECRETS_CONTENT.EXPIRY);
return Optional.ofNullable(query.fetchOne())
.map(row -> {
SecretSeries series = secretSeriesMapper.map(row.into(SECRETS));
return SanitizedSecret.of(
series.id(),
series.name(),
series.description(),
row.getValue(SECRETS_CONTENT.CONTENT_HMAC),
new ApiDate(row.getValue(SECRETS_CONTENT.CREATEDAT)),
row.getValue(SECRETS_CONTENT.CREATEDBY),
new ApiDate(row.getValue(SECRETS_CONTENT.UPDATEDAT)),
row.getValue(SECRETS_CONTENT.UPDATEDBY),
secretContentMapper.tryToReadMapFromMetadata(row.getValue(SECRETS_CONTENT.METADATA)),
series.type().orElse(null),
series.generationOptions(),
row.getValue(SECRETS_CONTENT.EXPIRY),
series.currentVersion().orElse(null));
});
}
public Map<Long, List<Group>> getGroupsForSecrets(Set<Long> secretIdList) {
Map<Long, Group> groupMap = dslContext.select().from(GROUPS)
.join(ACCESSGRANTS).on(ACCESSGRANTS.GROUPID.eq(GROUPS.ID))
.join(SECRETS).on(ACCESSGRANTS.SECRETID.eq(SECRETS.ID))
.where(SECRETS.ID.in(secretIdList))
.fetchInto(GROUPS).map(groupMapper).stream().collect(Collectors.toMap(Group::getId, g -> g, (g1, g2) -> g1));
Map<Long, List<Long>> secretsIdGroupsIdMap = dslContext.select().from(GROUPS)
.join(ACCESSGRANTS).on(ACCESSGRANTS.GROUPID.eq(GROUPS.ID))
.join(SECRETS).on(ACCESSGRANTS.SECRETID.eq(SECRETS.ID))
.where(SECRETS.ID.in(secretIdList))
.fetch().intoGroups(SECRETS.ID, GROUPS.ID);
ImmutableMap.Builder<Long, List<Group>> builder = ImmutableMap.builder();
for (Map.Entry<Long, List<Long>> entry : secretsIdGroupsIdMap.entrySet()) {
List<Group> groupList = entry.getValue().stream().map(groupMap::get).collect(toList());
builder.put(entry.getKey(), groupList);
}
return builder.build();
}
protected void allowAccess(Configuration configuration, long secretId, long groupId) {
long now = OffsetDateTime.now().toEpochSecond();
boolean assigned = 0 < DSL.using(configuration)
.fetchCount(ACCESSGRANTS,
ACCESSGRANTS.SECRETID.eq(secretId).and(
ACCESSGRANTS.GROUPID.eq(groupId)));
if (assigned) {
return;
}
DSL.using(configuration)
.insertInto(ACCESSGRANTS)
.set(ACCESSGRANTS.SECRETID, secretId)
.set(ACCESSGRANTS.GROUPID, groupId)
.set(ACCESSGRANTS.CREATEDAT, now)
.set(ACCESSGRANTS.UPDATEDAT, now)
.execute();
}
protected void revokeAccess(Configuration configuration, long secretId, long groupId) {
DSL.using(configuration)
.delete(ACCESSGRANTS)
.where(ACCESSGRANTS.SECRETID.eq(secretId)
.and(ACCESSGRANTS.GROUPID.eq(groupId)))
.execute();
}
protected void enrollClient(Configuration configuration, long clientId, long groupId) {
long now = OffsetDateTime.now().toEpochSecond();
boolean enrolled = 0 < DSL.using(configuration)
.fetchCount(MEMBERSHIPS,
MEMBERSHIPS.GROUPID.eq(groupId).and(
MEMBERSHIPS.CLIENTID.eq(clientId)));
if (enrolled) {
return;
}
DSL.using(configuration)
.insertInto(MEMBERSHIPS)
.set(MEMBERSHIPS.GROUPID, groupId)
.set(MEMBERSHIPS.CLIENTID, clientId)
.set(MEMBERSHIPS.CREATEDAT, now)
.set(MEMBERSHIPS.UPDATEDAT, now)
.execute();
}
protected void evictClient(Configuration configuration, long clientId, long groupId) {
DSL.using(configuration)
.delete(MEMBERSHIPS)
.where(MEMBERSHIPS.CLIENTID.eq(clientId)
.and(MEMBERSHIPS.GROUPID.eq(groupId)))
.execute();
}
protected ImmutableSet<SecretSeries> getSecretSeriesFor(Configuration configuration, Group group) {
List<SecretSeries> r = DSL.using(configuration)
.select(SECRETS.fields())
.from(SECRETS)
.join(ACCESSGRANTS).on(SECRETS.ID.eq(ACCESSGRANTS.SECRETID))
.join(GROUPS).on(GROUPS.ID.eq(ACCESSGRANTS.GROUPID))
.where(GROUPS.NAME.eq(group.getName()).and(SECRETS.CURRENT.isNotNull()))
.fetchInto(SECRETS)
.map(secretSeriesMapper);
return ImmutableSet.copyOf(r);
}
/**
* @param client client to access secrets
* @param secretName name of SecretSeries
* @return Optional.absent() when secret unauthorized or not found.
* The query doesn't distinguish between these cases. If result absent, a followup call on clients
* table should be used to determine the exception.
*/
protected Optional<SecretSeries> getSecretSeriesFor(Configuration configuration, Client client, String secretName) {
// TODO: We need to set limit(1) because we are using joins. We should probably change the join type.
SecretsRecord r = DSL.using(configuration)
.select(SECRETS.fields())
.from(SECRETS)
.join(ACCESSGRANTS).on(SECRETS.ID.eq(ACCESSGRANTS.SECRETID))
.join(MEMBERSHIPS).on(ACCESSGRANTS.GROUPID.eq(MEMBERSHIPS.GROUPID))
.join(CLIENTS).on(CLIENTS.ID.eq(MEMBERSHIPS.CLIENTID))
.where(SECRETS.NAME.eq(secretName).and(CLIENTS.NAME.eq(client.getName())).and(SECRETS.CURRENT.isNotNull()))
.limit(1)
.fetchOneInto(SECRETS);
return Optional.ofNullable(r).map(secretSeriesMapper::map);
}
public static class AclDAOFactory implements DAOFactory<AclDAO> {
private final DSLContext jooq;
private final DSLContext readonlyJooq;
private final ClientDAOFactory clientDAOFactory;
private final GroupDAOFactory groupDAOFactory;
private final SecretContentDAOFactory secretContentDAOFactory;
private final SecretSeriesDAOFactory secretSeriesDAOFactory;
private final ClientMapper clientMapper;
private final GroupMapper groupMapper;
private final SecretSeriesMapper secretSeriesMapper;
private final SecretContentMapper secretContentMapper;
@Inject public AclDAOFactory(DSLContext jooq, @Readonly DSLContext readonlyJooq, ClientDAOFactory clientDAOFactory,
GroupDAOFactory groupDAOFactory, SecretContentDAOFactory secretContentDAOFactory,
SecretSeriesDAOFactory secretSeriesDAOFactory, ClientMapper clientMapper,
GroupMapper groupMapper, SecretSeriesMapper secretSeriesMapper,
SecretContentMapper secretContentMapper) {
this.jooq = jooq;
this.readonlyJooq = readonlyJooq;
this.clientDAOFactory = clientDAOFactory;
this.groupDAOFactory = groupDAOFactory;
this.secretContentDAOFactory = secretContentDAOFactory;
this.secretSeriesDAOFactory = secretSeriesDAOFactory;
this.clientMapper = clientMapper;
this.groupMapper = groupMapper;
this.secretSeriesMapper = secretSeriesMapper;
this.secretContentMapper = secretContentMapper;
}
@Override public AclDAO readwrite() {
return new AclDAO(jooq, clientDAOFactory, groupDAOFactory, secretContentDAOFactory,
secretSeriesDAOFactory, clientMapper, groupMapper, secretSeriesMapper, secretContentMapper);
}
@Override public AclDAO readonly() {
return new AclDAO(readonlyJooq, clientDAOFactory, groupDAOFactory, secretContentDAOFactory,
secretSeriesDAOFactory, clientMapper, groupMapper, secretSeriesMapper, secretContentMapper);
}
@Override public AclDAO using(Configuration configuration) {
DSLContext dslContext = DSL.using(checkNotNull(configuration));
return new AclDAO(dslContext, clientDAOFactory, groupDAOFactory, secretContentDAOFactory,
secretSeriesDAOFactory, clientMapper, groupMapper, secretSeriesMapper, secretContentMapper);
}
}
}