/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.iterative; import org.apache.flink.api.common.aggregators.Aggregator; import org.apache.flink.api.common.aggregators.ConvergenceCriterion; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.test.util.MultipleProgramsTestBase; import org.apache.flink.types.Value; import org.apache.flink.util.Collector; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import java.io.IOException; import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @RunWith(Parameterized.class) @SuppressWarnings({"serial", "unchecked"}) public class DanglingPageRankITCase extends MultipleProgramsTestBase { private static final String AGGREGATOR_NAME = "pagerank.aggregator"; public DanglingPageRankITCase(TestExecutionMode mode) { super(mode); } @Test public void testDanglingPageRank() { try { final int NUM_ITERATIONS = 25; final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Tuple2<Long, Boolean>> vertices = env.fromElements( new Tuple2<>(1L, false), new Tuple2<>(2L, false), new Tuple2<>(5L, false), new Tuple2<>(3L, true), new Tuple2<>(4L, false) ); DataSet<PageWithLinks> edges = env.fromElements( new PageWithLinks(2L, new long[] { 1 }), new PageWithLinks(5L, new long[] { 2, 4 }), new PageWithLinks(4L, new long[] { 3, 2 }), new PageWithLinks(1L, new long[] { 4, 2, 3 }) ); final long numVertices = vertices.count(); final long numDanglingVertices = vertices .filter( new FilterFunction<Tuple2<Long, Boolean>>() { @Override public boolean filter(Tuple2<Long, Boolean> value) { return value.f1; } }) .count(); DataSet<PageWithRankAndDangling> verticesWithInitialRank = vertices .map(new MapFunction<Tuple2<Long, Boolean>, PageWithRankAndDangling>() { @Override public PageWithRankAndDangling map(Tuple2<Long, Boolean> value) { return new PageWithRankAndDangling(value.f0, 1.0 / numVertices, value.f1); } }); IterativeDataSet<PageWithRankAndDangling> iteration = verticesWithInitialRank.iterate(NUM_ITERATIONS); iteration.getAggregators().registerAggregationConvergenceCriterion( AGGREGATOR_NAME, new PageRankStatsAggregator(), new DiffL1NormConvergenceCriterion()); DataSet<PageWithRank> partialRanks = iteration.join(edges).where("pageId").equalTo("pageId").with( new FlatJoinFunction<PageWithRankAndDangling, PageWithLinks, PageWithRank>() { @Override public void join(PageWithRankAndDangling page, PageWithLinks links, Collector<PageWithRank> out) { double rankToDistribute = page.rank / (double) links.targets.length; PageWithRank output = new PageWithRank(0L, rankToDistribute); for (long target : links.targets) { output.pageId = target; out.collect(output); } } } ); DataSet<PageWithRankAndDangling> newRanks = iteration.coGroup(partialRanks).where("pageId").equalTo("pageId").with( new RichCoGroupFunction<PageWithRankAndDangling, PageWithRank, PageWithRankAndDangling>() { private static final double BETA = 0.85; private final double randomJump = (1.0 - BETA) / numVertices; private PageRankStatsAggregator aggregator; private double danglingRankFactor; @Override public void open(Configuration parameters) throws Exception { int currentIteration = getIterationRuntimeContext().getSuperstepNumber(); aggregator = getIterationRuntimeContext().getIterationAggregator(AGGREGATOR_NAME); if (currentIteration == 1) { danglingRankFactor = BETA * (double) numDanglingVertices / ((double) numVertices * (double) numVertices); } else { PageRankStats previousAggregate = getIterationRuntimeContext() .getPreviousIterationAggregate(AGGREGATOR_NAME); danglingRankFactor = BETA * previousAggregate.danglingRank() / (double) numVertices; } } @Override public void coGroup(Iterable<PageWithRankAndDangling> currentPages, Iterable<PageWithRank> partialRanks, Collector<PageWithRankAndDangling> out) { // compute the next rank long edges = 0; double summedRank = 0; for (PageWithRank partial : partialRanks) { summedRank += partial.rank; edges++; } double rank = BETA * summedRank + randomJump + danglingRankFactor; // current rank, for stats and convergence PageWithRankAndDangling currentPage = currentPages.iterator().next(); double currentRank = currentPage.rank; boolean isDangling = currentPage.dangling; // maintain statistics to compensate for probability loss on dangling nodes double danglingRankToAggregate = isDangling ? rank : 0; long danglingVerticesToAggregate = isDangling ? 1 : 0; double diff = Math.abs(currentRank - rank); aggregator.aggregate(diff, rank, danglingRankToAggregate, danglingVerticesToAggregate, 1, edges); currentPage.rank = rank; out.collect(currentPage); } }); List<PageWithRankAndDangling> result = iteration.closeWith(newRanks).collect(); double totalRank = 0.0; for (PageWithRankAndDangling r : result) { totalRank += r.rank; assertTrue(r.pageId >= 1 && r.pageId <= 5); assertTrue(r.pageId != 3 || r.dangling); } assertEquals(1.0, totalRank, 0.001); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } } // ------------------------------------------------------------------------ // custom types // ------------------------------------------------------------------------ public static class PageWithRank { public long pageId; public double rank; public PageWithRank() {} public PageWithRank(long pageId, double rank) { this.pageId = pageId; this.rank = rank; } } public static class PageWithRankAndDangling { public long pageId; public double rank; public boolean dangling; public PageWithRankAndDangling() {} public PageWithRankAndDangling(long pageId, double rank, boolean dangling) { this.pageId = pageId; this.rank = rank; this.dangling = dangling; } @Override public String toString() { return "PageWithRankAndDangling{" + "pageId=" + pageId + ", rank=" + rank + ", dangling=" + dangling + '}'; } } public static class PageWithLinks { public long pageId; public long[] targets; public PageWithLinks() {} public PageWithLinks(long pageId, long[] targets) { this.pageId = pageId; this.targets = targets; } } // ------------------------------------------------------------------------ // statistics // ------------------------------------------------------------------------ public static class PageRankStats implements Value { private double diff; private double rank; private double danglingRank; private long numDanglingVertices; private long numVertices; private long edges; public PageRankStats() {} public PageRankStats( double diff, double rank, double danglingRank, long numDanglingVertices, long numVertices, long edges) { this.diff = diff; this.rank = rank; this.danglingRank = danglingRank; this.numDanglingVertices = numDanglingVertices; this.numVertices = numVertices; this.edges = edges; } public double diff() { return diff; } public double rank() { return rank; } public double danglingRank() { return danglingRank; } public long numDanglingVertices() { return numDanglingVertices; } public long numVertices() { return numVertices; } public long edges() { return edges; } @Override public void write(DataOutputView out) throws IOException { out.writeDouble(diff); out.writeDouble(rank); out.writeDouble(danglingRank); out.writeLong(numDanglingVertices); out.writeLong(numVertices); out.writeLong(edges); } @Override public void read(DataInputView in) throws IOException { diff = in.readDouble(); rank = in.readDouble(); danglingRank = in.readDouble(); numDanglingVertices = in.readLong(); numVertices = in.readLong(); edges = in.readLong(); } @Override public String toString() { return "PageRankStats: diff [" + diff + "], rank [" + rank + "], danglingRank [" + danglingRank + "], numDanglingVertices [" + numDanglingVertices + "], numVertices [" + numVertices + "], edges [" + edges + "]"; } } public static class PageRankStatsAggregator implements Aggregator<PageRankStats> { private double diff; private double rank; private double danglingRank; private long numDanglingVertices; private long numVertices; private long edges; @Override public PageRankStats getAggregate() { return new PageRankStats(diff, rank, danglingRank, numDanglingVertices, numVertices, edges); } public void aggregate(double diffDelta, double rankDelta, double danglingRankDelta, long danglingVerticesDelta, long verticesDelta, long edgesDelta) { diff += diffDelta; rank += rankDelta; danglingRank += danglingRankDelta; numDanglingVertices += danglingVerticesDelta; numVertices += verticesDelta; edges += edgesDelta; } @Override public void aggregate(PageRankStats pageRankStats) { diff += pageRankStats.diff(); rank += pageRankStats.rank(); danglingRank += pageRankStats.danglingRank(); numDanglingVertices += pageRankStats.numDanglingVertices(); numVertices += pageRankStats.numVertices(); edges += pageRankStats.edges(); } @Override public void reset() { diff = 0; rank = 0; danglingRank = 0; numDanglingVertices = 0; numVertices = 0; edges = 0; } } public static class DiffL1NormConvergenceCriterion implements ConvergenceCriterion<PageRankStats> { private static final double EPSILON = 0.00005; @Override public boolean isConverged(int iteration, PageRankStats pageRankStats) { return pageRankStats.diff() < EPSILON; } } }