/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.cassandra.cql3.functions; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; import java.util.List; import org.apache.cassandra.cql3.CQL3Type; import org.apache.cassandra.db.marshal.*; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.transport.ProtocolVersion; /** * Factory methods for aggregate functions. */ public abstract class AggregateFcts { public static Collection<AggregateFunction> all() { Collection<AggregateFunction> functions = new ArrayList<>(); functions.add(countRowsFunction); // sum for primitives functions.add(sumFunctionForByte); functions.add(sumFunctionForShort); functions.add(sumFunctionForInt32); functions.add(sumFunctionForLong); functions.add(sumFunctionForFloat); functions.add(sumFunctionForDouble); functions.add(sumFunctionForDecimal); functions.add(sumFunctionForVarint); functions.add(sumFunctionForCounter); // avg for primitives functions.add(avgFunctionForByte); functions.add(avgFunctionForShort); functions.add(avgFunctionForInt32); functions.add(avgFunctionForLong); functions.add(avgFunctionForFloat); functions.add(avgFunctionForDouble); functions.add(avgFunctionForDecimal); functions.add(avgFunctionForVarint); functions.add(avgFunctionForCounter); // count, max, and min for all standard types for (CQL3Type type : CQL3Type.Native.values()) { if (type != CQL3Type.Native.VARCHAR) // varchar and text both mapping to UTF8Type { functions.add(AggregateFcts.makeCountFunction(type.getType())); if (type != CQL3Type.Native.COUNTER) { functions.add(AggregateFcts.makeMaxFunction(type.getType())); functions.add(AggregateFcts.makeMinFunction(type.getType())); } else { functions.add(AggregateFcts.maxFunctionForCounter); functions.add(AggregateFcts.minFunctionForCounter); } } } return functions; } /** * The function used to count the number of rows of a result set. This function is called when COUNT(*) or COUNT(1) * is specified. */ public static final AggregateFunction countRowsFunction = new NativeAggregateFunction("countRows", LongType.instance) { public Aggregate newAggregate() { return new Aggregate() { private long count; public void reset() { count = 0; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return LongType.instance.decompose(count); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { count++; } }; } @Override public String columnName(List<String> columnNames) { return "count"; } }; /** * The SUM function for decimal values. */ public static final AggregateFunction sumFunctionForDecimal = new NativeAggregateFunction("sum", DecimalType.instance, DecimalType.instance) { @Override public Aggregate newAggregate() { return new Aggregate() { private BigDecimal sum = BigDecimal.ZERO; public void reset() { sum = BigDecimal.ZERO; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return ((DecimalType) returnType()).decompose(sum); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; BigDecimal number = DecimalType.instance.compose(value); sum = sum.add(number); } }; } }; /** * The AVG function for decimal values. */ public static final AggregateFunction avgFunctionForDecimal = new NativeAggregateFunction("avg", DecimalType.instance, DecimalType.instance) { public Aggregate newAggregate() { return new Aggregate() { private BigDecimal avg = BigDecimal.ZERO; private int count; public void reset() { count = 0; avg = BigDecimal.ZERO; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return DecimalType.instance.decompose(avg); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; count++; BigDecimal number = DecimalType.instance.compose(value); // avg = avg + (value - sum) / count. avg = avg.add(number.subtract(avg).divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN)); } }; } }; /** * The SUM function for varint values. */ public static final AggregateFunction sumFunctionForVarint = new NativeAggregateFunction("sum", IntegerType.instance, IntegerType.instance) { public Aggregate newAggregate() { return new Aggregate() { private BigInteger sum = BigInteger.ZERO; public void reset() { sum = BigInteger.ZERO; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return ((IntegerType) returnType()).decompose(sum); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; BigInteger number = IntegerType.instance.compose(value); sum = sum.add(number); } }; } }; /** * The AVG function for varint values. */ public static final AggregateFunction avgFunctionForVarint = new NativeAggregateFunction("avg", IntegerType.instance, IntegerType.instance) { public Aggregate newAggregate() { return new Aggregate() { private BigInteger sum = BigInteger.ZERO; private int count; public void reset() { count = 0; sum = BigInteger.ZERO; } public ByteBuffer compute(ProtocolVersion protocolVersion) { if (count == 0) return IntegerType.instance.decompose(BigInteger.ZERO); return IntegerType.instance.decompose(sum.divide(BigInteger.valueOf(count))); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; count++; BigInteger number = ((BigInteger) argTypes().get(0).compose(value)); sum = sum.add(number); } }; } }; /** * The SUM function for byte values (tinyint). */ public static final AggregateFunction sumFunctionForByte = new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance) { public Aggregate newAggregate() { return new Aggregate() { private byte sum; public void reset() { sum = 0; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return ((ByteType) returnType()).decompose(sum); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; Number number = ((Number) argTypes().get(0).compose(value)); sum += number.byteValue(); } }; } }; /** * AVG function for byte values (tinyint). */ public static final AggregateFunction avgFunctionForByte = new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance) { public Aggregate newAggregate() { return new AvgAggregate(ByteType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { return ByteType.instance.decompose((byte) computeInternal()); } }; } }; /** * The SUM function for short values (smallint). */ public static final AggregateFunction sumFunctionForShort = new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance) { public Aggregate newAggregate() { return new Aggregate() { private short sum; public void reset() { sum = 0; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return ((ShortType) returnType()).decompose(sum); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; Number number = ((Number) argTypes().get(0).compose(value)); sum += number.shortValue(); } }; } }; /** * AVG function for for short values (smallint). */ public static final AggregateFunction avgFunctionForShort = new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance) { public Aggregate newAggregate() { return new AvgAggregate(ShortType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) { return ShortType.instance.decompose((short) computeInternal()); } }; } }; /** * The SUM function for int32 values. */ public static final AggregateFunction sumFunctionForInt32 = new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance) { public Aggregate newAggregate() { return new Aggregate() { private int sum; public void reset() { sum = 0; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return ((Int32Type) returnType()).decompose(sum); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; Number number = ((Number) argTypes().get(0).compose(value)); sum += number.intValue(); } }; } }; /** * AVG function for int32 values. */ public static final AggregateFunction avgFunctionForInt32 = new NativeAggregateFunction("avg", Int32Type.instance, Int32Type.instance) { public Aggregate newAggregate() { return new AvgAggregate(Int32Type.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) { return Int32Type.instance.decompose((int) computeInternal()); } }; } }; /** * The SUM function for long values. */ public static final AggregateFunction sumFunctionForLong = new NativeAggregateFunction("sum", LongType.instance, LongType.instance) { public Aggregate newAggregate() { return new LongSumAggregate(); } }; /** * AVG function for long values. */ public static final AggregateFunction avgFunctionForLong = new NativeAggregateFunction("avg", LongType.instance, LongType.instance) { public Aggregate newAggregate() { return new AvgAggregate(LongType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) { return LongType.instance.decompose(computeInternal()); } }; } }; /** * The SUM function for float values. */ public static final AggregateFunction sumFunctionForFloat = new NativeAggregateFunction("sum", FloatType.instance, FloatType.instance) { public Aggregate newAggregate() { return new FloatSumAggregate(FloatType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { return FloatType.instance.decompose((float) computeInternal()); } }; } }; /** * AVG function for float values. */ public static final AggregateFunction avgFunctionForFloat = new NativeAggregateFunction("avg", FloatType.instance, FloatType.instance) { public Aggregate newAggregate() { return new FloatAvgAggregate(FloatType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { return FloatType.instance.decompose((float) computeInternal()); } }; } }; /** * The SUM function for double values. */ public static final AggregateFunction sumFunctionForDouble = new NativeAggregateFunction("sum", DoubleType.instance, DoubleType.instance) { public Aggregate newAggregate() { return new FloatSumAggregate(DoubleType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { return DoubleType.instance.decompose(computeInternal()); } }; } }; /** * Sum aggregate function for floating point numbers, using double arithmetics and * Kahan's algorithm to improve result precision. */ private static abstract class FloatSumAggregate implements AggregateFunction.Aggregate { private double sum; private double compensation; private double simpleSum; private final AbstractType numberType; public FloatSumAggregate(AbstractType numberType) { this.numberType = numberType; } public void reset() { sum = 0; compensation = 0; simpleSum = 0; } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; double number = ((Number) numberType.compose(value)).doubleValue(); simpleSum += number; double tmp = number - compensation; double rounded = sum + tmp; compensation = (rounded - sum) - tmp; sum = rounded; } public double computeInternal() { // correctly compute final sum if it's NaN from consequently // adding same-signed infinite values. double tmp = sum + compensation; if (Double.isNaN(tmp) && Double.isInfinite(simpleSum)) return simpleSum; else return tmp; } } /** * Average aggregate for floating point umbers, using double arithmetics and Kahan's algorithm * to calculate sum by default, switching to BigDecimal on sum overflow. Resulting number is * converted to corresponding representation by concrete implementations. */ private static abstract class FloatAvgAggregate implements AggregateFunction.Aggregate { private double sum; private double compensation; private double simpleSum; private int count; private BigDecimal bigSum = null; private boolean overflow = false; private final AbstractType numberType; public FloatAvgAggregate(AbstractType numberType) { this.numberType = numberType; } public void reset() { sum = 0; compensation = 0; simpleSum = 0; count = 0; bigSum = null; overflow = false; } public double computeInternal() { if (count == 0) return 0d; if (overflow) { return bigSum.divide(BigDecimal.valueOf(count), RoundingMode.HALF_EVEN).doubleValue(); } else { // correctly compute final sum if it's NaN from consequently // adding same-signed infinite values. double tmp = sum + compensation; if (Double.isNaN(tmp) && Double.isInfinite(simpleSum)) sum = simpleSum; else sum = tmp; return sum / count; } } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; count++; double number = ((Number) numberType.compose(value)).doubleValue(); if (overflow) { bigSum = bigSum.add(BigDecimal.valueOf(number)); } else { simpleSum += number; double prev = sum; double tmp = number - compensation; double rounded = sum + tmp; compensation = (rounded - sum) - tmp; sum = rounded; if (Double.isInfinite(sum) && !Double.isInfinite(number)) { overflow = true; bigSum = BigDecimal.valueOf(prev).add(BigDecimal.valueOf(number)); } } } } /** * AVG function for double values. */ public static final AggregateFunction avgFunctionForDouble = new NativeAggregateFunction("avg", DoubleType.instance, DoubleType.instance) { public Aggregate newAggregate() { return new FloatAvgAggregate(DoubleType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { return DoubleType.instance.decompose(computeInternal()); } }; } }; /** * The SUM function for counter column values. */ public static final AggregateFunction sumFunctionForCounter = new NativeAggregateFunction("sum", CounterColumnType.instance, CounterColumnType.instance) { public Aggregate newAggregate() { return new LongSumAggregate(); } }; /** * AVG function for counter column values. */ public static final AggregateFunction avgFunctionForCounter = new NativeAggregateFunction("avg", CounterColumnType.instance, CounterColumnType.instance) { public Aggregate newAggregate() { return new AvgAggregate(LongType.instance) { public ByteBuffer compute(ProtocolVersion protocolVersion) throws InvalidRequestException { return CounterColumnType.instance.decompose(computeInternal()); } }; } }; /** * The MIN function for counter column values. */ public static final AggregateFunction minFunctionForCounter = new NativeAggregateFunction("min", CounterColumnType.instance, CounterColumnType.instance) { public Aggregate newAggregate() { return new Aggregate() { private Long min; public void reset() { min = null; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return min != null ? LongType.instance.decompose(min) : null; } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; long lval = LongType.instance.compose(value); if (min == null || lval < min) min = lval; } }; } }; /** * MAX function for counter column values. */ public static final AggregateFunction maxFunctionForCounter = new NativeAggregateFunction("max", CounterColumnType.instance, CounterColumnType.instance) { public Aggregate newAggregate() { return new Aggregate() { private Long max; public void reset() { max = null; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return max != null ? LongType.instance.decompose(max) : null; } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; long lval = LongType.instance.compose(value); if (max == null || lval > max) max = lval; } }; } }; /** * Creates a MAX function for the specified type. * * @param inputType the function input and output type * @return a MAX function for the specified type. */ public static AggregateFunction makeMaxFunction(final AbstractType<?> inputType) { return new NativeAggregateFunction("max", inputType, inputType) { public Aggregate newAggregate() { return new Aggregate() { private ByteBuffer max; public void reset() { max = null; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return max; } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; if (max == null || returnType().compare(max, value) < 0) max = value; } }; } }; } /** * Creates a MIN function for the specified type. * * @param inputType the function input and output type * @return a MIN function for the specified type. */ public static AggregateFunction makeMinFunction(final AbstractType<?> inputType) { return new NativeAggregateFunction("min", inputType, inputType) { public Aggregate newAggregate() { return new Aggregate() { private ByteBuffer min; public void reset() { min = null; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return min; } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; if (min == null || returnType().compare(min, value) > 0) min = value; } }; } }; } /** * Creates a COUNT function for the specified type. * * @param inputType the function input type * @return a COUNT function for the specified type. */ public static AggregateFunction makeCountFunction(AbstractType<?> inputType) { return new NativeAggregateFunction("count", LongType.instance, inputType) { public Aggregate newAggregate() { return new Aggregate() { private long count; public void reset() { count = 0; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return ((LongType) returnType()).decompose(count); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; count++; } }; } }; } private static class LongSumAggregate implements AggregateFunction.Aggregate { private long sum; public void reset() { sum = 0; } public ByteBuffer compute(ProtocolVersion protocolVersion) { return LongType.instance.decompose(sum); } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; Number number = LongType.instance.compose(value); sum += number.longValue(); } } /** * Average aggregate class, collecting the sum using long arithmetics, falling back * to BigInteger on long overflow. Resulting number is converted to corresponding * representation by concrete implementations. */ private static abstract class AvgAggregate implements AggregateFunction.Aggregate { private long sum; private int count; private BigInteger bigSum = null; private boolean overflow = false; private final AbstractType numberType; public AvgAggregate(AbstractType type) { this.numberType = type; } public void reset() { count = 0; sum = 0L; overflow = false; bigSum = null; } long computeInternal() { if (overflow) { return bigSum.divide(BigInteger.valueOf(count)).longValue(); } else { return count == 0 ? 0 : (sum / count); } } public void addInput(ProtocolVersion protocolVersion, List<ByteBuffer> values) { ByteBuffer value = values.get(0); if (value == null) return; count++; long number = ((Number) numberType.compose(value)).longValue(); if (overflow) { bigSum = bigSum.add(BigInteger.valueOf(number)); } else { long prev = sum; sum += number; if (((prev ^ sum) & (number ^ sum)) < 0) { overflow = true; bigSum = BigInteger.valueOf(prev).add(BigInteger.valueOf(number)); } } } } }