/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.api.common.aggregators; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.Map; import org.apache.flink.annotation.Internal; import org.apache.flink.types.Value; /** * A registry for iteration {@link Aggregator}s. */ @Internal public class AggregatorRegistry { private final Map<String, Aggregator<?>> registry = new HashMap<String, Aggregator<?>>(); private ConvergenceCriterion<? extends Value> convergenceCriterion; private String convergenceCriterionAggregatorName; // -------------------------------------------------------------------------------------------- public void registerAggregator(String name, Aggregator<?> aggregator) { if (name == null || aggregator == null) { throw new IllegalArgumentException("Name and aggregator must not be null"); } if (this.registry.containsKey(name)) { throw new RuntimeException("An aggregator is already registered under the given name."); } this.registry.put(name, aggregator); } public Collection<AggregatorWithName<?>> getAllRegisteredAggregators() { ArrayList<AggregatorWithName<?>> list = new ArrayList<AggregatorWithName<?>>(this.registry.size()); for (Map.Entry<String, Aggregator<?>> entry : this.registry.entrySet()) { @SuppressWarnings("unchecked") Aggregator<Value> valAgg = (Aggregator<Value>) entry.getValue(); list.add(new AggregatorWithName<>(entry.getKey(), valAgg)); } return list; } public <T extends Value> void registerAggregationConvergenceCriterion( String name, Aggregator<T> aggregator, ConvergenceCriterion<T> convergenceCheck) { if (name == null || aggregator == null || convergenceCheck == null) { throw new IllegalArgumentException("Name, aggregator, or convergence criterion must not be null"); } Aggregator<?> genAgg = aggregator; Aggregator<?> previous = this.registry.get(name); if (previous != null && previous != genAgg) { throw new RuntimeException("An aggregator is already registered under the given name."); } this.registry.put(name, genAgg); this.convergenceCriterion = convergenceCheck; this.convergenceCriterionAggregatorName = name; } public String getConvergenceCriterionAggregatorName() { return this.convergenceCriterionAggregatorName; } public ConvergenceCriterion<?> getConvergenceCriterion() { return this.convergenceCriterion; } public void addAll(AggregatorRegistry registry) { this.registry.putAll(registry.registry); this.convergenceCriterion = registry.convergenceCriterion; this.convergenceCriterionAggregatorName = registry.convergenceCriterionAggregatorName; } }