/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.raptor.systemtables; import com.facebook.presto.spi.predicate.Domain; import com.facebook.presto.spi.predicate.Range; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.google.common.base.Joiner; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Types; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; import static com.facebook.presto.raptor.util.DatabaseUtil.enableStreamingResults; import static com.facebook.presto.raptor.util.UuidUtil.uuidToBytes; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.type.Varchars.isVarcharType; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; import static java.sql.ResultSet.CONCUR_READ_ONLY; import static java.sql.ResultSet.TYPE_FORWARD_ONLY; import static java.util.Collections.nCopies; import static java.util.UUID.fromString; public final class PreparedStatementBuilder { private PreparedStatementBuilder() {} public static PreparedStatement create( Connection connection, String sql, List<String> columnNames, List<Type> types, Set<Integer> uuidColumnIndexes, TupleDomain<Integer> tupleDomain) throws SQLException { checkArgument(!isNullOrEmpty(sql), "sql is null or empty"); List<ValueBuffer> bindValues = new ArrayList<>(256); sql += getWhereClause(tupleDomain, columnNames, types, uuidColumnIndexes, bindValues); PreparedStatement statement = connection.prepareStatement(sql, TYPE_FORWARD_ONLY, CONCUR_READ_ONLY); enableStreamingResults(statement); // bind values to statement int bindIndex = 1; for (ValueBuffer value : bindValues) { bindField(value, statement, bindIndex, uuidColumnIndexes.contains(value.getColumnIndex())); bindIndex++; } return statement; } @SuppressWarnings("OptionalGetWithoutIsPresent") private static String getWhereClause( TupleDomain<Integer> tupleDomain, List<String> columnNames, List<Type> types, Set<Integer> uuidColumnIndexes, List<ValueBuffer> bindValues) { if (tupleDomain.isNone()) { return ""; } ImmutableList.Builder<String> conjunctsBuilder = ImmutableList.builder(); Map<Integer, Domain> domainMap = tupleDomain.getDomains().get(); for (Map.Entry<Integer, Domain> entry : domainMap.entrySet()) { int index = entry.getKey(); String columnName = columnNames.get(index); Type type = types.get(index); conjunctsBuilder.add(toPredicate(index, columnName, type, entry.getValue(), uuidColumnIndexes, bindValues)); } List<String> conjuncts = conjunctsBuilder.build(); if (conjuncts.isEmpty()) { return ""; } StringBuilder where = new StringBuilder("WHERE "); return Joiner.on(" AND\n").appendTo(where, conjuncts).toString(); } private static String toPredicate( int columnIndex, String columnName, Type type, Domain domain, Set<Integer> uuidColumnIndexes, List<ValueBuffer> bindValues) { if (domain.getValues().isAll()) { return domain.isNullAllowed() ? "TRUE" : columnName + " IS NOT NULL"; } if (domain.getValues().isNone()) { return domain.isNullAllowed() ? columnName + " IS NULL" : "FALSE"; } return domain.getValues().getValuesProcessor().transform( ranges -> { // Add disjuncts for ranges List<String> disjuncts = new ArrayList<>(); List<Object> singleValues = new ArrayList<>(); // Add disjuncts for ranges for (Range range : ranges.getOrderedRanges()) { checkState(!range.isAll()); // Already checked if (range.isSingleValue()) { singleValues.add(range.getLow().getValue()); } else { List<String> rangeConjuncts = new ArrayList<>(); if (!range.getLow().isLowerUnbounded()) { Object bindValue = getBindValue(columnIndex, uuidColumnIndexes, range.getLow().getValue()); switch (range.getLow().getBound()) { case ABOVE: rangeConjuncts.add(toBindPredicate(columnName, ">")); bindValues.add(ValueBuffer.create(columnIndex, type, bindValue)); break; case EXACTLY: rangeConjuncts.add(toBindPredicate(columnName, ">=")); bindValues.add(ValueBuffer.create(columnIndex, type, bindValue)); break; case BELOW: throw new VerifyException("Low Marker should never use BELOW bound"); default: throw new AssertionError("Unhandled bound: " + range.getLow().getBound()); } } if (!range.getHigh().isUpperUnbounded()) { Object bindValue = getBindValue(columnIndex, uuidColumnIndexes, range.getHigh().getValue()); switch (range.getHigh().getBound()) { case ABOVE: throw new VerifyException("High Marker should never use ABOVE bound"); case EXACTLY: rangeConjuncts.add(toBindPredicate(columnName, "<=")); bindValues.add(ValueBuffer.create(columnIndex, type, bindValue)); break; case BELOW: rangeConjuncts.add(toBindPredicate(columnName, "<")); bindValues.add(ValueBuffer.create(columnIndex, type, bindValue)); break; default: throw new AssertionError("Unhandled bound: " + range.getHigh().getBound()); } } // If rangeConjuncts is null, then the range was ALL, which should already have been checked for checkState(!rangeConjuncts.isEmpty()); disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); } } // Add back all of the possible single values either as an equality or an IN predicate if (singleValues.size() == 1) { disjuncts.add(toBindPredicate(columnName, "=")); bindValues.add(ValueBuffer.create(columnIndex, type, getBindValue(columnIndex, uuidColumnIndexes, getOnlyElement(singleValues)))); } else if (singleValues.size() > 1) { disjuncts.add(columnName + " IN (" + Joiner.on(",").join(nCopies(singleValues.size(), "?")) + ")"); for (Object singleValue : singleValues) { bindValues.add(ValueBuffer.create(columnIndex, type, getBindValue(columnIndex, uuidColumnIndexes, singleValue))); } } // Add nullability disjuncts checkState(!disjuncts.isEmpty()); if (domain.isNullAllowed()) { disjuncts.add(columnName + " IS NULL"); } return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; }, discreteValues -> { String values = Joiner.on(",").join(nCopies(discreteValues.getValues().size(), "?")); String predicate = columnName + (discreteValues.isWhiteList() ? "" : " NOT") + " IN (" + values + ")"; for (Object value : discreteValues.getValues()) { bindValues.add(ValueBuffer.create(columnIndex, type, getBindValue(columnIndex, uuidColumnIndexes, value))); } if (domain.isNullAllowed()) { predicate = "(" + predicate + " OR " + columnName + " IS NULL)"; } return predicate; }, allOrNone -> { throw new IllegalStateException("Case should not be reachable"); }); } private static Object getBindValue(int columnIndex, Set<Integer> uuidColumnIndexes, Object value) { if (uuidColumnIndexes.contains(columnIndex)) { return uuidToBytes(fromString(((Slice) value).toStringUtf8())); } return value; } private static String toBindPredicate(String columnName, String operator) { return format("%s %s ?", columnName, operator); } private static void bindField(ValueBuffer valueBuffer, PreparedStatement preparedStatement, int parameterIndex, boolean isUuid) throws SQLException { Type type = valueBuffer.getType(); if (valueBuffer.isNull()) { preparedStatement.setNull(parameterIndex, typeToSqlType(type)); } else if (type.getJavaType() == long.class) { preparedStatement.setLong(parameterIndex, valueBuffer.getLong()); } else if (type.getJavaType() == double.class) { preparedStatement.setDouble(parameterIndex, valueBuffer.getDouble()); } else if (type.getJavaType() == boolean.class) { preparedStatement.setBoolean(parameterIndex, valueBuffer.getBoolean()); } else if (type.getJavaType() == Slice.class && isUuid) { preparedStatement.setBytes(parameterIndex, valueBuffer.getSlice().getBytes()); } else if (type.getJavaType() == Slice.class) { preparedStatement.setString(parameterIndex, new String(valueBuffer.getSlice().getBytes())); } else { throw new IllegalArgumentException("Unknown Java type: " + type.getJavaType()); } } private static int typeToSqlType(Type type) { if (type.equals(BIGINT)) { return Types.BIGINT; } if (type.equals(DOUBLE)) { return Types.DOUBLE; } if (type.equals(BOOLEAN)) { return Types.BOOLEAN; } if (isVarcharType(type)) { return Types.VARCHAR; } if (type.equals(VARBINARY)) { return Types.VARBINARY; } throw new IllegalArgumentException("Unknown type: " + type); } }