/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSecond;
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.ReduceOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.utils.GraphUtils;
import org.apache.flink.types.LongValue;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import java.util.Collection;
import java.util.Map;
/**
* This class represents iterative graph computations, programmed in a gather-sum-apply perspective.
*
* @param <K> The type of the vertex key in the graph
* @param <VV> The type of the vertex value in the graph
* @param <EV> The type of the edge value in the graph
* @param <M> The intermediate type used by the gather, sum and apply functions
*/
public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperation<Vertex<K, VV>,
Vertex<K, VV>> {
private DataSet<Vertex<K, VV>> vertexDataSet;
private DataSet<Edge<K, EV>> edgeDataSet;
private final GatherFunction<VV, EV, M> gather;
private final SumFunction<VV, EV, M> sum;
private final ApplyFunction<K, VV, M> apply;
private final int maximumNumberOfIterations;
private EdgeDirection direction = EdgeDirection.OUT;
private GSAConfiguration configuration;
// ----------------------------------------------------------------------------------
private GatherSumApplyIteration(GatherFunction<VV, EV, M> gather, SumFunction<VV, EV, M> sum,
ApplyFunction<K, VV, M> apply, DataSet<Edge<K, EV>> edges, int maximumNumberOfIterations) {
Preconditions.checkNotNull(gather);
Preconditions.checkNotNull(sum);
Preconditions.checkNotNull(apply);
Preconditions.checkNotNull(edges);
Preconditions.checkArgument(maximumNumberOfIterations > 0, "The maximum number of iterations must be at least one.");
this.gather = gather;
this.sum = sum;
this.apply = apply;
this.edgeDataSet = edges;
this.maximumNumberOfIterations = maximumNumberOfIterations;
}
// --------------------------------------------------------------------------------------------
// Custom Operator behavior
// --------------------------------------------------------------------------------------------
/**
* Sets the input data set for this operator. In the case of this operator this input data set represents
* the set of vertices with their initial state.
*
* @param dataSet The input data set, which in the case of this operator represents the set of
* vertices with their initial state.
*/
@Override
public void setInput(DataSet<Vertex<K, VV>> dataSet) {
this.vertexDataSet = dataSet;
}
/**
* Computes the results of the gather-sum-apply iteration
*
* @return The resulting DataSet
*/
@Override
public DataSet<Vertex<K, VV>> createResult() {
if (vertexDataSet == null) {
throw new IllegalStateException("The input data set has not been set.");
}
// Prepare type information
TypeInformation<K> keyType = ((TupleTypeInfo<?>) vertexDataSet.getType()).getTypeAt(0);
TypeInformation<M> messageType = TypeExtractor.createTypeInfo(gather, GatherFunction.class, gather.getClass(), 2);
TypeInformation<Tuple2<K, M>> innerType = new TupleTypeInfo<>(keyType, messageType);
TypeInformation<Vertex<K, VV>> outputType = vertexDataSet.getType();
// check whether the numVertices option is set and, if so, compute the total number of vertices
// and set it within the gather, sum and apply functions
DataSet<LongValue> numberOfVertices = null;
if (this.configuration != null && this.configuration.isOptNumVertices()) {
try {
numberOfVertices = GraphUtils.count(this.vertexDataSet);
} catch (Exception e) {
e.printStackTrace();
}
}
// Prepare UDFs
GatherUdf<K, VV, EV, M> gatherUdf = new GatherUdf<>(gather, innerType);
SumUdf<K, VV, EV, M> sumUdf = new SumUdf<>(sum, innerType);
ApplyUdf<K, VV, EV, M> applyUdf = new ApplyUdf<>(apply, outputType);
final int[] zeroKeyPos = new int[] {0};
final DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration =
vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos);
// set up the iteration operator
if (this.configuration != null) {
iteration.name(this.configuration.getName(
"Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")"));
iteration.parallelism(this.configuration.getParallelism());
iteration.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory());
// register all aggregators
for (Map.Entry<String, Aggregator<?>> entry : this.configuration.getAggregators().entrySet()) {
iteration.registerAggregator(entry.getKey(), entry.getValue());
}
}
else {
// no configuration provided; set default name
iteration.name("Gather-sum-apply iteration (" + gather + " | " + sum + " | " + apply + ")");
}
// Prepare the neighbors
if(this.configuration != null) {
direction = this.configuration.getDirection();
}
DataSet<Tuple2<K, Neighbor<VV, EV>>> neighbors;
switch(direction) {
case OUT:
neighbors = iteration
.getWorkset().join(edgeDataSet)
.where(0).equalTo(0).with(new ProjectKeyWithNeighborOUT<K, VV, EV>());
break;
case IN:
neighbors = iteration
.getWorkset().join(edgeDataSet)
.where(0).equalTo(1).with(new ProjectKeyWithNeighborIN<K, VV, EV>());
break;
case ALL:
neighbors = iteration
.getWorkset().join(edgeDataSet)
.where(0).equalTo(0).with(new ProjectKeyWithNeighborOUT<K, VV, EV>()).union(iteration
.getWorkset().join(edgeDataSet)
.where(0).equalTo(1).with(new ProjectKeyWithNeighborIN<K, VV, EV>()));
break;
default:
neighbors = iteration
.getWorkset().join(edgeDataSet)
.where(0).equalTo(0).with(new ProjectKeyWithNeighborOUT<K, VV, EV>());
break;
}
// Gather, sum and apply
MapOperator<Tuple2<K, Neighbor<VV, EV>>, Tuple2<K, M>> gatherMapOperator = neighbors.map(gatherUdf);
// configure map gather function with name and broadcast variables
gatherMapOperator = gatherMapOperator.name("Gather");
if (this.configuration != null) {
for (Tuple2<String, DataSet<?>> e : this.configuration.getGatherBcastVars()) {
gatherMapOperator = gatherMapOperator.withBroadcastSet(e.f1, e.f0);
}
if (this.configuration.isOptNumVertices()) {
gatherMapOperator = gatherMapOperator.withBroadcastSet(numberOfVertices, "number of vertices");
}
}
DataSet<Tuple2<K, M>> gatheredSet = gatherMapOperator;
ReduceOperator<Tuple2<K, M>> sumReduceOperator = gatheredSet.groupBy(0).reduce(sumUdf);
// configure reduce sum function with name and broadcast variables
sumReduceOperator = sumReduceOperator.name("Sum");
if (this.configuration != null) {
for (Tuple2<String, DataSet<?>> e : this.configuration.getSumBcastVars()) {
sumReduceOperator = sumReduceOperator.withBroadcastSet(e.f1, e.f0);
}
if (this.configuration.isOptNumVertices()) {
sumReduceOperator = sumReduceOperator.withBroadcastSet(numberOfVertices, "number of vertices");
}
}
DataSet<Tuple2<K, M>> summedSet = sumReduceOperator;
JoinOperator<?, ?, Vertex<K, VV>> appliedSet = summedSet
.join(iteration.getSolutionSet())
.where(0)
.equalTo(0)
.with(applyUdf);
// configure join apply function with name and broadcast variables
appliedSet = appliedSet.name("Apply");
if (this.configuration != null) {
for (Tuple2<String, DataSet<?>> e : this.configuration.getApplyBcastVars()) {
appliedSet = appliedSet.withBroadcastSet(e.f1, e.f0);
}
if (this.configuration.isOptNumVertices()) {
appliedSet = appliedSet.withBroadcastSet(numberOfVertices, "number of vertices");
}
}
// let the operator know that we preserve the key field
appliedSet.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0");
return iteration.closeWith(appliedSet, appliedSet);
}
/**
* Creates a new gather-sum-apply iteration operator for graphs
*
* @param edges The edge DataSet
*
* @param gather The gather function of the GSA iteration
* @param sum The sum function of the GSA iteration
* @param apply The apply function of the GSA iteration
*
* @param maximumNumberOfIterations The maximum number of iterations executed
*
* @param <K> The type of the vertex key in the graph
* @param <VV> The type of the vertex value in the graph
* @param <EV> The type of the edge value in the graph
* @param <M> The intermediate type used by the gather, sum and apply functions
*
* @return An in stance of the gather-sum-apply graph computation operator.
*/
public static <K, VV, EV, M> GatherSumApplyIteration<K, VV, EV, M>
withEdges(DataSet<Edge<K, EV>> edges, GatherFunction<VV, EV, M> gather,
SumFunction<VV, EV, M> sum, ApplyFunction<K, VV, M> apply, int maximumNumberOfIterations) {
return new GatherSumApplyIteration<>(gather, sum, apply, edges, maximumNumberOfIterations);
}
// --------------------------------------------------------------------------------------------
// Wrapping UDFs
// --------------------------------------------------------------------------------------------
@SuppressWarnings("serial")
@ForwardedFields("f0")
private static final class GatherUdf<K, VV, EV, M> extends RichMapFunction<Tuple2<K, Neighbor<VV, EV>>,
Tuple2<K, M>> implements ResultTypeQueryable<Tuple2<K, M>> {
private final GatherFunction<VV, EV, M> gatherFunction;
private transient TypeInformation<Tuple2<K, M>> resultType;
private GatherUdf(GatherFunction<VV, EV, M> gatherFunction, TypeInformation<Tuple2<K, M>> resultType) {
this.gatherFunction = gatherFunction;
this.resultType = resultType;
}
@Override
public Tuple2<K, M> map(Tuple2<K, Neighbor<VV, EV>> neighborTuple) {
M result = this.gatherFunction.gather(neighborTuple.f1);
return new Tuple2<>(neighborTuple.f0, result);
}
@Override
public void open(Configuration parameters) throws Exception {
if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices");
this.gatherFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
}
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.gatherFunction.init(getIterationRuntimeContext());
}
this.gatherFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.gatherFunction.postSuperstep();
}
@Override
public TypeInformation<Tuple2<K, M>> getProducedType() {
return this.resultType;
}
}
@SuppressWarnings("serial")
private static final class SumUdf<K, VV, EV, M> extends RichReduceFunction<Tuple2<K, M>>
implements ResultTypeQueryable<Tuple2<K, M>>{
private final SumFunction<VV, EV, M> sumFunction;
private transient TypeInformation<Tuple2<K, M>> resultType;
private SumUdf(SumFunction<VV, EV, M> sumFunction, TypeInformation<Tuple2<K, M>> resultType) {
this.sumFunction = sumFunction;
this.resultType = resultType;
}
@Override
public Tuple2<K, M> reduce(Tuple2<K, M> arg0, Tuple2<K, M> arg1) throws Exception {
M result = this.sumFunction.sum(arg0.f1, arg1.f1);
// if the user returns value from the right argument then swap as
// in ReduceDriver.run()
if (result == arg1.f1) {
M tmp = arg1.f1;
arg1.f1 = arg0.f1;
arg0.f1 = tmp;
} else {
arg0.f1 = result;
}
return arg0;
}
@Override
public void open(Configuration parameters) throws Exception {
if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices");
this.sumFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
}
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.sumFunction.init(getIterationRuntimeContext());
}
this.sumFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.sumFunction.postSuperstep();
}
@Override
public TypeInformation<Tuple2<K, M>> getProducedType() {
return this.resultType;
}
}
@SuppressWarnings("serial")
private static final class ApplyUdf<K, VV, EV, M> extends RichFlatJoinFunction<Tuple2<K, M>,
Vertex<K, VV>, Vertex<K, VV>> implements ResultTypeQueryable<Vertex<K, VV>> {
private final ApplyFunction<K, VV, M> applyFunction;
private transient TypeInformation<Vertex<K, VV>> resultType;
private ApplyUdf(ApplyFunction<K, VV, M> applyFunction, TypeInformation<Vertex<K, VV>> resultType) {
this.applyFunction = applyFunction;
this.resultType = resultType;
}
@Override
public void join(Tuple2<K, M> newValue, final Vertex<K, VV> currentValue, final Collector<Vertex<K, VV>> out) throws Exception {
this.applyFunction.setOutput(currentValue, out);
this.applyFunction.apply(newValue.f1, currentValue.getValue());
}
@Override
public void open(Configuration parameters) throws Exception {
if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number of vertices");
this.applyFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
}
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.applyFunction.init(getIterationRuntimeContext());
}
this.applyFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.applyFunction.postSuperstep();
}
@Override
public TypeInformation<Vertex<K, VV>> getProducedType() {
return this.resultType;
}
}
@SuppressWarnings("serial")
@ForwardedFieldsSecond("f1->f0")
private static final class ProjectKeyWithNeighborOUT<K, VV, EV> implements FlatJoinFunction<
Vertex<K, VV>, Edge<K, EV>, Tuple2<K, Neighbor<VV, EV>>> {
public void join(Vertex<K, VV> vertex, Edge<K, EV> edge, Collector<Tuple2<K, Neighbor<VV, EV>>> out) {
out.collect(new Tuple2<>(
edge.getTarget(), new Neighbor<>(vertex.getValue(), edge.getValue())));
}
}
@SuppressWarnings("serial")
@ForwardedFieldsSecond({"f0"})
private static final class ProjectKeyWithNeighborIN<K, VV, EV> implements FlatJoinFunction<
Vertex<K, VV>, Edge<K, EV>, Tuple2<K, Neighbor<VV, EV>>> {
public void join(Vertex<K, VV> vertex, Edge<K, EV> edge, Collector<Tuple2<K, Neighbor<VV, EV>>> out) {
out.collect(new Tuple2<>(
edge.getSource(), new Neighbor<>(vertex.getValue(), edge.getValue())));
}
}
/**
* Configures this gather-sum-apply iteration with the provided parameters.
*
* @param parameters the configuration parameters
*/
public void configure(GSAConfiguration parameters) {
this.configuration = parameters;
}
/**
* @return the configuration parameters of this gather-sum-apply iteration
*/
public GSAConfiguration getIterationConfiguration() {
return this.configuration;
}
}