/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.ExecutionConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.AggregatorRegistry;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.IterationOperator;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.common.operators.OperatorInformation;
import org.apache.flink.api.common.operators.SingleInputOperator;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Visitor;
/**
*
*/
@Internal
public class BulkIterationBase<T> extends SingleInputOperator<T, T, AbstractRichFunction> implements IterationOperator {
private static final String DEFAULT_NAME = "<Unnamed Bulk Iteration>";
public static final String TERMINATION_CRITERION_AGGREGATOR_NAME = "terminationCriterion.aggregator";
private Operator<T> iterationResult;
private final Operator<T> inputPlaceHolder;
private final AggregatorRegistry aggregators = new AggregatorRegistry();
private int numberOfIterations = -1;
protected Operator<?> terminationCriterion;
// --------------------------------------------------------------------------------------------
/**
*
*/
public BulkIterationBase(UnaryOperatorInformation<T, T> operatorInfo) {
this(operatorInfo, DEFAULT_NAME);
}
/**
* @param name
*/
public BulkIterationBase(UnaryOperatorInformation<T, T> operatorInfo, String name) {
super(new UserCodeClassWrapper<AbstractRichFunction>(AbstractRichFunction.class), operatorInfo, name);
inputPlaceHolder = new PartialSolutionPlaceHolder<T>(this, this.getOperatorInfo());
}
// --------------------------------------------------------------------------------------------
/**
* @return The operator representing the partial solution.
*/
public Operator<T> getPartialSolution() {
return this.inputPlaceHolder;
}
/**
* @param result
*/
public void setNextPartialSolution(Operator<T> result) {
if (result == null) {
throw new NullPointerException("Operator producing the next partial solution must not be null.");
}
this.iterationResult = result;
}
/**
* @return The operator representing the next partial solution.
*/
public Operator<T> getNextPartialSolution() {
return this.iterationResult;
}
/**
* @return The operator representing the termination criterion.
*/
public Operator<?> getTerminationCriterion() {
return this.terminationCriterion;
}
/**
* @param criterion
*/
public <X> void setTerminationCriterion(Operator<X> criterion) {
TypeInformation<X> type = criterion.getOperatorInfo().getOutputType();
FlatMapOperatorBase<X, X, TerminationCriterionMapper<X>> mapper =
new FlatMapOperatorBase<X, X, TerminationCriterionMapper<X>>(
new TerminationCriterionMapper<X>(),
new UnaryOperatorInformation<X, X>(type, type),
"Termination Criterion Aggregation Wrapper");
mapper.setInput(criterion);
this.terminationCriterion = mapper;
this.getAggregators().registerAggregationConvergenceCriterion(TERMINATION_CRITERION_AGGREGATOR_NAME, new TerminationCriterionAggregator(), new TerminationCriterionAggregationConvergence());
}
/**
* @param num
*/
public void setMaximumNumberOfIterations(int num) {
if (num < 1) {
throw new IllegalArgumentException("The number of iterations must be at least one.");
}
this.numberOfIterations = num;
}
public int getMaximumNumberOfIterations() {
return this.numberOfIterations;
}
@Override
public AggregatorRegistry getAggregators() {
return this.aggregators;
}
/**
* @throws InvalidProgramException
*/
public void validate() throws InvalidProgramException {
if (this.input == null) {
throw new RuntimeException("Operator for initial partial solution is not set.");
}
if (this.iterationResult == null) {
throw new InvalidProgramException("Operator producing the next version of the partial " +
"solution (iteration result) is not set.");
}
if (this.terminationCriterion == null && this.numberOfIterations <= 0) {
throw new InvalidProgramException("No termination condition is set " +
"(neither fix number of iteration nor termination criterion).");
}
}
/**
* The BulkIteration meta operator cannot have broadcast inputs.
*
* @return An empty map.
*/
public Map<String, Operator<?>> getBroadcastInputs() {
return Collections.emptyMap();
}
/**
* The BulkIteration meta operator cannot have broadcast inputs.
* This method always throws an exception.
*
* @param name Ignored.
* @param root Ignored.
*/
public void setBroadcastVariable(String name, Operator<?> root) {
throw new UnsupportedOperationException("The BulkIteration meta operator cannot have broadcast inputs.");
}
/**
* The BulkIteration meta operator cannot have broadcast inputs.
* This method always throws an exception.
*
* @param inputs Ignored
*/
public <X> void setBroadcastVariables(Map<String, Operator<X>> inputs) {
throw new UnsupportedOperationException("The BulkIteration meta operator cannot have broadcast inputs.");
}
// --------------------------------------------------------------------------------------------
/**
* Specialized operator to use as a recognizable place-holder for the input to the
* step function when composing the nested data flow.
*/
public static class PartialSolutionPlaceHolder<OT> extends Operator<OT> {
private final BulkIterationBase<OT> containingIteration;
public PartialSolutionPlaceHolder(BulkIterationBase<OT> container, OperatorInformation<OT> operatorInfo) {
super(operatorInfo, "Partial Solution");
this.containingIteration = container;
}
public BulkIterationBase<OT> getContainingBulkIteration() {
return this.containingIteration;
}
@Override
public void accept(Visitor<Operator<?>> visitor) {
visitor.preVisit(this);
visitor.postVisit(this);
}
@Override
public UserCodeWrapper<?> getUserCodeWrapper() {
return null;
}
}
/**
* Special Mapper that is added before a termination criterion and is only a container for an special aggregator
*/
public static class TerminationCriterionMapper<X> extends AbstractRichFunction implements FlatMapFunction<X, X> {
private static final long serialVersionUID = 1L;
private TerminationCriterionAggregator aggregator;
@Override
public void open(Configuration parameters) {
aggregator = getIterationRuntimeContext().getIterationAggregator(TERMINATION_CRITERION_AGGREGATOR_NAME);
}
@Override
public void flatMap(X in, Collector<X> out) {
aggregator.aggregate(1L);
}
}
/**
* Aggregator that basically only adds 1 for every output tuple of the termination criterion branch
*/
@SuppressWarnings("serial")
public static class TerminationCriterionAggregator implements Aggregator<LongValue> {
private long count;
@Override
public LongValue getAggregate() {
return new LongValue(count);
}
public void aggregate(long count) {
this.count += count;
}
@Override
public void aggregate(LongValue count) {
this.count += count.getValue();
}
@Override
public void reset() {
count = 0;
}
}
/**
* Convergence for the termination criterion is reached if no tuple is output at current iteration for the termination criterion branch
*/
public static class TerminationCriterionAggregationConvergence implements ConvergenceCriterion<LongValue> {
private static final long serialVersionUID = 1L;
private static final Logger log = LoggerFactory.getLogger(TerminationCriterionAggregationConvergence.class);
@Override
public boolean isConverged(int iteration, LongValue countAggregate) {
long count = countAggregate.getValue();
if (log.isInfoEnabled()) {
log.info("Termination criterion stats in iteration [" + iteration + "]: " + count);
}
return count == 0;
}
}
@Override
protected List<T> executeOnCollections(List<T> inputData, RuntimeContext runtimeContext, ExecutionConfig executionConfig) {
throw new UnsupportedOperationException();
}
}