// Copyright © 2015 HSL <https://www.hsl.fi> // This program is dual-licensed under the EUPL v1.2 and AGPLv3 licenses. package fi.hsl.parkandride.back; import com.querydsl.core.Tuple; import com.querydsl.core.dml.StoreClause; import com.querydsl.core.types.MappingProjection; import com.querydsl.core.types.dsl.ComparableExpression; import com.querydsl.core.types.dsl.SimpleExpression; import com.querydsl.sql.SQLExpressions; import com.querydsl.sql.dml.SQLInsertClause; import com.querydsl.sql.dml.SQLUpdateClause; import com.querydsl.sql.postgresql.PostgreSQLQuery; import com.querydsl.sql.postgresql.PostgreSQLQueryFactory; import fi.hsl.parkandride.back.sql.QContact; import fi.hsl.parkandride.core.back.ContactRepository; import fi.hsl.parkandride.core.domain.*; import fi.hsl.parkandride.core.service.TransactionalRead; import fi.hsl.parkandride.core.service.TransactionalWrite; import fi.hsl.parkandride.core.service.ValidationException; import static com.google.common.base.MoreObjects.firstNonNull; import static fi.hsl.parkandride.core.domain.Sort.Dir.ASC; import static fi.hsl.parkandride.core.domain.Sort.Dir.DESC; public class ContactDao implements ContactRepository { private static final Sort DEFAULT_SORT = new Sort("name.fi", ASC); public static final String CONTACT_ID_SEQ = "contact_id_seq"; private static final SimpleExpression<Long> contactIdNextval = SQLExpressions.nextval(CONTACT_ID_SEQ); private static QContact qContact = QContact.contact; private static MultilingualStringMapping nameMapping = new MultilingualStringMapping(qContact.nameFi, qContact.nameSv, qContact.nameEn); private static MultilingualStringMapping openingHoursMapping = new MultilingualStringMapping(qContact.openingHoursFi, qContact.openingHoursSv, qContact.openingHoursEn); private static MultilingualStringMapping infoMapping = new MultilingualStringMapping(qContact.infoFi, qContact.infoSv, qContact.infoEn); private static AddressMapping addressMapping = new AddressMapping(qContact); private static MappingProjection<Contact> contactMapping = new MappingProjection<Contact>(Contact.class, qContact.all()) { @Override protected Contact map(Tuple row) { Long id = row.get(qContact.id); if (id == null) { return null; } Contact contact = new Contact(); contact.id = id; contact.operatorId = row.get(qContact.operatorId); contact.email = row.get(qContact.email); contact.phone = row.get(qContact.phone); contact.name = nameMapping.map(row); contact.address = addressMapping.map(row); contact.openingHours = openingHoursMapping.map(row); contact.info = infoMapping.map(row); return contact; } }; private final PostgreSQLQueryFactory queryFactory; public ContactDao(PostgreSQLQueryFactory queryFactory) { this.queryFactory = queryFactory; } @Override @TransactionalWrite public long insertContact(Contact contact) { return insertContact(contact, queryFactory.query().select(contactIdNextval).fetchOne()); } @TransactionalWrite public long insertContact(Contact contact, Long contactId) { SQLInsertClause insert = queryFactory.insert(qContact); insert.set(qContact.id, contactId); populate(contact, insert); insert.execute(); return contactId; } @Override @TransactionalRead public Contact getContact(long contactId) { return getContact(contactId, false); } @Override @TransactionalRead public Contact getContactForUpdate(long contactId) { return getContact(contactId, true); } private Contact getContact(long contactId, boolean forUpdate) { PostgreSQLQuery<Contact> qry = queryFactory.from(qContact).select(contactMapping).where(qContact.id.eq(contactId)); if (forUpdate) { qry.forUpdate(); } return qry.fetchOne(); } @Override @TransactionalWrite public void updateContact(long contactId, Contact contact) { SQLUpdateClause update = queryFactory.update(qContact).where(qContact.id.eq(contactId)); populate(contact, update); if (update.execute() != 1) { notFound(contactId); } } private void notFound(long contactId) { throw new NotFoundException("Contact by id '%s'", contactId); } @Override @TransactionalRead public SearchResults<Contact> findContacts(ContactSearch search) { PostgreSQLQuery<Contact> qry = queryFactory.from(qContact).select(contactMapping); qry.limit(search.getLimit() + 1); qry.offset(search.getOffset()); if (search.getIds() != null && !search.getIds().isEmpty()) { qry.where(qContact.id.in(search.getIds())); } if (search.getOperatorId() != null) { qry.where(qContact.operatorId.isNull().or(qContact.operatorId.eq(search.getOperatorId()))); } orderBy(search.getSort(), qry); return SearchResults.of(qry.fetch(), search.getLimit()); } private void populate(Contact contact, StoreClause<?> store) { store .set(qContact.operatorId, contact.operatorId) .set(qContact.phone, contact.phone) .set(qContact.email, contact.email); nameMapping.populate(contact.name, store); addressMapping.populate(contact.address, store); openingHoursMapping.populate(contact.openingHours, store); infoMapping.populate(contact.info, store); } private void orderBy(Sort sort, PostgreSQLQuery qry) { sort = firstNonNull(sort, DEFAULT_SORT); ComparableExpression<String> sortField; switch (firstNonNull(sort.getBy(), DEFAULT_SORT.getBy())) { case "name.fi": sortField = qContact.nameFi.lower(); break; case "name.sv": sortField = qContact.nameSv.lower(); break; case "name.en": sortField = qContact.nameEn.lower(); break; default: throw invalidSortBy(); } if (DESC.equals(sort.getDir())) { qry.orderBy(sortField.desc(), qContact.id.desc()); } else { qry.orderBy(sortField.asc(), qContact.id.asc()); } } private ValidationException invalidSortBy() { return new ValidationException(new Violation("SortBy", "sort.by", "Expected one of 'name.fi', 'name.sv' or 'name.en'")); } }