// Copyright © 2016 HSL <https://www.hsl.fi> // This program is dual-licensed under the EUPL v1.2 and AGPLv3 licenses. package fi.hsl.parkandride.back.prediction; import com.querydsl.core.QueryException; import com.querydsl.core.Tuple; import com.querydsl.core.types.Expression; import com.querydsl.core.types.MappingProjection; import com.querydsl.core.types.Path; import com.querydsl.core.types.Projections; import com.querydsl.core.types.dsl.BooleanExpression; import com.querydsl.sql.dml.SQLInsertClause; import com.querydsl.sql.dml.SQLUpdateClause; import com.querydsl.sql.postgresql.PostgreSQLQueryFactory; import fi.hsl.parkandride.back.TimeUtil; import fi.hsl.parkandride.back.sql.QFacilityPrediction; import fi.hsl.parkandride.back.sql.QFacilityPredictionHistory; import fi.hsl.parkandride.core.back.PredictionRepository; import fi.hsl.parkandride.core.domain.UtilizationKey; import fi.hsl.parkandride.core.domain.prediction.Prediction; import fi.hsl.parkandride.core.domain.prediction.PredictionBatch; import fi.hsl.parkandride.core.service.TransactionalRead; import fi.hsl.parkandride.core.service.TransactionalWrite; import fi.hsl.parkandride.core.service.ValidationService; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.Duration; import org.joda.time.format.DateTimeFormat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; import java.util.function.BinaryOperator; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.Stream; import static java.util.stream.Collectors.toList; import static org.joda.time.Duration.standardHours; import static org.joda.time.Duration.standardMinutes; public class PredictionDao implements PredictionRepository { private static final Logger log = LoggerFactory.getLogger(PredictionDao.class); private static final QFacilityPrediction qPrediction = QFacilityPrediction.facilityPrediction; private static final QFacilityPredictionHistory qPredictionHistory = QFacilityPredictionHistory.facilityPredictionHistory; private static final Map<String, Path<Integer>> spacesAvailableColumnsByHHmm = Collections.unmodifiableMap( Stream.of(qPrediction.all()) .filter(p -> p.getMetadata().getName().startsWith("spacesAvailableAt")) .map(PredictionDao::castToIntegerPath) .collect(Collectors.toMap( p -> p.getMetadata().getName().substring("spacesAvailableAt".length()), Function.identity()))); public static final List<Duration> predictionsDistancesToStore = Collections.unmodifiableList( Arrays.<Duration>asList(standardMinutes(5), standardMinutes(10), standardMinutes(15), standardMinutes(20), standardMinutes(30), standardMinutes(45), standardHours(1), standardHours(2), standardHours(4), standardHours(8), standardHours(12), standardHours(16), standardHours(20), standardHours(24))); private final PostgreSQLQueryFactory queryFactory; private final ValidationService validationService; public PredictionDao(PostgreSQLQueryFactory queryFactory, ValidationService validationService) { this.queryFactory = queryFactory; this.validationService = validationService; } @TransactionalWrite @Override public void updatePredictions(PredictionBatch pb, Long predictorId) { validationService.validate(pb); UtilizationKey utilizationKey = pb.utilizationKey; DateTime start = toPredictionResolution(pb.sourceTimestamp); List<Prediction> predictions = normalizeToPredictionWindow(start, pb.predictions); long updatedRows = maybeUpdatePredictionLookupTable(utilizationKey, start, predictions); if (updatedRows == 0) { initializePredictionLookupTable(utilizationKey); updatePredictions(pb, predictorId); // retry now that the lookup table exists } else { savePredictionHistory(predictorId, start, filterToSelectedPredictionDistances(start, predictions)); } } @TransactionalWrite @Override public void updateOnlyPredictionHistory(PredictionBatch pb, Long predictorId) { validationService.validate(pb); DateTime start = toPredictionResolution(pb.sourceTimestamp); List<Prediction> predictions = normalizeToPredictionWindow(start, pb.predictions); savePredictionHistory(predictorId, start, filterToSelectedPredictionDistances(start, predictions)); } private List<Prediction> filterToSelectedPredictionDistances(DateTime start, List<Prediction> predictions) { return predictions.stream() .filter(p -> predictionsDistancesToStore.contains(new Duration(start, p.timestamp))) .collect(toList()); } private static List<Prediction> normalizeToPredictionWindow(DateTime start, List<Prediction> predictions) { DateTime end = start.plus(PREDICTION_WINDOW).minus(PREDICTION_RESOLUTION); return predictions.stream() // remove too fine-grained predictions .collect(groupByRoundedTimeKeepingNewest()) // -> Map<DateTime, Prediction> .values().stream() // normalize resolution .map(roundTimestampsToPredictionResolution()) // interpolate too coarse-grained predictions .sorted(Comparator.comparing(p -> p.timestamp)) .map(Collections::singletonList) // 1. wrap values in immutable singleton lists .reduce(new ArrayList<>(), linearInterpolation()).stream() // 2. mutable ArrayList as accumulator // normalize range .filter(isWithin(start, end)) // after interpolation because of PredictionDaoTest.does_linear_interpolation_also_between_values_outside_the_prediction_window .collect(toList()); } private static Predicate<Prediction> isWithin(DateTime start, DateTime end) { return p -> !p.timestamp.isBefore(start) && !p.timestamp.isAfter(end); } private static Function<Prediction, Prediction> roundTimestampsToPredictionResolution() { return p -> new Prediction(toPredictionResolution(p.timestamp), p.spacesAvailable); } private static Collector<Prediction, ?, Map<DateTime, Prediction>> groupByRoundedTimeKeepingNewest() { return Collectors.toMap( p -> toPredictionResolution(p.timestamp), Function.identity(), (a, b) -> a.timestamp.isAfter(b.timestamp) ? a : b, HashMap::new ); } private static BinaryOperator<List<Prediction>> linearInterpolation() { return (interpolated, input) -> { if (input.size() != 1) { throw new IllegalArgumentException("expected one element, but got " + input); } if (interpolated.isEmpty()) { interpolated.addAll(input); return interpolated; } Prediction previous = interpolated.get(interpolated.size() - 1); Prediction next = input.get(0); for (DateTime timestamp = previous.timestamp.plus(PREDICTION_RESOLUTION); timestamp.isBefore(next.timestamp); timestamp = timestamp.plus(PREDICTION_RESOLUTION)) { double totalDuration = new Duration(previous.timestamp, next.timestamp).getMillis(); double currentDuration = new Duration(previous.timestamp, timestamp).getMillis(); double proportion = currentDuration / totalDuration; int totalChange = next.spacesAvailable - previous.spacesAvailable; int currentChange = (int) Math.round(totalChange * proportion); int spacesAvailable = previous.spacesAvailable + currentChange; interpolated.add(new Prediction(timestamp, spacesAvailable)); } interpolated.add(next); return interpolated; }; } private void initializePredictionLookupTable(UtilizationKey utilizationKey) { queryFactory.insert(qPrediction) .set(qPrediction.facilityId, utilizationKey.facilityId) .set(qPrediction.capacityType, utilizationKey.capacityType) .set(qPrediction.usage, utilizationKey.usage) .execute(); } private long maybeUpdatePredictionLookupTable(UtilizationKey utilizationKey, DateTime start, List<Prediction> predictions) { SQLUpdateClause update = queryFactory.update(qPrediction) .where(qPrediction.facilityId.eq(utilizationKey.facilityId), qPrediction.capacityType.eq(utilizationKey.capacityType), qPrediction.usage.eq(utilizationKey.usage)) .set(qPrediction.start, start); predictions.forEach(p -> update.set(spacesAvailableAt(p.timestamp), p.spacesAvailable)); return update.execute(); } private void savePredictionHistory(Long predictorId, DateTime start, List<Prediction> predictions) { if (predictions.isEmpty()) { return; } SQLInsertClause insert = queryFactory.insert(qPredictionHistory); predictions.forEach(p -> insert .set(qPredictionHistory.predictorId, predictorId) .set(qPredictionHistory.forecastDistanceInMinutes, ((int) new Duration(start, p.timestamp).getStandardMinutes())) .set(qPredictionHistory.ts, p.timestamp) .set(qPredictionHistory.spacesAvailable, p.spacesAvailable) .addBatch()); try { insert.execute(); } catch (QueryException e) { // XXX: upsert would be a better way to ignore primary key conflicts, but this shall do for now log.error("Failed save prediction history for predictor " + predictorId, e); } } @TransactionalRead @Override public Optional<PredictionBatch> getPrediction(UtilizationKey utilizationKey, DateTime time) { return asOptional(queryFactory .from(qPrediction) .select(predictionMapping(time)) .where(qPrediction.facilityId.eq(utilizationKey.facilityId), qPrediction.capacityType.eq(utilizationKey.capacityType), qPrediction.usage.eq(utilizationKey.usage)) .where(isWithinPredictionWindow(time)) .fetchOne()); } @TransactionalRead @Override public List<PredictionBatch> getPredictionsByFacility(Long facilityId, DateTime time) { return queryFactory .from(qPrediction) .select(predictionMapping(time)) .where(qPrediction.facilityId.eq(facilityId)) .where(isWithinPredictionWindow(time)) .fetch(); } @TransactionalRead @Override public List<Prediction> getPredictionHistoryByPredictor(Long predictorId, DateTime start, DateTime end, int forecastDistanceInMinutes) { return queryFactory.from(qPredictionHistory) .select(historyToPredictionMapping()) .where(qPredictionHistory.predictorId.eq(predictorId), qPredictionHistory.forecastDistanceInMinutes.eq(forecastDistanceInMinutes), qPredictionHistory.ts.between(start, end)) .orderBy(qPredictionHistory.ts.asc()) .fetch(); } private Expression<Prediction> historyToPredictionMapping() { return Projections.constructor(Prediction.class, qPredictionHistory.ts, qPredictionHistory.spacesAvailable); } private static BooleanExpression isWithinPredictionWindow(DateTime time) { time = toPredictionResolution(time); return qPrediction.start.between(time.minus(PREDICTION_WINDOW).plus(PREDICTION_RESOLUTION), time); } private static MappingProjection<PredictionBatch> predictionMapping(DateTime timeWithFullPrecision) { DateTime time = toPredictionResolution(timeWithFullPrecision); Path<Integer> spacesAvailableColumn = spacesAvailableAt(time); return new MappingProjection<PredictionBatch>(PredictionBatch.class, qPrediction.facilityId, qPrediction.capacityType, qPrediction.usage, qPrediction.start, spacesAvailableColumn) { @Override protected PredictionBatch map(Tuple row) { PredictionBatch pb = new PredictionBatch(); pb.utilizationKey = new UtilizationKey( row.get(qPrediction.facilityId), row.get(qPrediction.capacityType), row.get(qPrediction.usage) ); pb.sourceTimestamp = row.get(qPrediction.start); Integer spacesAvailable = row.get(spacesAvailableColumn); if (spacesAvailable != null) { pb.predictions.add(new Prediction(time, spacesAvailable)); } return pb; } }; } private static Path<Integer> spacesAvailableAt(DateTime timestamp) { // Also other parts of this class assume prediction resolution, // so we don't do the rounding here, but require the timestamp // to already have been properly rounded. assert timestamp.equals(toPredictionResolution(timestamp)) : "not in prediction resolution: " + timestamp; String hhmm = DateTimeFormat.forPattern("HHmm").print(timestamp.withZone(DateTimeZone.UTC)); return spacesAvailableColumnsByHHmm.get(hhmm); } static DateTime toPredictionResolution(DateTime time) { return TimeUtil.roundMinutes(PREDICTION_RESOLUTION.getMinutes(), time); } private static Optional<PredictionBatch> asOptional(PredictionBatch pb) { if (pb == null || pb.predictions.isEmpty()) { return Optional.empty(); } else { return Optional.of(pb); } } @SuppressWarnings("unchecked") private static Path<Integer> castToIntegerPath(Path<?> path) { if (path.getType().equals(Integer.class)) { return (Path<Integer>) path; } throw new ClassCastException(path + " has type " + path.getType()); } }