// Copyright © 2015 HSL <https://www.hsl.fi> // This program is dual-licensed under the EUPL v1.2 and AGPLv3 licenses. package fi.hsl.parkandride.back; import com.google.common.collect.ImmutableMap; import com.querydsl.core.Tuple; import com.querydsl.core.dml.StoreClause; import com.querydsl.core.group.GroupBy; import com.querydsl.core.types.ConstantImpl; import com.querydsl.core.types.Expression; import com.querydsl.core.types.MappingProjection; import com.querydsl.core.types.dsl.ComparableExpression; import com.querydsl.core.types.dsl.SimpleExpression; import com.querydsl.sql.SQLExpressions; import com.querydsl.sql.SQLQuery; import com.querydsl.sql.dml.SQLInsertClause; import com.querydsl.sql.dml.SQLUpdateClause; import com.querydsl.sql.postgresql.PostgreSQLQuery; import com.querydsl.sql.postgresql.PostgreSQLQueryFactory; import fi.hsl.parkandride.back.sql.QHub; import fi.hsl.parkandride.back.sql.QHubFacility; import fi.hsl.parkandride.core.back.HubRepository; import fi.hsl.parkandride.core.domain.*; import fi.hsl.parkandride.core.service.TransactionalRead; import fi.hsl.parkandride.core.service.TransactionalWrite; import fi.hsl.parkandride.core.service.ValidationException; import java.util.Map; import java.util.Set; import static com.google.common.base.MoreObjects.firstNonNull; import static com.querydsl.core.group.GroupBy.groupBy; import static com.querydsl.spatial.GeometryExpressions.dwithin; import static fi.hsl.parkandride.core.domain.Sort.Dir.ASC; import static fi.hsl.parkandride.core.domain.Sort.Dir.DESC; public class HubDao implements HubRepository { private static final Sort DEFAULT_SORT = new Sort("name.fi", ASC); public static final String HUB_ID_SEQ = "hub_id_seq"; private static final SimpleExpression<Long> nextHubId = SQLExpressions.nextval(HUB_ID_SEQ); private static final QHub qHub = QHub.hub; private static final QHubFacility qHubFacility = QHubFacility.hubFacility; private static final Expression<Set<Long>> facilityIdsMapping = GroupBy.set(qHubFacility.facilityId); private static final MultilingualStringMapping nameMapping = new MultilingualStringMapping(qHub.nameFi, qHub.nameSv, qHub.nameEn); private static final AddressMapping addressMapping = new AddressMapping(qHub); private static final MappingProjection<Hub> hubMapping = new MappingProjection<Hub>(Hub.class, qHub.all()) { @Override protected Hub map(Tuple row) { Long id = row.get(qHub.id); if (id == null) { return null; } Hub hub = new Hub(); hub.id = id; hub.location = row.get(qHub.location); hub.name = nameMapping.map(row); hub.address = addressMapping.map(row); return hub; } }; private final PostgreSQLQueryFactory queryFactory; public HubDao(PostgreSQLQueryFactory queryFactory) { this.queryFactory = queryFactory; } @Override @TransactionalWrite public long insertHub(Hub hub) { return insertHub(hub, queryFactory.query().select(nextHubId).fetchOne()); } @TransactionalWrite public long insertHub(Hub hub, long hubId) { SQLInsertClause insert = queryFactory.insert(qHub); insert.set(qHub.id, hubId); populate(hub, insert); insert.execute(); insertHubFacilities(hubId, hub.facilityIds); return hubId; } @Override @TransactionalWrite public void updateHub(long hubId, Hub hub) { SQLUpdateClause update = queryFactory.update(qHub); update.where(qHub.id.eq(hubId)); populate(hub, update); if (update.execute() != 1) { throw new HubNotFoundException(hubId); } deleteHubFacilities(hubId); insertHubFacilities(hubId, hub.facilityIds); } private void deleteHubFacilities(long hubId) { queryFactory.delete(qHubFacility).where(qHubFacility.hubId.eq(hubId)).execute(); } @Override @TransactionalRead public Hub getHub(long hubId) { Hub hub = queryFactory.from(qHub).select(hubMapping).where(qHub.id.eq(hubId)).fetchOne(); if (hub == null) { throw new HubNotFoundException(hubId); } fetchFacilityIds(ImmutableMap.of(hubId, hub)); return hub; } @Override @TransactionalRead public SearchResults<Hub> findHubs(HubSearch search) { final PostgreSQLQuery<Hub> qry = queryFactory.from(qHub).select(hubMapping); if (search.getLimit() >= 0) { qry.limit(search.getLimit() + 1); // find one extra for hasMore } if (search.getOffset() > 0) { qry.offset(search.getOffset()); } buildWhere(search, qry); orderBy(search.getSort(), qry); Map<Long, Hub> hubs = qry.transform(groupBy(qHub.id).as(hubMapping)); fetchFacilityIds(hubs); return SearchResults.of(hubs.values(), search.getLimit()); } private void buildWhere(HubSearch search, PostgreSQLQuery qry) { if (search.getGeometry() != null) { if (search.getMaxDistance() != null && search.getMaxDistance() > 0) { qry.where(dwithin(qHub.location, ConstantImpl.create(search.getGeometry()), search.getMaxDistance())); } else { qry.where(qHub.location.intersects(search.getGeometry())); } } if (search.getIds() != null && !search.getIds().isEmpty()) { qry.where(qHub.id.in(search.getIds())); } if (search.getFacilityIds() != null && !search.getFacilityIds().isEmpty()) { final SQLQuery<Long> hasFacilityId = SQLExpressions.select(qHubFacility.facilityId) .from(qHubFacility) .where(qHubFacility.hubId.eq(qHub.id), qHubFacility.facilityId.in(search.getFacilityIds())); qry.where(hasFacilityId.exists()); } } private void fetchFacilityIds(Map<Long, Hub> hubs) { if (!hubs.isEmpty()) { final PostgreSQLQuery<Long> qry = queryFactory.from(qHubFacility) .select(qHubFacility.hubId) .where(qHubFacility.hubId.in(hubs.keySet())); Map<Long, Set<Long>> hubFacilityIds = qry.transform(groupBy(qHubFacility.hubId).as(facilityIdsMapping)); for (Map.Entry<Long, Set<Long>> entry : hubFacilityIds.entrySet()) { hubs.get(entry.getKey()).facilityIds = entry.getValue(); } } } private void populate(Hub hub, StoreClause store) { store.set(qHub.location, hub.location); nameMapping.populate(hub.name, store); addressMapping.populate(hub.address, store); } private void insertHubFacilities(long hubId, Set<Long> facilityIds) { if (facilityIds != null && !facilityIds.isEmpty()) { SQLInsertClause insertBatch = queryFactory.insert(qHubFacility); for (Long facilityId : facilityIds) { insertBatch.set(qHubFacility.hubId, hubId); insertBatch.set(qHubFacility.facilityId, facilityId); insertBatch.addBatch(); } insertBatch.execute(); } } private void orderBy(Sort sort, PostgreSQLQuery qry) { sort = firstNonNull(sort, DEFAULT_SORT); ComparableExpression<String> sortField; switch (firstNonNull(sort.getBy(), DEFAULT_SORT.getBy())) { case "name.fi": sortField = qHub.nameFi.lower(); break; case "name.sv": sortField = qHub.nameSv.lower(); break; case "name.en": sortField = qHub.nameEn.lower(); break; default: throw invalidSortBy(); } if (DESC.equals(sort.getDir())) { qry.orderBy(sortField.desc(), qHub.id.desc()); } else { qry.orderBy(sortField.asc(), qHub.id.asc()); } } private ValidationException invalidSortBy() { return new ValidationException(new Violation("SortBy", "sort.by", "Expected one of 'name.fi', 'name.sv' or 'name.en'")); } }