/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.cassandra.cql3.functions; import java.nio.ByteBuffer; import java.util.*; import com.google.common.base.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.schema.Functions; import org.apache.cassandra.tracing.Tracing; /** * Base class for user-defined-aggregates. */ public class UDAggregate extends AbstractFunction implements AggregateFunction { protected static final Logger logger = LoggerFactory.getLogger(UDAggregate.class); protected final AbstractType<?> stateType; protected final ByteBuffer initcond; private final ScalarFunction stateFunction; private final ScalarFunction finalFunction; public UDAggregate(FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, ScalarFunction stateFunc, ScalarFunction finalFunc, ByteBuffer initcond) { super(name, argTypes, returnType); this.stateFunction = stateFunc; this.finalFunction = finalFunc; this.stateType = stateFunc != null ? stateFunc.returnType() : null; this.initcond = initcond; } public static UDAggregate create(Functions functions, FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, FunctionName stateFunc, FunctionName finalFunc, AbstractType<?> stateType, ByteBuffer initcond) throws InvalidRequestException { List<AbstractType<?>> stateTypes = new ArrayList<>(argTypes.size() + 1); stateTypes.add(stateType); stateTypes.addAll(argTypes); List<AbstractType<?>> finalTypes = Collections.<AbstractType<?>>singletonList(stateType); return new UDAggregate(name, argTypes, returnType, resolveScalar(functions, name, stateFunc, stateTypes), finalFunc != null ? resolveScalar(functions, name, finalFunc, finalTypes) : null, initcond); } public static UDAggregate createBroken(FunctionName name, List<AbstractType<?>> argTypes, AbstractType<?> returnType, ByteBuffer initcond, final InvalidRequestException reason) { return new UDAggregate(name, argTypes, returnType, null, null, initcond) { public Aggregate newAggregate() throws InvalidRequestException { throw new InvalidRequestException(String.format("Aggregate '%s' exists but hasn't been loaded successfully for the following reason: %s. " + "Please see the server log for more details", this, reason.getMessage())); } }; } public boolean hasReferenceTo(Function function) { return stateFunction == function || finalFunction == function; } @Override public void addFunctionsTo(List<Function> functions) { functions.add(this); if (stateFunction != null) { stateFunction.addFunctionsTo(functions); if (finalFunction != null) finalFunction.addFunctionsTo(functions); } } public boolean isAggregate() { return true; } public boolean isNative() { return false; } public ScalarFunction stateFunction() { return stateFunction; } public ScalarFunction finalFunction() { return finalFunction; } public ByteBuffer initialCondition() { return initcond; } public AbstractType<?> stateType() { return stateType; } public Aggregate newAggregate() throws InvalidRequestException { return new Aggregate() { private long stateFunctionCount; private long stateFunctionDuration; private ByteBuffer state; { reset(); } public void addInput(int protocolVersion, List<ByteBuffer> values) throws InvalidRequestException { long startTime = System.nanoTime(); stateFunctionCount++; List<ByteBuffer> fArgs = new ArrayList<>(values.size() + 1); fArgs.add(state); fArgs.addAll(values); if (stateFunction instanceof UDFunction) { UDFunction udf = (UDFunction)stateFunction; if (udf.isCallableWrtNullable(fArgs)) state = udf.execute(protocolVersion, fArgs); } else { state = stateFunction.execute(protocolVersion, fArgs); } stateFunctionDuration += (System.nanoTime() - startTime) / 1000; } public ByteBuffer compute(int protocolVersion) throws InvalidRequestException { // final function is traced in UDFunction Tracing.trace("Executed UDA {}: {} call(s) to state function {} in {}\u03bcs", name(), stateFunctionCount, stateFunction.name(), stateFunctionDuration); if (finalFunction == null) return state; List<ByteBuffer> fArgs = Collections.singletonList(state); ByteBuffer result = finalFunction.execute(protocolVersion, fArgs); return result; } public void reset() { state = initcond != null ? initcond.duplicate() : null; stateFunctionDuration = 0; stateFunctionCount = 0; } }; } private static ScalarFunction resolveScalar(Functions functions, FunctionName aName, FunctionName fName, List<AbstractType<?>> argTypes) throws InvalidRequestException { Optional<Function> fun = functions.find(fName, argTypes); if (!fun.isPresent()) throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' does not exist", fName, Arrays.toString(UDHelper.driverTypes(argTypes)), aName)); if (!(fun.get() instanceof ScalarFunction)) throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' is not a scalar function", fName, Arrays.toString(UDHelper.driverTypes(argTypes)), aName)); return (ScalarFunction) fun.get(); } @Override public boolean equals(Object o) { if (!(o instanceof UDAggregate)) return false; UDAggregate that = (UDAggregate) o; return Objects.equal(name, that.name) && Functions.typesMatch(argTypes, that.argTypes) && Functions.typesMatch(returnType, that.returnType) && Objects.equal(stateFunction, that.stateFunction) && Objects.equal(finalFunction, that.finalFunction) && Objects.equal(stateType, that.stateType) && Objects.equal(initcond, that.initcond); } @Override public int hashCode() { return Objects.hashCode(name, Functions.typeHashCode(argTypes), Functions.typeHashCode(returnType), stateFunction, finalFunction, stateType, initcond); } }