/* * Licensed to CRATE Technology GmbH ("Crate") under one or more contributor * license agreements. See the NOTICE file distributed with this work for * additional information regarding copyright ownership. Crate licenses * this file to you under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * However, if you have executed another commercial license agreement * with Crate these terms will supersede the license and you may use the * software solely pursuant to the terms of the relevant commercial agreement. */ package io.crate.analyze; import io.crate.analyze.symbol.Literal; import io.crate.analyze.symbol.Symbol; import io.crate.data.Row; import io.crate.sql.tree.ParameterExpression; import io.crate.types.DataType; import io.crate.types.DataTypes; import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.function.Function; public class ParameterContext implements Function<ParameterExpression, Symbol> { public static final ParameterContext EMPTY = new ParameterContext(Row.EMPTY, Collections.<Row>emptyList()); private final Row parameters; private final List<Row> bulkParameters; private int currentIdx = 0; private ParamTypeHints typeHints = null; public ParameterContext(Row parameters, List<Row> bulkParameters) { this.parameters = parameters; if (bulkParameters.size() > 0) { validateBulkParams(bulkParameters); } this.bulkParameters = bulkParameters; } private void validateBulkParams(List<Row> bulkParams) { int length = bulkParams.get(0).numColumns(); for (Row bulkParam : bulkParams) { if (bulkParam.numColumns() != length) { throw new IllegalArgumentException("mixed number of arguments inside bulk arguments"); } } } private static DataType guessTypeSafe(Object value) throws IllegalArgumentException { DataType guessedType = DataTypes.guessType(value); if (guessedType == null) { throw new IllegalArgumentException(String.format(Locale.ENGLISH, "Got an argument \"%s\" that couldn't be recognized", value)); } return guessedType; } public boolean hasBulkParams() { return bulkParameters.size() > 0; } public int numBulkParams() { return bulkParameters.size(); } public void setBulkIdx(int i) { this.currentIdx = i; } public Row parameters() { if (hasBulkParams()) { return bulkParameters.get(currentIdx); } return parameters; } public io.crate.analyze.symbol.Literal getAsSymbol(int index) { try { Object value = parameters().get(index); DataType type = guessTypeSafe(value); // use type.value because some types need conversion (String to BytesRef, List to Array) return Literal.of(type, type.value(value)); } catch (ArrayIndexOutOfBoundsException e) { throw new IllegalArgumentException(String.format(Locale.ENGLISH, "Tried to resolve a parameter but the arguments provided with the " + "SQLRequest don't contain a parameter at position %d", index), e); } } public ParamTypeHints typeHints() { if (typeHints == null) { List<DataType> types = new ArrayList<>(parameters.numColumns()); for (int i = 0; i < parameters.numColumns(); i++) { types.add(guessTypeSafe(parameters.get(i))); } typeHints = new ParamTypeHints(types); } return typeHints; } @Nullable @Override public Symbol apply(@Nullable ParameterExpression input) { if (input == null) { return null; } return getAsSymbol(input.index()); } }