/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.giraph.examples; import org.apache.giraph.aggregators.LongSumAggregator; import org.apache.giraph.bsp.BspInputSplit; import org.apache.giraph.edge.Edge; import org.apache.giraph.edge.EdgeFactory; import org.apache.giraph.graph.BasicComputation; import org.apache.giraph.master.DefaultMasterCompute; import org.apache.giraph.graph.Vertex; import org.apache.giraph.io.EdgeInputFormat; import org.apache.giraph.io.EdgeReader; import org.apache.giraph.io.VertexReader; import org.apache.giraph.io.formats.GeneratedVertexInputFormat; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.JobContext; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.log4j.Logger; import com.google.common.collect.Lists; import java.io.IOException; import java.util.ArrayList; import java.util.List; /** Computation which uses aggrergators. To be used for testing. */ public class AggregatorsTestComputation extends BasicComputation<LongWritable, DoubleWritable, FloatWritable, DoubleWritable> { /** Name of regular aggregator */ private static final String REGULAR_AGG = "regular"; /** Name of persistent aggregator */ private static final String PERSISTENT_AGG = "persistent"; /** Name of input super step persistent aggregator */ private static final String INPUT_VERTEX_PERSISTENT_AGG = "input_super_step_vertex_agg"; /** Name of input super step persistent aggregator */ private static final String INPUT_EDGE_PERSISTENT_AGG = "input_super_step_edge_agg"; /** Name of master overwriting aggregator */ private static final String MASTER_WRITE_AGG = "master"; /** Value which master compute will use */ private static final long MASTER_VALUE = 12345; /** Prefix for name of aggregators in array */ private static final String ARRAY_PREFIX_AGG = "array"; /** Number of aggregators to use in array */ private static final int NUM_OF_AGGREGATORS_IN_ARRAY = 100; @Override public void compute( Vertex<LongWritable, DoubleWritable, FloatWritable> vertex, Iterable<DoubleWritable> messages) throws IOException { long superstep = getSuperstep(); LongWritable myValue = new LongWritable(1L << superstep); aggregate(REGULAR_AGG, myValue); aggregate(PERSISTENT_AGG, myValue); long nv = getTotalNumVertices(); if (superstep > 0) { assertEquals(nv * (1L << (superstep - 1)), ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); } else { assertEquals(0, ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); } assertEquals(nv * ((1L << superstep) - 1), ((LongWritable) getAggregatedValue(PERSISTENT_AGG)).get()); assertEquals(MASTER_VALUE * (1L << superstep), ((LongWritable) getAggregatedValue(MASTER_WRITE_AGG)).get()); for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) { aggregate(ARRAY_PREFIX_AGG + i, new LongWritable((superstep + 1) * i)); assertEquals(superstep * getTotalNumVertices() * i, ((LongWritable) getAggregatedValue(ARRAY_PREFIX_AGG + i)).get()); } if (getSuperstep() == 10) { vertex.voteToHalt(); } } /** Master compute which uses aggregators. To be used for testing. */ public static class AggregatorsTestMasterCompute extends DefaultMasterCompute { @Override public void compute() { long superstep = getSuperstep(); LongWritable myValue = new LongWritable(MASTER_VALUE * (1L << superstep)); setAggregatedValue(MASTER_WRITE_AGG, myValue); long nv = getTotalNumVertices(); if (superstep >= 0) { assertEquals(100, ((LongWritable) getAggregatedValue(INPUT_VERTEX_PERSISTENT_AGG)).get()); } if (superstep >= 0) { assertEquals(4500, ((LongWritable) getAggregatedValue(INPUT_EDGE_PERSISTENT_AGG)).get()); } if (superstep > 0) { assertEquals(nv * (1L << (superstep - 1)), ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); } else { assertEquals(0, ((LongWritable) getAggregatedValue(REGULAR_AGG)).get()); } assertEquals(nv * ((1L << superstep) - 1), ((LongWritable) getAggregatedValue(PERSISTENT_AGG)).get()); for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) { assertEquals(superstep * getTotalNumVertices() * i, ((LongWritable) getAggregatedValue(ARRAY_PREFIX_AGG + i)).get()); } } @Override public void initialize() throws InstantiationException, IllegalAccessException { registerPersistentAggregator( INPUT_VERTEX_PERSISTENT_AGG, LongSumAggregator.class); registerPersistentAggregator( INPUT_EDGE_PERSISTENT_AGG, LongSumAggregator.class); registerAggregator(REGULAR_AGG, LongSumAggregator.class); registerPersistentAggregator(PERSISTENT_AGG, LongSumAggregator.class); registerAggregator(MASTER_WRITE_AGG, LongSumAggregator.class); for (int i = 0; i < NUM_OF_AGGREGATORS_IN_ARRAY; i++) { registerAggregator(ARRAY_PREFIX_AGG + i, LongSumAggregator.class); } } } /** * Throws exception if values are not equal. * * @param expected Expected value * @param actual Actual value */ private static void assertEquals(long expected, long actual) { if (expected != actual) { throw new RuntimeException("expected: " + expected + ", actual: " + actual); } } /** * Simple VertexReader */ public static class SimpleVertexReader extends GeneratedVertexReader<LongWritable, DoubleWritable, FloatWritable> { /** Class logger */ private static final Logger LOG = Logger.getLogger(SimpleVertexReader.class); @Override public boolean nextVertex() { return totalRecords > recordsRead; } @Override public Vertex<LongWritable, DoubleWritable, FloatWritable> getCurrentVertex() throws IOException { Vertex<LongWritable, DoubleWritable, FloatWritable> vertex = getConf().createVertex(); LongWritable vertexId = new LongWritable( (inputSplit.getSplitIndex() * totalRecords) + recordsRead); DoubleWritable vertexValue = new DoubleWritable(vertexId.get() * 10d); long targetVertexId = (vertexId.get() + 1) % (inputSplit.getNumSplits() * totalRecords); float edgeValue = vertexId.get() * 100f; List<Edge<LongWritable, FloatWritable>> edges = Lists.newLinkedList(); edges.add(EdgeFactory.create(new LongWritable(targetVertexId), new FloatWritable(edgeValue))); vertex.initialize(vertexId, vertexValue, edges); ++recordsRead; if (LOG.isInfoEnabled()) { LOG.info("next vertex: Return vertexId=" + vertex.getId().get() + ", vertexValue=" + vertex.getValue() + ", targetVertexId=" + targetVertexId + ", edgeValue=" + edgeValue); } aggregate(INPUT_VERTEX_PERSISTENT_AGG, new LongWritable((long) vertex.getValue().get())); return vertex; } } /** * Simple VertexInputFormat */ public static class SimpleVertexInputFormat extends GeneratedVertexInputFormat<LongWritable, DoubleWritable, FloatWritable> { @Override public VertexReader<LongWritable, DoubleWritable, FloatWritable> createVertexReader(InputSplit split, TaskAttemptContext context) throws IOException { return new SimpleVertexReader(); } } /** * Simple Edge Reader */ public static class SimpleEdgeReader extends GeneratedEdgeReader<LongWritable, FloatWritable> { /** Class logger */ private static final Logger LOG = Logger.getLogger(SimpleEdgeReader.class); @Override public boolean nextEdge() { return totalRecords > recordsRead; } @Override public Edge<LongWritable, FloatWritable> getCurrentEdge() throws IOException { LongWritable vertexId = new LongWritable( (inputSplit.getSplitIndex() * totalRecords) + recordsRead); long targetVertexId = (vertexId.get() + 1) % (inputSplit.getNumSplits() * totalRecords); float edgeValue = vertexId.get() * 100f; Edge<LongWritable, FloatWritable> edge = EdgeFactory.create( new LongWritable(targetVertexId), new FloatWritable(edgeValue)); ++recordsRead; if (LOG.isInfoEnabled()) { LOG.info("next edge: Return targetVertexId=" + targetVertexId + ", edgeValue=" + edgeValue); } aggregate(INPUT_EDGE_PERSISTENT_AGG, new LongWritable((long) edge .getValue().get())); return edge; } @Override public LongWritable getCurrentSourceId() throws IOException, InterruptedException { LongWritable vertexId = new LongWritable( (inputSplit.getSplitIndex() * totalRecords) + recordsRead); return vertexId; } } /** * Simple VertexInputFormat */ public static class SimpleEdgeInputFormat extends EdgeInputFormat<LongWritable, FloatWritable> { @Override public void checkInputSpecs(Configuration conf) { } @Override public EdgeReader<LongWritable, FloatWritable> createEdgeReader( InputSplit split, TaskAttemptContext context) throws IOException { return new SimpleEdgeReader(); } @Override public List<InputSplit> getSplits(JobContext context, int minSplitCountHint) throws IOException, InterruptedException { List<InputSplit> inputSplitList = new ArrayList<InputSplit>(); for (int i = 0; i < minSplitCountHint; ++i) { inputSplitList.add(new BspInputSplit(i, minSplitCountHint)); } return inputSplitList; } } }