/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.operator.aggregation; import com.facebook.presto.spi.function.AccumulatorStateFactory; import com.facebook.presto.spi.function.AccumulatorStateSerializer; import com.google.common.base.Throwables; import java.lang.reflect.Constructor; import java.util.List; import java.util.Optional; import static java.util.Objects.requireNonNull; public class GenericAccumulatorFactoryBinder implements AccumulatorFactoryBinder { private final AccumulatorStateSerializer<?> stateSerializer; private final AccumulatorStateFactory<?> stateFactory; private final Constructor<? extends Accumulator> accumulatorConstructor; private final Constructor<? extends GroupedAccumulator> groupedAccumulatorConstructor; public GenericAccumulatorFactoryBinder( AccumulatorStateSerializer<?> stateSerializer, AccumulatorStateFactory<?> stateFactory, Class<? extends Accumulator> accumulatorClass, Class<? extends GroupedAccumulator> groupedAccumulatorClass) { this.stateSerializer = requireNonNull(stateSerializer, "stateSerializer is null"); this.stateFactory = requireNonNull(stateFactory, "stateFactory is null"); try { accumulatorConstructor = accumulatorClass.getConstructor( AccumulatorStateSerializer.class, AccumulatorStateFactory.class, List.class, Optional.class); groupedAccumulatorConstructor = groupedAccumulatorClass.getConstructor( AccumulatorStateSerializer.class, AccumulatorStateFactory.class, List.class, Optional.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); } } @Override public AccumulatorFactory bind(List<Integer> argumentChannels, Optional<Integer> maskChannel) { return new GenericAccumulatorFactory(stateSerializer, stateFactory, accumulatorConstructor, groupedAccumulatorConstructor, argumentChannels, maskChannel); } }