/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.test;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.EdgeDirection;
import org.apache.flink.graph.Graph;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.gsa.ApplyFunction;
import org.apache.flink.graph.gsa.GSAConfiguration;
import org.apache.flink.graph.gsa.GatherFunction;
import org.apache.flink.graph.gsa.GatherSumApplyIteration;
import org.apache.flink.graph.gsa.Neighbor;
import org.apache.flink.graph.gsa.SumFunction;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.types.LongValue;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.HashSet;
import java.util.List;
@RunWith(Parameterized.class)
public class GatherSumApplyConfigurationITCase extends MultipleProgramsTestBase {
public GatherSumApplyConfigurationITCase(TestExecutionMode mode) {
super(mode);
}
private String expectedResult;
@Test
public void testRunWithConfiguration() throws Exception {
/*
* Test Graph's runGatherSumApplyIteration when configuration parameters are provided
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
Graph<Long, Long, Long> graph = Graph.fromCollection(TestGraphUtils.getLongLongVertices(),
TestGraphUtils.getLongLongEdges(), env).mapVertices(new AssignOneMapper());
// create the configuration object
GSAConfiguration parameters = new GSAConfiguration();
parameters.addBroadcastSetForGatherFunction("gatherBcastSet", env.fromElements(1, 2, 3));
parameters.addBroadcastSetForSumFunction("sumBcastSet", env.fromElements(4, 5, 6));
parameters.addBroadcastSetForApplyFunction("applyBcastSet", env.fromElements(7, 8, 9));
parameters.registerAggregator("superstepAggregator", new LongSumAggregator());
parameters.setOptNumVertices(true);
Graph<Long, Long, Long> res = graph.runGatherSumApplyIteration(new Gather(), new Sum(),
new Apply(), 10, parameters);
DataSet<Vertex<Long, Long>> data = res.getVertices();
List<Vertex<Long, Long>> result = data.collect();
expectedResult = "1,11\n" +
"2,11\n" +
"3,11\n" +
"4,11\n" +
"5,11";
compareResultAsTuples(result, expectedResult);
}
@Test
public void testIterationConfiguration() throws Exception {
/*
* Test name, parallelism and solutionSetUnmanaged parameters
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
GatherSumApplyIteration<Long, Long, Long, Long> iteration = GatherSumApplyIteration
.withEdges(TestGraphUtils.getLongLongEdgeData(env), new DummyGather(),
new DummySum(), new DummyApply(), 10);
GSAConfiguration parameters = new GSAConfiguration();
parameters.setName("gelly iteration");
parameters.setParallelism(2);
parameters.setSolutionSetUnmanagedMemory(true);
iteration.configure(parameters);
Assert.assertEquals("gelly iteration", iteration.getIterationConfiguration().getName(""));
Assert.assertEquals(2, iteration.getIterationConfiguration().getParallelism());
Assert.assertEquals(true, iteration.getIterationConfiguration().isSolutionSetUnmanagedMemory());
DataSet<Vertex<Long, Long>> data = TestGraphUtils.getLongLongVertexData(env).runOperation(iteration);
List<Vertex<Long, Long>> result = data.collect();
expectedResult = "1,11\n" +
"2,12\n" +
"3,13\n" +
"4,14\n" +
"5,15";
compareResultAsTuples(result, expectedResult);
}
@Test
public void testIterationDefaultDirection() throws Exception {
/*
* Test that if no direction parameter is given, the iteration works as before
* (i.e. it gathers information from the IN edges and neighbors and the information is calculated for an OUT edge
* Default direction parameter is OUT for the GatherSumApplyIterations)
* When data is gathered from the IN edges the Gather Sum and Apply functions
* set the set of vertices which have path to a vertex as the value of that vertex
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
List<Edge<Long, Long>> edges = TestGraphUtils.getLongLongEdges();
edges.remove(0);
Graph<Long, HashSet<Long>, Long> graph = Graph
.fromCollection(TestGraphUtils.getLongLongVertices(), edges, env)
.mapVertices(new GatherSumApplyConfigurationITCase.InitialiseHashSetMapper());
DataSet<Vertex<Long, HashSet<Long>>> resultedVertices = graph.runGatherSumApplyIteration(
new GetReachableVertices(), new FindAllReachableVertices(), new UpdateReachableVertices(), 4)
.getVertices();
List<Vertex<Long, HashSet<Long>>> result = resultedVertices.collect();
expectedResult = "1,[1, 2, 3, 4, 5]\n"
+ "2,[2]\n"
+ "3,[1, 2, 3, 4, 5]\n"
+ "4,[1, 2, 3, 4, 5]\n"
+ "5,[1, 2, 3, 4, 5]\n";
compareResultAsTuples(result, expectedResult);
}
@Test
public void testIterationDirectionIN() throws Exception {
/*
* Test that if the direction parameter IN is given, the iteration works as expected
* (i.e. it gathers information from the OUT edges and neighbors and the information is calculated for an IN edge
* When data is gathered from the OUT edges the Gather Sum and Apply functions
* set the set of vertices which have path from a vertex as the value of that vertex
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
GSAConfiguration parameters = new GSAConfiguration();
parameters.setDirection(EdgeDirection.IN);
List<Edge<Long, Long>> edges = TestGraphUtils.getLongLongEdges();
edges.remove(0);
Graph<Long, HashSet<Long>, Long> graph = Graph
.fromCollection(TestGraphUtils.getLongLongVertices(), edges, env)
.mapVertices(new GatherSumApplyConfigurationITCase.InitialiseHashSetMapper());
DataSet<Vertex<Long, HashSet<Long>>> resultedVertices = graph.runGatherSumApplyIteration(
new GetReachableVertices(), new FindAllReachableVertices(), new UpdateReachableVertices(), 4,
parameters)
.getVertices();
List<Vertex<Long, HashSet<Long>>> result = resultedVertices.collect();
expectedResult = "1,[1, 3, 4, 5]\n"
+ "2,[1, 2, 3, 4, 5]\n"
+ "3,[1, 3, 4, 5]\n"
+ "4,[1, 3, 4, 5]\n"
+ "5,[1, 3, 4, 5]\n";
compareResultAsTuples(result, expectedResult);
}
@Test
public void testIterationDirectionALL() throws Exception {
/*
* Test that if the direction parameter OUT is given, the iteration works as expected
* (i.e. it gathers information from both IN and OUT edges and neighbors
* When data is gathered from the ALL edges the Gather Sum and Apply functions
* set the set of vertices which are connected to a Vertex through some path as value of that vertex
*/
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
GSAConfiguration parameters = new GSAConfiguration();
parameters.setDirection(EdgeDirection.ALL);
List<Edge<Long, Long>> edges = TestGraphUtils.getLongLongEdges();
edges.remove(0);
Graph<Long, HashSet<Long>, Long> graph = Graph
.fromCollection(TestGraphUtils.getLongLongVertices(), edges, env)
.mapVertices(new GatherSumApplyConfigurationITCase.InitialiseHashSetMapper());
DataSet<Vertex<Long, HashSet<Long>>> resultedVertices = graph.runGatherSumApplyIteration(
new GetReachableVertices(), new FindAllReachableVertices(), new UpdateReachableVertices(), 4,
parameters)
.getVertices();
List<Vertex<Long, HashSet<Long>>> result = resultedVertices.collect();
expectedResult = "1,[1, 2, 3, 4, 5]\n"
+ "2,[1, 2, 3, 4, 5]\n"
+ "3,[1, 2, 3, 4, 5]\n"
+ "4,[1, 2, 3, 4, 5]\n"
+ "5,[1, 2, 3, 4, 5]\n";
compareResultAsTuples(result, expectedResult);
}
@SuppressWarnings("serial")
private static final class Gather extends GatherFunction<Long, Long, Long> {
@Override
public void preSuperstep() {
// test bcast variable
@SuppressWarnings("unchecked")
List<Integer> bcastSet = (List<Integer>)(List<?>)getBroadcastSet("gatherBcastSet");
Assert.assertEquals(1, bcastSet.get(0).intValue());
Assert.assertEquals(2, bcastSet.get(1).intValue());
Assert.assertEquals(3, bcastSet.get(2).intValue());
// test aggregator
if (getSuperstepNumber() == 2) {
long aggrValue = ((LongValue) getPreviousIterationAggregate("superstepAggregator")).getValue();
Assert.assertEquals(7, aggrValue);
}
// test number of vertices
Assert.assertEquals(5, getNumberOfVertices());
}
public Long gather(Neighbor<Long, Long> neighbor) {
return neighbor.getNeighborValue();
}
}
@SuppressWarnings("serial")
private static final class Sum extends SumFunction<Long, Long, Long> {
LongSumAggregator aggregator = new LongSumAggregator();
@Override
public void preSuperstep() {
// test bcast variable
@SuppressWarnings("unchecked")
List<Integer> bcastSet = (List<Integer>)(List<?>)getBroadcastSet("sumBcastSet");
Assert.assertEquals(4, bcastSet.get(0).intValue());
Assert.assertEquals(5, bcastSet.get(1).intValue());
Assert.assertEquals(6, bcastSet.get(2).intValue());
// test aggregator
aggregator = getIterationAggregator("superstepAggregator");
// test number of vertices
Assert.assertEquals(5, getNumberOfVertices());
}
public Long sum(Long newValue, Long currentValue) {
long superstep = getSuperstepNumber();
aggregator.aggregate(superstep);
return 0L;
}
}
@SuppressWarnings("serial")
private static final class Apply extends ApplyFunction<Long, Long, Long> {
LongSumAggregator aggregator = new LongSumAggregator();
@Override
public void preSuperstep() {
// test bcast variable
@SuppressWarnings("unchecked")
List<Integer> bcastSet = (List<Integer>)(List<?>)getBroadcastSet("applyBcastSet");
Assert.assertEquals(7, bcastSet.get(0).intValue());
Assert.assertEquals(8, bcastSet.get(1).intValue());
Assert.assertEquals(9, bcastSet.get(2).intValue());
// test aggregator
aggregator = getIterationAggregator("superstepAggregator");
// test number of vertices
Assert.assertEquals(5, getNumberOfVertices());
}
public void apply(Long summedValue, Long origValue) {
long superstep = getSuperstepNumber();
aggregator.aggregate(superstep);
setResult(origValue + 1);
}
}
@SuppressWarnings("serial")
private static final class DummyGather extends GatherFunction<Long, Long, Long> {
@Override
public void preSuperstep() {
// test number of vertices
// when the numVertices option is not set, -1 is returned
Assert.assertEquals(-1, getNumberOfVertices());
}
public Long gather(Neighbor<Long, Long> neighbor) {
return neighbor.getNeighborValue();
}
}
@SuppressWarnings("serial")
private static final class DummySum extends SumFunction<Long, Long, Long> {
public Long sum(Long newValue, Long currentValue) {
return 0L;
}
}
@SuppressWarnings("serial")
private static final class DummyApply extends ApplyFunction<Long, Long, Long> {
public void apply(Long summedValue, Long origValue) {
setResult(origValue + 1);
}
}
@SuppressWarnings("serial")
public static final class AssignOneMapper implements MapFunction<Vertex<Long, Long>, Long> {
public Long map(Vertex<Long, Long> value) {
return 1L;
}
}
@SuppressWarnings("serial")
public static final class InitialiseHashSetMapper implements MapFunction<Vertex<Long, Long>, HashSet<Long>> {
@Override
public HashSet<Long> map(Vertex<Long, Long> value) throws Exception {
HashSet<Long> h = new HashSet<>();
h.add(value.getId());
return h;
}
}
@SuppressWarnings("serial")
private static final class GetReachableVertices extends GatherFunction<HashSet<Long>, Long, HashSet<Long>> {
@Override
public HashSet<Long> gather(Neighbor<HashSet<Long>, Long> neighbor) {
return neighbor.getNeighborValue();
}
}
@SuppressWarnings("serial")
private static final class FindAllReachableVertices extends SumFunction<HashSet<Long>, Long, HashSet<Long>> {
@Override
public HashSet<Long> sum(HashSet<Long> newSet, HashSet<Long> currentSet) {
for (Long l : newSet) {
currentSet.add(l);
}
return currentSet;
}
}
@SuppressWarnings("serial")
private static final class UpdateReachableVertices extends ApplyFunction<Long, HashSet<Long>, HashSet<Long>> {
@Override
public void apply(HashSet<Long> newValue, HashSet<Long> currentValue) {
newValue.addAll(currentValue);
if (newValue.size() > currentValue.size()) {
setResult(newValue);
}
}
}
}