/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.graph.pregel; import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.GroupCombineFunction; import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields; import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsFirst; import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSecond; import org.apache.flink.api.java.operators.CoGroupOperator; import org.apache.flink.api.java.operators.CustomUnaryOperation; import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.EitherTypeInfo; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.configuration.Configuration; import org.apache.flink.graph.Edge; import org.apache.flink.graph.Graph; import org.apache.flink.graph.Vertex; import org.apache.flink.types.Either; import org.apache.flink.types.NullValue; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; import java.util.Iterator; import java.util.Map; /** * This class represents iterative graph computations, programmed in a vertex-centric perspective. * It is a special case of <i>Bulk Synchronous Parallel</i> computation. The paradigm has also been * implemented by Google's <i>Pregel</i> system and by <i>Apache Giraph</i>. * <p> * Vertex centric algorithms operate on graphs, which are defined through vertices and edges. The * algorithms send messages along the edges and update the state of vertices based on * the old state and the incoming messages. All vertices have an initial state. * The computation terminates once no vertex receives any message anymore. * Additionally, a maximum number of iterations (supersteps) may be specified. * <p> * The computation is here represented by one function: * <ul> * <li>The {@link ComputeFunction} receives incoming messages, may update the state for * the vertex, and sends messages along the edges of the vertex. * </li> * </ul> * <p> * * Vertex-centric graph iterations are run by calling * {@link Graph#runVertexCentricIteration(ComputeFunction, MessageCombiner, int)}. * * @param <K> The type of the vertex key (the vertex identifier). * @param <VV> The type of the vertex value (the state of the vertex). * @param <Message> The type of the message sent between vertices along the edges. * @param <EV> The type of the values that are associated with the edges. */ public class VertexCentricIteration<K, VV, EV, Message> implements CustomUnaryOperation<Vertex<K, VV>, Vertex<K, VV>> { private final ComputeFunction<K, VV, EV, Message> computeFunction; private final MessageCombiner<K, Message> combineFunction; private final DataSet<Edge<K, EV>> edgesWithValue; private final int maximumNumberOfIterations; private final TypeInformation<Message> messageType; private DataSet<Vertex<K, VV>> initialVertices; private VertexCentricConfiguration configuration; // ---------------------------------------------------------------------------------- private VertexCentricIteration(ComputeFunction<K, VV, EV, Message> cf, DataSet<Edge<K, EV>> edgesWithValue, MessageCombiner<K, Message> mc, int maximumNumberOfIterations) { Preconditions.checkNotNull(cf); Preconditions.checkNotNull(edgesWithValue); Preconditions.checkArgument(maximumNumberOfIterations > 0, "The maximum number of iterations must be at least one."); this.computeFunction = cf; this.edgesWithValue = edgesWithValue; this.combineFunction = mc; this.maximumNumberOfIterations = maximumNumberOfIterations; this.messageType = getMessageType(cf); } private TypeInformation<Message> getMessageType(ComputeFunction<K, VV, EV, Message> cf) { return TypeExtractor.createTypeInfo(cf, ComputeFunction.class, cf.getClass(), 3); } // -------------------------------------------------------------------------------------------- // Custom Operator behavior // -------------------------------------------------------------------------------------------- /** * Sets the input data set for this operator. In the case of this operator this input data set represents * the set of vertices with their initial state. * * @param inputData The input data set, which in the case of this operator represents the set of * vertices with their initial state. * * @see org.apache.flink.api.java.operators.CustomUnaryOperation#setInput(org.apache.flink.api.java.DataSet) */ @Override public void setInput(DataSet<Vertex<K, VV>> inputData) { this.initialVertices = inputData; } /** * Creates the operator that represents this vertex-centric graph computation. * <p> * The Pregel iteration is mapped to delta iteration as follows. * The solution set consists of the set of active vertices and the workset contains the set of messages * send to vertices during the previous superstep. Initially, the workset contains a null message for each vertex. * In the beginning of a superstep, the solution set is joined with the workset to produce * a dataset containing tuples of vertex state and messages (vertex inbox). * The superstep compute UDF is realized with a coGroup between the vertices with inbox and the graph edges. * The output of the compute UDF contains both the new vertex values and the new messages produced. * These are directed to the solution set delta and new workset, respectively, with subsequent flatMaps. * <p/> * * @return The operator that represents this vertex-centric graph computation. */ @Override public DataSet<Vertex<K, VV>> createResult() { if (this.initialVertices == null) { throw new IllegalStateException("The input data set has not been set."); } // prepare the type information TypeInformation<K> keyType = ((TupleTypeInfo<?>) initialVertices.getType()).getTypeAt(0); TypeInformation<Tuple2<K, Message>> messageTypeInfo = new TupleTypeInfo<>(keyType, messageType); TypeInformation<Vertex<K, VV>> vertexType = initialVertices.getType(); TypeInformation<Either<Vertex<K, VV>, Tuple2<K, Message>>> intermediateTypeInfo = new EitherTypeInfo<>(vertexType, messageTypeInfo); TypeInformation<Either<NullValue, Message>> nullableMsgTypeInfo = new EitherTypeInfo<>(TypeExtractor.getForClass(NullValue.class), messageType); TypeInformation<Tuple2<K, Either<NullValue, Message>>> workSetTypeInfo = new TupleTypeInfo<>(keyType, nullableMsgTypeInfo); DataSet<Tuple2<K, Either<NullValue, Message>>> initialWorkSet = initialVertices.map( new InitializeWorkSet<K, VV, Message>()).returns(workSetTypeInfo); final DeltaIteration<Vertex<K, VV>, Tuple2<K, Either<NullValue, Message>>> iteration = initialVertices.iterateDelta(initialWorkSet, this.maximumNumberOfIterations, 0); setUpIteration(iteration); // join with the current state to get vertex values DataSet<Tuple2<Vertex<K, VV>, Either<NullValue, Message>>> verticesWithMsgs = iteration.getSolutionSet().join(iteration.getWorkset()) .where(0).equalTo(0) .with(new AppendVertexState<K, VV, Message>()) .returns(new TupleTypeInfo<Tuple2<Vertex<K, VV>, Either<NullValue, Message>>>( vertexType, nullableMsgTypeInfo)); VertexComputeUdf<K, VV, EV, Message> vertexUdf = new VertexComputeUdf<>(computeFunction, intermediateTypeInfo); CoGroupOperator<?, ?, Either<Vertex<K, VV>, Tuple2<K, Message>>> superstepComputation = verticesWithMsgs.coGroup(edgesWithValue) .where("f0.f0").equalTo(0) .with(vertexUdf); // compute the solution set delta DataSet<Vertex<K, VV>> solutionSetDelta = superstepComputation.flatMap( new ProjectNewVertexValue<K, VV, Message>()).returns(vertexType); // compute the inbox of each vertex for the next superstep (new workset) DataSet<Tuple2<K, Either<NullValue, Message>>> allMessages = superstepComputation.flatMap( new ProjectMessages<K, VV, Message>()).returns(workSetTypeInfo); DataSet<Tuple2<K, Either<NullValue, Message>>> newWorkSet = allMessages; // check if a combiner has been provided if (combineFunction != null) { MessageCombinerUdf<K, Message> combinerUdf = new MessageCombinerUdf<>(combineFunction, workSetTypeInfo); DataSet<Tuple2<K, Either<NullValue, Message>>> combinedMessages = allMessages .groupBy(0).reduceGroup(combinerUdf) .setCombinable(true); newWorkSet = combinedMessages; } // configure the compute function superstepComputation = superstepComputation.name("Compute Function"); if (this.configuration != null) { for (Tuple2<String, DataSet<?>> e : this.configuration.getBcastVars()) { superstepComputation = superstepComputation.withBroadcastSet(e.f1, e.f0); } } return iteration.closeWith(solutionSetDelta, newWorkSet); } /** * Creates a new vertex-centric iteration operator. * * @param edgesWithValue The data set containing edges. * @param cf The compute function * * @param <K> The type of the vertex key (the vertex identifier). * @param <VV> The type of the vertex value (the state of the vertex). * @param <Message> The type of the message sent between vertices along the edges. * @param <EV> The type of the values that are associated with the edges. * * @return An instance of the vertex-centric graph computation operator. */ public static <K, VV, EV, Message> VertexCentricIteration<K, VV, EV, Message> withEdges( DataSet<Edge<K, EV>> edgesWithValue, ComputeFunction<K, VV, EV, Message> cf, int maximumNumberOfIterations) { return new VertexCentricIteration<>(cf, edgesWithValue, null, maximumNumberOfIterations); } /** * Creates a new vertex-centric iteration operator for graphs where the edges are associated with a value (such as * a weight or distance). * * @param edgesWithValue The data set containing edges. * @param cf The compute function. * @param mc The function that combines messages sent to a vertex during a superstep. * * @param <K> The type of the vertex key (the vertex identifier). * @param <VV> The type of the vertex value (the state of the vertex). * @param <Message> The type of the message sent between vertices along the edges. * @param <EV> The type of the values that are associated with the edges. * * @return An instance of the vertex-centric graph computation operator. */ public static <K, VV, EV, Message> VertexCentricIteration<K, VV, EV, Message> withEdges( DataSet<Edge<K, EV>> edgesWithValue, ComputeFunction<K, VV, EV, Message> cf, MessageCombiner<K, Message> mc, int maximumNumberOfIterations) { return new VertexCentricIteration<>(cf, edgesWithValue, mc, maximumNumberOfIterations); } /** * Configures this vertex-centric iteration with the provided parameters. * * @param parameters the configuration parameters */ public void configure(VertexCentricConfiguration parameters) { this.configuration = parameters; } /** * @return the configuration parameters of this vertex-centric iteration */ public VertexCentricConfiguration getIterationConfiguration() { return this.configuration; } // -------------------------------------------------------------------------------------------- // Wrapping UDFs // -------------------------------------------------------------------------------------------- @SuppressWarnings("serial") private static class InitializeWorkSet<K, VV, Message> extends RichMapFunction<Vertex<K, VV>, Tuple2<K, Either<NullValue, Message>>> { private Tuple2<K, Either<NullValue, Message>> outTuple; private Either<NullValue, Message> nullMessage; @Override public void open(Configuration parameters) { outTuple = new Tuple2<>(); nullMessage = Either.Left(NullValue.getInstance()); outTuple.f1 = nullMessage; } public Tuple2<K, Either<NullValue, Message>> map(Vertex<K, VV> vertex) { outTuple.f0 = vertex.getId(); return outTuple; } } /** * This coGroup class wraps the user-defined compute function. * The first input holds a Tuple2 containing the vertex state and its inbox. * The second input is an iterator of the out-going edges of this vertex. */ @SuppressWarnings("serial") private static class VertexComputeUdf<K, VV, EV, Message> extends RichCoGroupFunction< Tuple2<Vertex<K, VV>, Either<NullValue, Message>>, Edge<K, EV>, Either<Vertex<K, VV>, Tuple2<K, Message>>> implements ResultTypeQueryable<Either<Vertex<K, VV>, Tuple2<K, Message>>> { final ComputeFunction<K, VV, EV, Message> computeFunction; private transient TypeInformation<Either<Vertex<K, VV>, Tuple2<K, Message>>> resultType; private VertexComputeUdf(ComputeFunction<K, VV, EV, Message> compute, TypeInformation<Either<Vertex<K, VV>, Tuple2<K, Message>>> typeInfo) { this.computeFunction = compute; this.resultType = typeInfo; } @Override public TypeInformation<Either<Vertex<K, VV>, Tuple2<K, Message>>> getProducedType() { return this.resultType; } @Override public void open(Configuration parameters) throws Exception { if (getIterationRuntimeContext().getSuperstepNumber() == 1) { this.computeFunction.init(getIterationRuntimeContext()); } this.computeFunction.preSuperstep(); } @Override public void close() throws Exception { this.computeFunction.postSuperstep(); } @Override public void coGroup( Iterable<Tuple2<Vertex<K, VV>, Either<NullValue, Message>>> messages, Iterable<Edge<K, EV>> edgesIterator, Collector<Either<Vertex<K, VV>, Tuple2<K, Message>>> out) throws Exception { final Iterator<Tuple2<Vertex<K, VV>, Either<NullValue, Message>>> vertexIter = messages.iterator(); if (vertexIter.hasNext()) { final Tuple2<Vertex<K, VV>, Either<NullValue, Message>> first = vertexIter.next(); final Vertex<K, VV> vertexState = first.f0; final MessageIterator<Message> messageIter = new MessageIterator<>(); if (getIterationRuntimeContext().getSuperstepNumber() == 1) { // there are no messages during the 1st superstep } else { messageIter.setFirst(first.f1.right()); @SuppressWarnings("unchecked") Iterator<Tuple2<?, Either<NullValue, Message>>> downcastIter = (Iterator<Tuple2<?, Either<NullValue, Message>>>) (Iterator<?>) vertexIter; messageIter.setSource(downcastIter); } computeFunction.set(vertexState.getId(), edgesIterator.iterator(), out); computeFunction.compute(vertexState, messageIter); } } } @SuppressWarnings("serial") @ForwardedFields("f0") public static class MessageCombinerUdf<K, Message> extends RichGroupReduceFunction< Tuple2<K, Either<NullValue, Message>>, Tuple2<K, Either<NullValue, Message>>> implements ResultTypeQueryable<Tuple2<K, Either<NullValue, Message>>>, GroupCombineFunction<Tuple2<K, Either<NullValue, Message>>, Tuple2<K, Either<NullValue, Message>>> { final MessageCombiner<K, Message> combinerFunction; private transient TypeInformation<Tuple2<K, Either<NullValue, Message>>> resultType; private MessageCombinerUdf(MessageCombiner<K, Message> combineFunction, TypeInformation<Tuple2<K, Either<NullValue, Message>>> messageTypeInfo) { this.combinerFunction = combineFunction; this.resultType = messageTypeInfo; } @Override public TypeInformation<Tuple2<K, Either<NullValue, Message>>> getProducedType() { return resultType; } @Override public void reduce(Iterable<Tuple2<K, Either<NullValue, Message>>> messages, Collector<Tuple2<K, Either<NullValue, Message>>> out) throws Exception { final Iterator<Tuple2<K, Either<NullValue, Message>>> messageIterator = messages.iterator(); if (messageIterator.hasNext()) { final Tuple2<K, Either<NullValue, Message>> first = messageIterator.next(); final K vertexID = first.f0; final MessageIterator<Message> messageIter = new MessageIterator<>(); messageIter.setFirst(first.f1.right()); @SuppressWarnings("unchecked") Iterator<Tuple2<?, Either<NullValue, Message>>> downcastIter = (Iterator<Tuple2<?, Either<NullValue, Message>>>) (Iterator<?>) messageIterator; messageIter.setSource(downcastIter); combinerFunction.set(vertexID, out); combinerFunction.combineMessages(messageIter); } } @Override public void combine(Iterable<Tuple2<K, Either<NullValue, Message>>> values, Collector<Tuple2<K, Either<NullValue, Message>>> out) throws Exception { this.reduce(values, out); } } // -------------------------------------------------------------------------------------------- // UTIL methods // -------------------------------------------------------------------------------------------- /** * Helper method which sets up an iteration with the given vertex value * * @param iteration */ private void setUpIteration(DeltaIteration<?, ?> iteration) { // set up the iteration operator if (this.configuration != null) { iteration.name(this.configuration.getName("Vertex-centric iteration (" + computeFunction + ")")); iteration.parallelism(this.configuration.getParallelism()); iteration.setSolutionSetUnManaged(this.configuration.isSolutionSetUnmanagedMemory()); // register all aggregators for (Map.Entry<String, Aggregator<?>> entry : this.configuration.getAggregators().entrySet()) { iteration.registerAggregator(entry.getKey(), entry.getValue()); } } else { // no configuration provided; set default name iteration.name("Vertex-centric iteration (" + computeFunction + ")"); } } @SuppressWarnings("serial") @ForwardedFieldsFirst("*->f0") @ForwardedFieldsSecond("f1->f1") private static final class AppendVertexState<K, VV, Message> implements JoinFunction<Vertex<K, VV>, Tuple2<K, Either<NullValue, Message>>, Tuple2<Vertex<K, VV>, Either<NullValue, Message>>> { private Tuple2<Vertex<K, VV>, Either<NullValue, Message>> outTuple = new Tuple2<>(); public Tuple2<Vertex<K, VV>, Either<NullValue, Message>> join( Vertex<K, VV> vertex, Tuple2<K, Either<NullValue, Message>> message) { outTuple.f0 = vertex; outTuple.f1 = message.f1; return outTuple; } } @SuppressWarnings("serial") private static final class ProjectNewVertexValue<K, VV, Message> implements FlatMapFunction<Either<Vertex<K, VV>, Tuple2<K, Message>>, Vertex<K, VV>> { public void flatMap(Either<Vertex<K, VV>, Tuple2<K, Message>> value, Collector<Vertex<K, VV>> out) { if (value.isLeft()) { out.collect(value.left()); } } } @SuppressWarnings("serial") private static final class ProjectMessages<K, VV, Message> implements FlatMapFunction<Either<Vertex<K, VV>, Tuple2<K, Message>>, Tuple2<K, Either<NullValue, Message>>> { private Tuple2<K, Either<NullValue, Message>> outTuple = new Tuple2<>(); public void flatMap(Either<Vertex<K, VV>, Tuple2<K, Message>> value, Collector<Tuple2<K, Either<NullValue, Message>>> out) { if (value.isRight()) { Tuple2<K, Message> message = value.right(); outTuple.f0 = message.f0; outTuple.f1 = Either.Right(message.f1); out.collect(outTuple); } } } }