/*********************************************************************************************************************** * * Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. * **********************************************************************************************************************/ package eu.stratosphere.test.broadcastvars; import java.io.BufferedReader; import java.util.Collection; import java.util.Random; import java.util.regex.Matcher; import java.util.regex.Pattern; import eu.stratosphere.test.util.RecordAPITestBase; import eu.stratosphere.nephele.jobgraph.DistributionPattern; import eu.stratosphere.runtime.io.channels.ChannelType; import org.junit.Assert; import eu.stratosphere.api.common.operators.util.UserCodeClassWrapper; import eu.stratosphere.api.common.operators.util.UserCodeObjectWrapper; import eu.stratosphere.api.common.typeutils.TypeSerializerFactory; import eu.stratosphere.api.java.record.functions.MapFunction; import eu.stratosphere.api.java.record.io.CsvInputFormat; import eu.stratosphere.api.java.record.io.CsvOutputFormat; import eu.stratosphere.configuration.Configuration; import eu.stratosphere.core.fs.Path; import eu.stratosphere.nephele.jobgraph.JobGraph; import eu.stratosphere.nephele.jobgraph.JobGraphDefinitionException; import eu.stratosphere.nephele.jobgraph.JobInputVertex; import eu.stratosphere.nephele.jobgraph.JobOutputVertex; import eu.stratosphere.nephele.jobgraph.JobTaskVertex; import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializerFactory; import eu.stratosphere.pact.runtime.shipping.ShipStrategyType; import eu.stratosphere.pact.runtime.task.DriverStrategy; import eu.stratosphere.pact.runtime.task.CollectorMapDriver; import eu.stratosphere.pact.runtime.task.RegularPactTask; import eu.stratosphere.pact.runtime.task.util.LocalStrategy; import eu.stratosphere.pact.runtime.task.util.TaskConfig; import eu.stratosphere.test.iterative.nephele.JobGraphUtils; import eu.stratosphere.types.LongValue; import eu.stratosphere.types.Record; import eu.stratosphere.util.Collector; public class BroadcastVarsNepheleITCase extends RecordAPITestBase { private static final long SEED_POINTS = 0xBADC0FFEEBEEFL; private static final long SEED_MODELS = 0x39134230AFF32L; private static final int NUM_POINTS = 10000; private static final int NUM_MODELS = 42; private static final int NUM_FEATURES = 3; protected String pointsPath; protected String modelsPath; protected String resultPath; public static final String getInputPoints(int numPoints, int numDimensions, long seed) { if (numPoints < 1 || numPoints > 1000000) throw new IllegalArgumentException(); Random r = new Random(); StringBuilder bld = new StringBuilder(3 * (1 + numDimensions) * numPoints); for (int i = 1; i <= numPoints; i++) { bld.append(i); bld.append(' '); r.setSeed(seed + 1000 * i); for (int j = 1; j <= numDimensions; j++) { bld.append(r.nextInt(1000)); bld.append(' '); } bld.append('\n'); } return bld.toString(); } public static final String getInputModels(int numModels, int numDimensions, long seed) { if (numModels < 1 || numModels > 100) throw new IllegalArgumentException(); Random r = new Random(); StringBuilder bld = new StringBuilder(3 * (1 + numDimensions) * numModels); for (int i = 1; i <= numModels; i++) { bld.append(i); bld.append(' '); r.setSeed(seed + 1000 * i); for (int j = 1; j <= numDimensions; j++) { bld.append(r.nextInt(100)); bld.append(' '); } bld.append('\n'); } return bld.toString(); } @Override protected void preSubmit() throws Exception { this.pointsPath = createTempFile("points.txt", getInputPoints(NUM_POINTS, NUM_FEATURES, SEED_POINTS)); this.modelsPath = createTempFile("models.txt", getInputModels(NUM_MODELS, NUM_FEATURES, SEED_MODELS)); this.resultPath = getTempFilePath("results"); } @Override protected JobGraph getJobGraph() throws Exception { return createJobGraphV1(this.pointsPath, this.modelsPath, this.resultPath, 4); } @Override protected void postSubmit() throws Exception { final Random randPoints = new Random(); final Random randModels = new Random(); final Pattern p = Pattern.compile("(\\d+) (\\d+) (\\d+)"); long [][] results = new long[NUM_POINTS][NUM_MODELS]; boolean [][] occurs = new boolean[NUM_POINTS][NUM_MODELS]; for (int i = 0; i < NUM_POINTS; i++) { for (int j = 0; j < NUM_MODELS; j++) { long actDotProd = 0; randPoints.setSeed(SEED_POINTS + 1000 * (i+1)); randModels.setSeed(SEED_MODELS + 1000 * (j+1)); for (int z = 1; z <= NUM_FEATURES; z++) { actDotProd += randPoints.nextInt(1000) * randModels.nextInt(100); } results[i][j] = actDotProd; occurs[i][j] = false; } } for (BufferedReader reader : getResultReader(this.resultPath)) { String line = null; while (null != (line = reader.readLine())) { final Matcher m = p.matcher(line); Assert.assertTrue(m.matches()); int modelId = Integer.parseInt(m.group(1)); int pointId = Integer.parseInt(m.group(2)); long expDotProd = Long.parseLong(m.group(3)); Assert.assertFalse("Dot product for record (" + pointId + ", " + modelId + ") occurs more than once", occurs[pointId-1][modelId-1]); Assert.assertEquals(String.format("Bad product for (%04d, %04d)", pointId, modelId), expDotProd, results[pointId-1][modelId-1]); occurs[pointId-1][modelId-1] = true; } } for (int i = 0; i < NUM_POINTS; i++) { for (int j = 0; j < NUM_MODELS; j++) { Assert.assertTrue("Dot product for record (" + (i+1) + ", " + (j+1) + ") does not occur", occurs[i][j]); } } } // ------------------------------------------------------------------------------------------------------------- // UDFs // ------------------------------------------------------------------------------------------------------------- public static final class DotProducts extends MapFunction { private static final long serialVersionUID = 1L; private final Record result = new Record(3); private final LongValue lft = new LongValue(); private final LongValue rgt = new LongValue(); private final LongValue prd = new LongValue(); private Collection<Record> models; @Override public void open(Configuration parameters) throws Exception { this.models = this.getRuntimeContext().getBroadcastVariable("models"); } @Override public void map(Record record, Collector<Record> out) throws Exception { for (Record model : this.models) { // compute dot product between model and pair long product = 0; for (int i = 1; i <= NUM_FEATURES; i++) { product += model.getField(i, this.lft).getValue() * record.getField(i, this.rgt).getValue(); } this.prd.setValue(product); // construct result this.result.copyFrom(model, new int[] { 0 }, new int[] { 0 }); this.result.copyFrom(record, new int[] { 0 }, new int[] { 1 }); this.result.setField(2, this.prd); // emit result out.collect(this.result); } } } // ------------------------------------------------------------------------------------------------------------- // Job vertex builder methods // ------------------------------------------------------------------------------------------------------------- @SuppressWarnings("unchecked") private static JobInputVertex createPointsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) { CsvInputFormat pointsInFormat = new CsvInputFormat(' ', LongValue.class, LongValue.class, LongValue.class, LongValue.class); JobInputVertex pointsInput = JobGraphUtils.createInput(pointsInFormat, pointsPath, "Input[Points]", jobGraph, numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration()); taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); taskConfig.setOutputSerializer(serializer); } return pointsInput; } @SuppressWarnings("unchecked") private static JobInputVertex createModelsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) { CsvInputFormat modelsInFormat = new CsvInputFormat(' ', LongValue.class, LongValue.class, LongValue.class, LongValue.class); JobInputVertex modelsInput = JobGraphUtils.createInput(modelsInFormat, pointsPath, "Input[Models]", jobGraph, numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(modelsInput.getConfiguration()); taskConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST); taskConfig.setOutputSerializer(serializer); } return modelsInput; } private static JobTaskVertex createMapper(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) { JobTaskVertex pointsInput = JobGraphUtils.createTask(RegularPactTask.class, "Map[DotProducts]", jobGraph, numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration()); taskConfig.setStubWrapper(new UserCodeClassWrapper<DotProducts>(DotProducts.class)); taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); taskConfig.setOutputSerializer(serializer); taskConfig.setDriver(CollectorMapDriver.class); taskConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); taskConfig.addInputToGroup(0); taskConfig.setInputLocalStrategy(0, LocalStrategy.NONE); taskConfig.setInputSerializer(serializer, 0); taskConfig.setBroadcastInputName("models", 0); taskConfig.addBroadcastInputToGroup(0); taskConfig.setBroadcastInputSerializer(serializer, 0); } return pointsInput; } private static JobOutputVertex createOutput(JobGraph jobGraph, String resultPath, int numSubTasks, TypeSerializerFactory<?> serializer) { JobOutputVertex output = JobGraphUtils.createFileOutput(jobGraph, "Output", numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(output.getConfiguration()); taskConfig.addInputToGroup(0); taskConfig.setInputSerializer(serializer, 0); @SuppressWarnings("unchecked") CsvOutputFormat outFormat = new CsvOutputFormat("\n", " ", LongValue.class, LongValue.class, LongValue.class); outFormat.setOutputFilePath(new Path(resultPath)); taskConfig.setStubWrapper(new UserCodeObjectWrapper<CsvOutputFormat>(outFormat)); } return output; } // ------------------------------------------------------------------------------------------------------------- // Unified solution set and workset tail update // ------------------------------------------------------------------------------------------------------------- private JobGraph createJobGraphV1(String pointsPath, String centersPath, String resultPath, int numSubTasks) throws JobGraphDefinitionException { // -- init ------------------------------------------------------------------------------------------------- final TypeSerializerFactory<?> serializer = RecordSerializerFactory.get(); JobGraph jobGraph = new JobGraph("Distance Builder"); // -- vertices --------------------------------------------------------------------------------------------- JobInputVertex points = createPointsInput(jobGraph, pointsPath, numSubTasks, serializer); JobInputVertex models = createModelsInput(jobGraph, centersPath, numSubTasks, serializer); JobTaskVertex mapper = createMapper(jobGraph, numSubTasks, serializer); JobOutputVertex output = createOutput(jobGraph, resultPath, numSubTasks, serializer); // -- edges ------------------------------------------------------------------------------------------------ JobGraphUtils.connect(points, mapper, ChannelType.NETWORK, DistributionPattern.POINTWISE); JobGraphUtils.connect(models, mapper, ChannelType.NETWORK, DistributionPattern.BIPARTITE); JobGraphUtils.connect(mapper, output, ChannelType.NETWORK, DistributionPattern.POINTWISE); // -- instance sharing ------------------------------------------------------------------------------------- points.setVertexToShareInstancesWith(output); models.setVertexToShareInstancesWith(output); mapper.setVertexToShareInstancesWith(output); return jobGraph; } }