/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.rakam.postgresql.analysis;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.rakam.analysis.metadata.Metastore;
import org.rakam.collection.SchemaField;
import org.rakam.config.ProjectConfig;
import org.rakam.postgresql.report.PostgresqlQueryExecutor;
import org.rakam.report.AbstractRetentionQueryExecutor;
import org.rakam.report.DelegateQueryExecution;
import org.rakam.report.QueryExecution;
import org.rakam.report.QueryResult;
import org.rakam.util.RakamException;
import org.rakam.util.ValidationUtil;
import javax.annotation.PostConstruct;
import javax.inject.Inject;
import java.sql.Connection;
import java.sql.SQLException;
import java.time.LocalDate;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalField;
import java.time.temporal.WeekFields;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import static com.facebook.presto.sql.RakamSqlFormatter.formatExpression;
import static com.google.common.primitives.Ints.checkedCast;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static org.rakam.analysis.RetentionQueryExecutor.DateUnit.MONTH;
import static org.rakam.analysis.RetentionQueryExecutor.DateUnit.WEEK;
import static org.rakam.collection.FieldType.INTEGER;
import static org.rakam.util.DateTimeUtils.TIMESTAMP_FORMATTER;
import static org.rakam.util.ValidationUtil.checkArgument;
import static org.rakam.util.ValidationUtil.checkCollection;
import static org.rakam.util.ValidationUtil.checkTableColumn;
public class PostgresqlRetentionQueryExecutor
extends AbstractRetentionQueryExecutor
{
private final PostgresqlQueryExecutor executor;
private final Metastore metastore;
private final ProjectConfig projectConfig;
@Inject
public PostgresqlRetentionQueryExecutor(ProjectConfig projectConfig, PostgresqlQueryExecutor executor, Metastore metastore)
{
this.executor = executor;
this.metastore = metastore;
this.projectConfig = projectConfig;
}
@PostConstruct
public void setup()
{
try (Connection conn = executor.getConnection()) {
try {
conn.createStatement().execute(V8_RETENTION_FUNCTIONS);
}
catch (SQLException e) {
// plv8 is not available, fallback to pl/pgsql
conn.createStatement().execute(PL_PGGSQL_RETENTION_TIMELINE_FUNCTION);
conn.createStatement().execute(PL_PGGSQL_RETENTION_AGGREGATE_FUNCTION);
}
try {
conn.createStatement().execute("CREATE AGGREGATE collect_retention(boolean[])\n" +
"(\n" +
" sfunc = analyze_retention_intermediate,\n" +
" stype = integer[],\n" +
" initcond = '{}'\n" +
")");
}
catch (SQLException e) {
if (!e.getSQLState().equals("42723")) {
throw Throwables.propagate(e);
}
}
try {
conn.createStatement().execute(String.format("CREATE TYPE retention_action AS " +
"(is_first boolean, %s timestamp)",
checkTableColumn(projectConfig.getTimeColumn())));
}
catch (SQLException e) {
if (!e.getSQLState().equals("42710")) {
throw Throwables.propagate(e);
}
}
}
catch (SQLException e) {
throw Throwables.propagate(e);
}
}
@Override
public QueryExecution query(String project, Optional<RetentionAction> firstAction,
Optional<RetentionAction> returningAction, DateUnit dateUnit,
Optional<String> dimension, Optional<Integer> period,
LocalDate startDate, LocalDate endDate, ZoneId zoneId, boolean approximate)
{
period.ifPresent(e -> checkArgument(e >= 0, "Period must be 0 or a positive value"));
checkTableColumn(CONNECTOR_FIELD, "connector field", '"');
if (approximate) {
// TODO: should we throw an exception or just show a warning?
// throw new RakamException("Approximation is not supported.", HttpResponseStatus.BAD_REQUEST);
}
String timeColumn = getTimeExpression(dateUnit);
LocalDate start;
LocalDate end;
if (dateUnit == MONTH) {
start = startDate.withDayOfMonth(1);
end = endDate.withDayOfMonth(1).plus(1, ChronoUnit.MONTHS);
}
else if (dateUnit == WEEK) {
TemporalField fieldUS = WeekFields.of(Locale.US).dayOfWeek();
start = startDate.with(fieldUS, 1);
end = endDate.with(fieldUS, 1).plus(1, ChronoUnit.MONTHS);
}
else {
start = startDate;
end = endDate;
}
Optional<Integer> range = period.map(v ->
Math.min(v, checkedCast(dateUnit.getTemporalUnit().between(start, end))));
if (range.isPresent() && range.get() < 0) {
throw new IllegalArgumentException("startDate must be before endDate.");
}
if (range.isPresent() && range.get() < 0) {
return QueryExecution.completedQueryExecution(null, QueryResult.empty());
}
Map<String, List<SchemaField>> collections = metastore.getCollections(project);
String firstActionQuery = generateQuery(
collections, project, firstAction,
testDeviceIdExists(firstAction, collections) ? format("coalesce(cast(%s as varchar), _device_id) as %s", CONNECTOR_FIELD, checkTableColumn(CONNECTOR_FIELD)) : CONNECTOR_FIELD,
dimension, startDate, endDate, zoneId);
String returningActionQuery = generateQuery(
collections, project, returningAction,
testDeviceIdExists(firstAction, collections) ? format("coalesce(cast(%s as varchar), _device_id) as %s", CONNECTOR_FIELD, checkTableColumn(CONNECTOR_FIELD)) : CONNECTOR_FIELD,
dimension, startDate, endDate, zoneId);
String query;
if (firstAction.equals(returningAction)) {
query = format("select %s, collect_retention(bits) from (\n" +
"select %s, (case when (not is_first %s) then \n" +
"generate_timeline(%s, timeline, %d::bigint, 15) else null end) as bits from (\n" +
"select %s %s, true as is_first, %s as timeline from (%s) t group by 1 %s " +
"UNION ALL " +
"select %s %s, false as is_first, array_agg(%s::date order by %s::date) as timeline from (%s) t group by 1 %s) t\n" +
"%s \n" +
") t\n" +
"group by 1 order by 1 asc",
dimension.map(v -> "dimension").orElse("date"),
dimension.map(v -> "dimension").orElse("date"),
// do not check dimension value for first action dates.
dimension.map(v -> "").orElse("and dates.date = any(timeline)"),
dimension.map(v -> "timeline[1]").orElse("dates.date::date"),
dateUnit.getTemporalUnit().getDuration().toMillis(),
dimension.map(v -> "dimension").map(v -> v + ", ").orElse(""),
CONNECTOR_FIELD,
// if we're calculating by dimension, take the first event data for each user
dimension.map(val -> format("array[min(%s)]", format(timeColumn, projectConfig.getTimeColumn())))
.orElseGet(() -> format("array_agg(%s::date order by %s::date)", format(timeColumn, projectConfig.getTimeColumn()), format(timeColumn, projectConfig.getTimeColumn()))),
firstActionQuery,
dimension.map(v -> ", 2").orElse(""),
dimension.map(v -> "dimension").map(v -> v + ", ").orElse(""),
CONNECTOR_FIELD,
format(timeColumn, projectConfig.getTimeColumn()), format(timeColumn, projectConfig.getTimeColumn()),
returningActionQuery,
dimension.map(v -> ", 2").orElse(""),
dimension.map(v -> "").orElseGet(() -> String.format("cross join (select generate_series(date_trunc('%s', date '%s'), date_trunc('%s', date '%s'), interval '1 %s')::date date) dates",
dateUnit.name().toLowerCase(ENGLISH), TIMESTAMP_FORMATTER.format(startDate.atStartOfDay(zoneId)),
dateUnit.name().toLowerCase(ENGLISH), TIMESTAMP_FORMATTER.format(endDate.atStartOfDay(zoneId)),
dateUnit.name().toLowerCase(ENGLISH))));
}
else {
query = format("select %s, collect_retention(bits) from (\n" +
"select %s, (case when (%s) then \n" +
"generate_timeline(%s, ret.timeline, %d::bigint, 15) else null end) as bits from (\n" +
"select %s %s, %s as timeline from (%s) t group by 1 %s " +
") first left join ( " +
"select %s %s, array_agg(%s::date order by %s::date) as timeline from (%s) t group by 1 %s \n" +
") ret on (ret._user = first._user %s)\n" +
"%s \n" +
") t\n" +
"group by 1 order by 1 asc",
dimension.map(v -> "dimension").orElse("date"),
dimension.map(v -> "first.dimension").orElse("date"),
// do not check dimension value for first action dates.
dimension.map(v -> "true").orElse("dates.date = any(ret.timeline)"),
dimension.map(v -> "first.timeline[1]").orElse("dates.date::date"),
dateUnit.getTemporalUnit().getDuration().toMillis(),
dimension.map(v -> "dimension").map(v -> v + ", ").orElse(""),
CONNECTOR_FIELD,
// if we're calculating by dimension, take the first event data for each user
dimension.map(val -> format("array[min(%s)]", format(timeColumn, projectConfig.getTimeColumn())))
.orElseGet(() -> format("array_agg(%s::date order by %s::date)", format(timeColumn, projectConfig.getTimeColumn()), format(timeColumn, projectConfig.getTimeColumn()))),
firstActionQuery,
dimension.map(v -> ", 2").orElse(""),
dimension.map(v -> "dimension").map(v -> v + ", ").orElse(""),
CONNECTOR_FIELD,
format(timeColumn, projectConfig.getTimeColumn()), format(timeColumn, projectConfig.getTimeColumn()),
returningActionQuery,
dimension.map(v -> ", 2").orElse(""),
dimension.map(v -> " and first.dimension = ret.dimension").orElse(""),
dimension.map(v -> "").orElseGet(() -> String.format("cross join (select generate_series(date_trunc('%s', date '%s'), date_trunc('%s', date '%s'), interval '1 %s')::date date) dates",
dateUnit.name().toLowerCase(ENGLISH), TIMESTAMP_FORMATTER.format(startDate.atStartOfDay(zoneId)),
dateUnit.name().toLowerCase(ENGLISH), TIMESTAMP_FORMATTER.format(endDate.atStartOfDay(zoneId)),
dateUnit.name().toLowerCase(ENGLISH))));
}
return new DelegateQueryExecution(executor.executeRawQuery(query), (result) -> {
if (result.isFailed()) {
return result;
}
ArrayList<List<Object>> rows = new ArrayList<>();
for (List<Object> objects : result.getResult()) {
Object date = objects.get(0);
Integer[] days = (Integer[]) objects.get(1);
for (int i = 0; i < days.length; i++) {
if (days[i] != null) {
rows.add(Arrays.asList(date, i == 0 ? null : ((long) i - 1), (long) days[i]));
}
}
}
return new QueryResult(ImmutableList.of(
new SchemaField("dimension", result.getMetadata().get(0).getType()),
new SchemaField("lead", INTEGER),
new SchemaField("value", INTEGER)), rows, result.getProperties());
});
}
private String generateQuery(
Map<String, List<SchemaField>> collections,
String project,
Optional<RetentionAction> retentionAction,
String connectorField,
Optional<String> dimension,
LocalDate startDate,
LocalDate endDate,
ZoneId zoneId)
{
String timePredicate = format("between timestamp '%s' and timestamp '%s' + interval '1' day",
TIMESTAMP_FORMATTER.format(startDate.atStartOfDay(zoneId)),
TIMESTAMP_FORMATTER.format(endDate.atStartOfDay(zoneId)));
if (!retentionAction.isPresent()) {
if (!collections.entrySet().stream().anyMatch(e -> e.getValue().stream().anyMatch(s -> s.getName().equals("_user")))) {
return format("select %s, %s null as %s",
checkTableColumn(projectConfig.getTimeColumn()),
dimension.isPresent() ? checkTableColumn(dimension.get(), "dimension", '"') + " as dimension, " : "",
connectorField);
}
return collections.entrySet().stream()
.filter(entry -> entry.getValue().stream().anyMatch(e -> e.getName().equals("_user")))
.map(collection -> getTableSubQuery(project, collection.getKey(),
connectorField,
dimension, timePredicate, Optional.empty()))
.collect(Collectors.joining(" union all "));
}
else {
String collection = retentionAction.get().collection();
return getTableSubQuery(
project, collection,
connectorField,
dimension, timePredicate,
retentionAction.get().filter());
}
}
protected String getTableSubQuery(String project, String collection,
String connectorField,
Optional<String> dimension,
String timePredicate,
Optional<Expression> filter)
{
return format("select %s, %s %s from %s where %s %s %s",
checkTableColumn(projectConfig.getTimeColumn()),
dimension.isPresent() ? checkTableColumn(dimension.get(), "dimension", '"') + " as dimension, " : "",
connectorField,
project + "." + checkCollection(collection),
checkTableColumn(projectConfig.getTimeColumn()),
timePredicate,
filter.isPresent() ? "and " + formatExpression(filter.get(), reference -> {
throw new UnsupportedOperationException();
}, '"') : "");
}
private static String PL_PGGSQL_RETENTION_AGGREGATE_FUNCTION = "create or replace function analyze_retention_intermediate(arr integer[], ff boolean[]) returns integer[] volatile language plpgsql as $$\n" +
"DECLARE \n" +
" i int;\n" +
"begin\n" +
" if ff is null then\n" +
" return arr;\n" +
" end if;\n" +
" \n" +
"FOR i IN 1 .. array_upper(ff, 1)\n" +
" LOOP\n" +
" if ff[i] = true then \n" +
" if arr[i] is null then \n" +
" arr[i] := 1;\n" +
" ELSE\n" +
" arr[i] := (arr[i]+1);\n" +
" end if;\n" +
" end if;\n" +
"END LOOP;\n" +
"return arr;\n" +
"END\n" +
"$$;";
private static String PL_PGGSQL_RETENTION_TIMELINE_FUNCTION =
"create or replace function generate_timeline(start date, arr date[], durationmillis bigint, max_step integer) returns boolean[] volatile language plpgsql as $$\n" +
"DECLARE \n" +
" steps boolean[];\n" +
" value int;\n" +
" gap int;\n" +
" item date;\n" +
"BEGIN\n" +
" if arr is null then\n" +
" return null;\n" +
" end if;\n" +
"\n" +
"-- substracting dates returns an integer that represents date diff\n" +
"durationmillis := durationmillis / 86400000; \n" +
"FOREACH item IN ARRAY arr\n" +
"LOOP\n" +
" value := (item - start);\n" +
"\n" +
" if value < 0 then \n" +
" continue; \n" +
" end if;\n" +
" \n" +
" gap := value / durationmillis;\n" +
" if gap > max_step then \n" +
" EXIT; \n" +
" end if;\n" +
"\n" +
" if steps is null then\n" +
" steps = cast(ARRAY[] as boolean[]);\n" +
" --steps = new Array(Math.min(max_step, arr.length))\n" +
" end if;\n" +
"\n" +
" steps[gap+1] := true;\n" +
"\n" +
" END LOOP;\n" +
"\n" +
" return steps;\n" +
"END\n" +
"$$;";
private static String V8_RETENTION_FUNCTIONS = "CREATE EXTENSION IF NOT EXISTS plv8;\n" +
"create or replace function analyze_retention_intermediate(arr integer[], ff boolean[]) returns integer[] volatile language plv8 as $$\n" +
"\n" +
" if(ff == null) return arr;\n" +
" for(var i=0; i <= ff.length; i++) {\n" +
" if(ff[i] === true) {\n" +
" arr[i] = arr[i] == null ? 1 : (arr[i]+1)\n" +
" }\n" +
" }\n" +
"\n" +
" return arr;\n" +
"$$;" +
"create or replace function generate_timeline(start date, arr date[], durationmillis bigint, max_step integer) returns boolean[] volatile language plv8 as $$\n" +
"\n" +
" if(arr == null) {\n" +
" return null;\n" +
" }\n" +
" var steps = null;\n" +
"\n" +
" for(var i=0; i <= arr.length; i++) {\n" +
" var value = (arr[i]-start);\n" +
" if(value < 0) continue;\n" +
" \n" +
" var gap = value / durationmillis;\n" +
" if(gap > max_step) { \n" +
"\tbreak; \n" +
" }\n" +
"\n" +
" if(steps == null) {\n" +
" steps = new Array(Math.min(max_step, arr.length))\n" +
" }\n" +
"\n" +
" //plv8.elog(ERROR, gap, value, start, durationmillis);\n" +
" steps[gap] = true;\n" +
" }\n" +
"\n" +
" return steps;\n" +
"$$;";
}