package org.molgenis.data.postgresql; import com.google.common.base.Stopwatch; import com.google.common.collect.*; import org.molgenis.data.Entity; import org.molgenis.data.Fetch; import org.molgenis.data.Query; import org.molgenis.data.QueryRule.Operator; import org.molgenis.data.RepositoryCapability; import org.molgenis.data.meta.AttributeType; import org.molgenis.data.meta.model.Attribute; import org.molgenis.data.meta.model.EntityType; import org.molgenis.data.support.AbstractRepository; import org.molgenis.data.support.BatchingQueryResult; import org.molgenis.data.support.QueryImpl; import org.molgenis.data.validation.ConstraintViolation; import org.molgenis.data.validation.MolgenisValidationException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.jdbc.core.BatchPreparedStatementSetter; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.RowMapper; import javax.sql.DataSource; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.stream.Stream; import static com.google.common.base.Stopwatch.createStarted; import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Maps.newHashMap; import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.counting; import static java.util.stream.Collectors.toList; import static java.util.stream.StreamSupport.stream; import static org.molgenis.data.QueryRule.Operator.*; import static org.molgenis.data.RepositoryCapability.*; import static org.molgenis.data.meta.AttributeType.ONE_TO_MANY; import static org.molgenis.data.postgresql.PostgreSqlQueryGenerator.*; import static org.molgenis.data.postgresql.PostgreSqlQueryUtils.*; import static org.molgenis.data.postgresql.PostgreSqlUtils.getPostgreSqlValue; import static org.molgenis.data.support.EntityTypeUtils.isMultipleReferenceType; /** * Repository that persists entities in a PostgreSQL database * <ul> * <li>Attributes with expression are not persisted</li> * <li>Cross-backend attribute references are supported</li> * <li>Query operators DIS_MAX, FUZZY_MATCH, FUZZY_MATCH_NGRAM, SEARCH, SHOULD are not supported</li> * </ul> */ class PostgreSqlRepository extends AbstractRepository { private static final Logger LOG = LoggerFactory.getLogger(PostgreSqlRepository.class); /** * JDBC batch operation size */ private static final int BATCH_SIZE = 1000; /** * Repository capabilities */ private static final Set<RepositoryCapability> REPO_CAPABILITIES = unmodifiableSet( EnumSet.of(WRITABLE, MANAGABLE, QUERYABLE, VALIDATE_REFERENCE_CONSTRAINT, VALIDATE_UNIQUE_CONSTRAINT, VALIDATE_NOTNULL_CONSTRAINT, VALIDATE_READONLY_CONSTRAINT, CACHEABLE)); /** * Supported query operators */ private static final Set<Operator> QUERY_OPERATORS = unmodifiableSet( EnumSet.of(EQUALS, IN, LESS, LESS_EQUAL, GREATER, GREATER_EQUAL, RANGE, LIKE, NOT, AND, OR, NESTED)); private final PostgreSqlEntityFactory postgreSqlEntityFactory; private final JdbcTemplate jdbcTemplate; private final DataSource dataSource; private EntityType entityType; PostgreSqlRepository(PostgreSqlEntityFactory postgreSqlEntityFactory, JdbcTemplate jdbcTemplate, DataSource dataSource) { this.postgreSqlEntityFactory = requireNonNull(postgreSqlEntityFactory); this.jdbcTemplate = requireNonNull(jdbcTemplate); this.dataSource = requireNonNull(dataSource); } void setEntityType(EntityType entityType) { this.entityType = entityType; } @Override public Iterator<Entity> iterator() { Query<Entity> q = new QueryImpl<>(); return findAllBatching(q).iterator(); } @Override public Set<RepositoryCapability> getCapabilities() { return REPO_CAPABILITIES; } @Override public Set<Operator> getQueryOperators() { return QUERY_OPERATORS; } @Override public EntityType getEntityType() { return entityType; } @Override public long count(Query<Entity> q) { List<Object> parameters = Lists.newArrayList(); String sql = getSqlCount(entityType, q, parameters); if (LOG.isDebugEnabled()) { LOG.debug("Counting [{}] rows for query [{}]", getName(), q); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}, parameters: {}", sql, parameters); } } return jdbcTemplate.queryForObject(sql, parameters.toArray(new Object[parameters.size()]), Long.class); } @Override public Stream<Entity> findAll(Query<Entity> q) { return stream(findAllBatching(q).spliterator(), false); } @Override public Entity findOne(Query<Entity> q) { Iterator<Entity> iterator = findAll(q).iterator(); if (iterator.hasNext()) { return iterator.next(); } return null; } @Override public Entity findOneById(Object id) { if (id == null) { return null; } return findOne(new QueryImpl<>().eq(entityType.getIdAttribute().getName(), id)); } @Override public Entity findOneById(Object id, Fetch fetch) { if (id == null) { return null; } return findOne(new QueryImpl<>().eq(entityType.getIdAttribute().getName(), id).fetch(fetch)); } @Override public void update(Entity entity) { update(Stream.of(entity)); } @Override public void update(Stream<Entity> entities) { updateBatching(entities.iterator()); } @Override public void delete(Entity entity) { this.delete(Stream.of(entity)); } @Override public void delete(Stream<Entity> entities) { deleteAll(entities.map(Entity::getIdValue)); } @Override public void deleteById(Object id) { this.deleteAll(Stream.of(id)); } @Override public void deleteAll(Stream<Object> ids) { Iterators.partition(ids.iterator(), BATCH_SIZE).forEachRemaining(idsBatch -> { String sql = getSqlDelete(entityType); if (LOG.isDebugEnabled()) { LOG.debug("Deleting {} [{}] entities", idsBatch.size(), getName()); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}", sql); } } jdbcTemplate.batchUpdate(sql, new BatchDeletePreparedStatementSetter(idsBatch)); }); } @Override public void deleteAll() { String deleteAllSql = getSqlDeleteAll(entityType); if (LOG.isDebugEnabled()) { LOG.debug("Deleting all [{}] entities", getName()); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}", deleteAllSql); } } jdbcTemplate.update(deleteAllSql); } @Override public void add(Entity entity) { if (entity == null) { throw new RuntimeException("PostgreSqlRepository.add() failed: entity was null"); } add(Stream.of(entity)); } @Override public Integer add(Stream<Entity> entities) { return addBatching(entities.iterator()); } @Override public void forEachBatched(Fetch fetch, Consumer<List<Entity>> consumer, int batchSize) { final Stopwatch stopwatch = createStarted(); final JdbcTemplate template = new JdbcTemplate(dataSource); template.setFetchSize(batchSize); final Query<Entity> query = new QueryImpl<>(); if (fetch != null) { query.fetch(fetch); } final EntityType entityType = this.entityType; final String allRowsSelect = getSqlSelect(entityType, query, emptyList(), false); LOG.debug("Fetching [{}] data...", getName()); LOG.trace("SQL: {}", allRowsSelect); RowMapper<Entity> rowMapper = postgreSqlEntityFactory.createRowMapper(entityType, fetch); template.query(allRowsSelect, (ResultSetExtractor) resultSet -> processResultSet(consumer, batchSize, entityType, rowMapper, resultSet)); LOG.debug("Streamed entire repository in batches of size {} in {}.", batchSize, stopwatch); } private Object processResultSet(Consumer<List<Entity>> consumer, int batchSize, EntityType entityType, RowMapper<Entity> rowMapper, ResultSet resultSet) throws SQLException { int rowNum = 0; Map<Object, Entity> batch = newHashMap(); while (resultSet.next()) { Entity entity = rowMapper.mapRow(resultSet, rowNum++); batch.put(entity.getIdValue(), entity); if (rowNum % batchSize == 0) { handleBatch(consumer, entityType, batch); batch = newHashMap(); } } if (!batch.isEmpty()) { handleBatch(consumer, entityType, batch); } return null; } /** * Handles a batch of Entities. Looks up the values for MREF ID attributes and sets them as references in the * entities. Then feeds the entities to the {@link Consumer} * * @param consumer {@link Consumer} to feed the batch to after setting the MREF ID values * @param entityType EntityType for the {@link Entity}s in the batch * @param batch {@link Map} mapping entity ID to entity for all {@link Entity}s in the batch */ private void handleBatch(Consumer<List<Entity>> consumer, EntityType entityType, Map<Object, Entity> batch) { AttributeType idAttributeDataType = entityType.getIdAttribute().getDataType(); LOG.debug("Select ID values for a batch of MREF attributes..."); for (Attribute mrefAttr : entityType.getAtomicAttributes()) { if (mrefAttr.getExpression() == null && isMultipleReferenceType(mrefAttr) && !( mrefAttr.getDataType() == ONE_TO_MANY && mrefAttr.isMappedBy())) { EntityType refEntityType = mrefAttr.getRefEntity(); Multimap<Object, Object> mrefIDs = selectMrefIDsForAttribute(entityType, idAttributeDataType, mrefAttr, batch.keySet(), refEntityType.getIdAttribute().getDataType()); for (Map.Entry entry : batch.entrySet()) { batch.get(entry.getKey()).set(mrefAttr.getName(), postgreSqlEntityFactory .getReferences(refEntityType, newArrayList(mrefIDs.get(entry.getKey())))); } } } LOG.trace("Feeding batch of {} rows to consumer.", batch.size()); consumer.accept(batch.values().stream().collect(toList())); } /** * Selects MREF IDs for an MREF attribute from the junction table, in the order of the MREF attribute value. * * @param entityType EntityType for the entities * @param idAttributeDataType {@link AttributeType} of the ID attribute of the entity * @param mrefAttr Attribute of the MREF attribute to select the values for * @param ids {@link Set} of {@link Object}s containing the values for the ID attribute of the entity * @param refIdDataType {@link AttributeType} of the ID attribute of the refEntity of the attribute * @return Multimap mapping entity ID to a list containing the MREF IDs for the values in the attribute */ private Multimap<Object, Object> selectMrefIDsForAttribute(EntityType entityType, AttributeType idAttributeDataType, Attribute mrefAttr, Set<Object> ids, AttributeType refIdDataType) { Stopwatch stopwatch = null; if (LOG.isTraceEnabled()) stopwatch = createStarted(); String junctionTableSelect = getSqlJunctionTableSelect(entityType, mrefAttr, ids.size()); LOG.trace("SQL: {}", junctionTableSelect); Multimap<Object, Object> mrefIDs = ArrayListMultimap.create(); jdbcTemplate.query(junctionTableSelect, row -> { Object id; switch (idAttributeDataType) { case EMAIL: case HYPERLINK: case STRING: id = row.getString(1); break; case INT: id = row.getInt(1); break; case LONG: id = row.getLong(1); break; default: throw new RuntimeException(format("Unexpected id attribute type [%s]", idAttributeDataType)); } Object refId; switch (refIdDataType) { case EMAIL: case HYPERLINK: case STRING: refId = row.getString(3); break; case INT: refId = row.getInt(3); break; case LONG: refId = row.getLong(3); break; default: throw new RuntimeException(format("Unexpected id attribute type [%s]", refIdDataType)); } mrefIDs.put(id, refId); }, ids.toArray()); if (LOG.isTraceEnabled()) LOG.trace("Selected {} ID values for MREF attribute {} in {}", mrefIDs.values().stream().collect(counting()), mrefAttr.getName(), stopwatch); return mrefIDs; } private BatchingQueryResult<Entity> findAllBatching(Query<Entity> q) { return new BatchingQueryResult<Entity>(BATCH_SIZE, q) { @Override protected List<Entity> getBatch(Query<Entity> batchQuery) { List<Object> parameters = new ArrayList<>(); String sql = getSqlSelect(getEntityType(), batchQuery, parameters, true); RowMapper<Entity> entityMapper = postgreSqlEntityFactory .createRowMapper(getEntityType(), batchQuery.getFetch()); LOG.debug("Fetching [{}] data for query [{}]", getName(), batchQuery); LOG.trace("SQL: {}, parameters: {}", sql, parameters); Stopwatch sw = createStarted(); List<Entity> result = jdbcTemplate .query(sql, parameters.toArray(new Object[parameters.size()]), entityMapper); LOG.trace("That took {}", sw); return result; } }; } private Integer addBatching(Iterator<? extends Entity> entities) { AtomicInteger count = new AtomicInteger(); final Attribute idAttr = entityType.getIdAttribute(); final List<Attribute> tableAttrs = getTableAttributes(entityType).collect(toList()); final List<Attribute> junctionTableAttrs = getJunctionTableAttributes(entityType).collect(toList()); final String insertSql = getSqlInsert(entityType); Iterators.partition(entities, BATCH_SIZE).forEachRemaining(entitiesBatch -> { if (LOG.isDebugEnabled()) { LOG.debug("Adding {} [{}] entities", entitiesBatch.size(), getName()); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}", insertSql); } } // persist values in entity table jdbcTemplate.batchUpdate(insertSql, new BatchAddPreparedStatementSetter(entitiesBatch, tableAttrs)); // persist values in entity junction table if (!junctionTableAttrs.isEmpty()) { Map<String, List<Map<String, Object>>> mrefs = createMrefMap(idAttr, junctionTableAttrs, entitiesBatch); for (Attribute attr : junctionTableAttrs) { List<Map<String, Object>> attrMrefs = mrefs.get(attr.getName()); if (attrMrefs != null && !attrMrefs.isEmpty()) { addMrefs(attrMrefs, attr); } } } count.addAndGet(entitiesBatch.size()); }); return count.get(); } private static Map<String, List<Map<String, Object>>> createMrefMap(Attribute idAttr, List<Attribute> junctionTableAttrs, List<? extends Entity> entitiesBatch) { Map<String, List<Map<String, Object>>> mrefs = Maps.newHashMapWithExpectedSize(junctionTableAttrs.size()); AtomicInteger seqNr = new AtomicInteger(); for (Entity entity : entitiesBatch) { for (Attribute attr : junctionTableAttrs) { Iterable<Entity> refEntities = entity.getEntities(attr.getName()); // Not-Null constraint doesn't exist for MREF attributes since they are stored in junction tables, // so validate manually. if (!attr.isNillable() && Iterables.isEmpty(refEntities)) { throw new MolgenisValidationException(new ConstraintViolation( String.format("The attribute [%s] of entity [%s] with id [%s] can not be null.", attr.getName(), attr.getEntity().getName(), entity.getIdValue().toString()))); } mrefs.putIfAbsent(attr.getName(), new ArrayList<>()); seqNr.set(0); for (Entity val : refEntities) { Map<String, Object> mref = Maps.newHashMapWithExpectedSize(3); mref.put(JUNCTION_TABLE_ORDER_ATTR_NAME, seqNr.getAndIncrement()); mref.put(idAttr.getName(), entity.get(idAttr.getName())); mref.put(attr.getName(), val); mrefs.get(attr.getName()).add(mref); } } } return mrefs; } private void updateBatching(Iterator<? extends Entity> entities) { final Attribute idAttr = entityType.getIdAttribute(); final List<Attribute> tableAttrs = getTableAttributes(entityType).collect(toList()); final List<Attribute> junctionTableAttrs = getJunctionTableAttributes(entityType) .filter(attr -> !attr.isReadOnly()).collect(toList()); final String updateSql = getSqlUpdate(entityType); // update values in entity table Iterators.partition(entities, BATCH_SIZE).forEachRemaining(entitiesBatch -> { if (LOG.isDebugEnabled()) { LOG.debug("Updating {} [{}] entities", entitiesBatch.size(), getName()); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}", updateSql); } } jdbcTemplate .batchUpdate(updateSql, new BatchUpdatePreparedStatementSetter(entitiesBatch, tableAttrs, idAttr)); // update values in entity junction table if (!junctionTableAttrs.isEmpty()) { Map<String, List<Map<String, Object>>> mrefs = createMrefMap(idAttr, junctionTableAttrs, entitiesBatch); // update mrefs List<Object> ids = entitiesBatch.stream().map(entity -> getPostgreSqlValue(entity, idAttr)) .collect(toList()); for (Attribute attr : junctionTableAttrs) { removeMrefs(ids, attr); addMrefs(mrefs.get(attr.getName()), attr); } } }); } private void addMrefs(final List<Map<String, Object>> mrefs, final Attribute attr) { // database doesn't validate NOT NULL constraint for attribute values referencing multiple entities, // so validate it ourselves if (!attr.isNillable() && mrefs.isEmpty()) { throw new MolgenisValidationException(new ConstraintViolation( format("Entity [%s] attribute [%s] value cannot be null", entityType.getName(), attr.getName()))); } final Attribute idAttr = entityType.getIdAttribute(); String insertMrefSql = getSqlInsertJunction(entityType, attr); if (LOG.isDebugEnabled()) { LOG.debug("Adding junction table entries for entity [{}] attribute [{}]", getName(), attr.getName()); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}", insertMrefSql); } } jdbcTemplate.batchUpdate(insertMrefSql, new BatchJunctionTableAddPreparedStatementSetter(mrefs, attr, idAttr)); } private void removeMrefs(final List<Object> ids, final Attribute attr) { final Attribute idAttr = attr.isMappedBy() ? attr.getMappedBy() : entityType.getIdAttribute(); String deleteMrefSql = getSqlDelete(getJunctionTableName(entityType, attr), idAttr); if (LOG.isDebugEnabled()) { LOG.debug("Removing junction table entries for entity [{}] attribute [{}]", getName(), attr.getName()); if (LOG.isTraceEnabled()) { LOG.trace("SQL: {}", deleteMrefSql); } } jdbcTemplate.batchUpdate(deleteMrefSql, new BatchJunctionTableDeletePreparedStatementSetter(ids)); } private static class BatchAddPreparedStatementSetter implements BatchPreparedStatementSetter { private final List<? extends Entity> entities; private final List<Attribute> tableAttrs; BatchAddPreparedStatementSetter(List<? extends Entity> entities, List<Attribute> tableAttrs) { this.entities = entities; this.tableAttrs = tableAttrs; } @Override public void setValues(PreparedStatement preparedStatement, int rowIndex) throws SQLException { Entity entity = entities.get(rowIndex); int fieldIndex = 1; for (Attribute attr : tableAttrs) { Object postgreSqlValue = getPostgreSqlValue(entity, attr); preparedStatement.setObject(fieldIndex++, postgreSqlValue); } } @Override public int getBatchSize() { return entities.size(); } } private static class BatchUpdatePreparedStatementSetter implements BatchPreparedStatementSetter { private final List<? extends Entity> entities; private final List<Attribute> tableAttrs; private final Attribute idAttr; BatchUpdatePreparedStatementSetter(List<? extends Entity> entities, List<Attribute> tableAttrs, Attribute idAttr) { this.entities = entities; this.tableAttrs = tableAttrs; this.idAttr = idAttr; } @Override public void setValues(PreparedStatement preparedStatement, int rowIndex) throws SQLException { Entity entity = entities.get(rowIndex); int fieldIndex = 1; for (Attribute attr : tableAttrs) { Object postgreSqlValue = getPostgreSqlValue(entity, attr); preparedStatement.setObject(fieldIndex++, postgreSqlValue); } preparedStatement.setObject(fieldIndex, getPostgreSqlValue(entity, idAttr)); } @Override public int getBatchSize() { return entities.size(); } } private static class BatchDeletePreparedStatementSetter implements BatchPreparedStatementSetter { private final List<Object> entityIds; BatchDeletePreparedStatementSetter(List<Object> entityIds) { this.entityIds = entityIds; } @Override public void setValues(PreparedStatement preparedStatement, int i) throws SQLException { preparedStatement.setObject(1, entityIds.get(i)); } @Override public int getBatchSize() { return entityIds.size(); } } private static class BatchJunctionTableAddPreparedStatementSetter implements BatchPreparedStatementSetter { private final List<Map<String, Object>> mrefs; private final Attribute attr; private final Attribute idAttr; BatchJunctionTableAddPreparedStatementSetter(List<Map<String, Object>> mrefs, Attribute attr, Attribute idAttr) { this.mrefs = mrefs; this.attr = attr; this.idAttr = idAttr; } @Override public void setValues(PreparedStatement preparedStatement, int i) throws SQLException { Map<String, Object> mref = mrefs.get(i); Object idValue0, idValue1; if (attr.isMappedBy()) { Entity mrefEntity = (Entity) mref.get(attr.getName()); idValue0 = getPostgreSqlValue(mrefEntity, attr.getRefEntity().getIdAttribute()); idValue1 = mref.get(idAttr.getName()); } else { idValue0 = mref.get(idAttr.getName()); Entity mrefEntity = (Entity) mref.get(attr.getName()); idValue1 = getPostgreSqlValue(mrefEntity, mrefEntity.getEntityType().getIdAttribute()); } preparedStatement.setInt(1, (int) mref.get(JUNCTION_TABLE_ORDER_ATTR_NAME)); preparedStatement.setObject(2, idValue0); preparedStatement.setObject(3, idValue1); } @Override public int getBatchSize() { return mrefs.size(); } } private static class BatchJunctionTableDeletePreparedStatementSetter implements BatchPreparedStatementSetter { private final List<Object> entityIds; BatchJunctionTableDeletePreparedStatementSetter(List<Object> entityIds) { this.entityIds = entityIds; } @Override public void setValues(PreparedStatement preparedStatement, int i) throws SQLException { preparedStatement.setObject(1, entityIds.get(i)); } @Override public int getBatchSize() { return entityIds.size(); } } }