/*********************************************************************************************************************** * * Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. * **********************************************************************************************************************/ package eu.stratosphere.test.broadcastvars; import eu.stratosphere.nephele.jobgraph.DistributionPattern; import org.apache.log4j.Level; import eu.stratosphere.api.common.operators.util.UserCodeObjectWrapper; import eu.stratosphere.api.common.typeutils.TypeComparatorFactory; import eu.stratosphere.api.common.typeutils.TypeSerializerFactory; import eu.stratosphere.api.java.record.io.CsvInputFormat; import eu.stratosphere.configuration.Configuration; import eu.stratosphere.core.fs.Path; import eu.stratosphere.nephele.jobgraph.JobGraph; import eu.stratosphere.nephele.jobgraph.JobGraphDefinitionException; import eu.stratosphere.nephele.jobgraph.JobInputVertex; import eu.stratosphere.nephele.jobgraph.JobOutputVertex; import eu.stratosphere.nephele.jobgraph.JobTaskVertex; import eu.stratosphere.pact.runtime.iterative.task.IterationHeadPactTask; import eu.stratosphere.pact.runtime.iterative.task.IterationIntermediatePactTask; import eu.stratosphere.pact.runtime.iterative.task.IterationTailPactTask; import eu.stratosphere.api.java.typeutils.runtime.record.RecordComparatorFactory; import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializerFactory; import eu.stratosphere.pact.runtime.shipping.ShipStrategyType; import eu.stratosphere.pact.runtime.task.CollectorMapDriver; import eu.stratosphere.pact.runtime.task.DriverStrategy; import eu.stratosphere.pact.runtime.task.NoOpDriver; import eu.stratosphere.pact.runtime.task.GroupReduceDriver; import eu.stratosphere.pact.runtime.task.chaining.ChainedCollectorMapDriver; import eu.stratosphere.pact.runtime.task.util.LocalStrategy; import eu.stratosphere.pact.runtime.task.util.TaskConfig; import eu.stratosphere.runtime.io.channels.ChannelType; import eu.stratosphere.test.iterative.nephele.JobGraphUtils; import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.PointBuilder; import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.RecomputeClusterCenter; import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.SelectNearestCenter; import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.PointOutFormat; import eu.stratosphere.test.testdata.KMeansData; import eu.stratosphere.test.util.RecordAPITestBase; import eu.stratosphere.types.DoubleValue; import eu.stratosphere.types.IntValue; import eu.stratosphere.util.LogUtils; public class KMeansIterativeNepheleITCase extends RecordAPITestBase { private static final int ITERATION_ID = 42; private static final int MEMORY_PER_CONSUMER = 2; protected String dataPath; protected String clusterPath; protected String resultPath; public KMeansIterativeNepheleITCase() { LogUtils.initializeDefaultConsoleLogger(Level.ERROR); } @Override protected void preSubmit() throws Exception { dataPath = createTempFile("datapoints.txt", KMeansData.DATAPOINTS); clusterPath = createTempFile("initial_centers.txt", KMeansData.INITIAL_CENTERS); resultPath = getTempDirPath("result"); } @Override protected void postSubmit() throws Exception { compareResultsByLinesInMemory(KMeansData.CENTERS_AFTER_20_ITERATIONS_SINGLE_DIGIT, resultPath); } @Override protected JobGraph getJobGraph() throws Exception { return createJobGraph(dataPath, clusterPath, this.resultPath, 4, 20); } // ------------------------------------------------------------------------------------------------------------- // Job vertex builder methods // ------------------------------------------------------------------------------------------------------------- private static JobInputVertex createPointsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) { @SuppressWarnings("unchecked") CsvInputFormat pointsInFormat = new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class); JobInputVertex pointsInput = JobGraphUtils.createInput(pointsInFormat, pointsPath, "[Points]", jobGraph, numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration()); taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); taskConfig.setOutputSerializer(serializer); TaskConfig chainedMapper = new TaskConfig(new Configuration()); chainedMapper.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); chainedMapper.setStubWrapper(new UserCodeObjectWrapper<PointBuilder>(new PointBuilder())); chainedMapper.addOutputShipStrategy(ShipStrategyType.FORWARD); chainedMapper.setOutputSerializer(serializer); taskConfig.addChainedTask(ChainedCollectorMapDriver.class, chainedMapper, "Build points"); } return pointsInput; } private static JobInputVertex createCentersInput(JobGraph jobGraph, String centersPath, int numSubTasks, TypeSerializerFactory<?> serializer) { @SuppressWarnings("unchecked") CsvInputFormat modelsInFormat = new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class); JobInputVertex modelsInput = JobGraphUtils.createInput(modelsInFormat, centersPath, "[Models]", jobGraph, numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(modelsInput.getConfiguration()); taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); taskConfig.setOutputSerializer(serializer); TaskConfig chainedMapper = new TaskConfig(new Configuration()); chainedMapper.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); chainedMapper.setStubWrapper(new UserCodeObjectWrapper<PointBuilder>(new PointBuilder())); chainedMapper.addOutputShipStrategy(ShipStrategyType.FORWARD); chainedMapper.setOutputSerializer(serializer); taskConfig.addChainedTask(ChainedCollectorMapDriver.class, chainedMapper, "Build centers"); } return modelsInput; } private static JobOutputVertex createOutput(JobGraph jobGraph, String resultPath, int numSubTasks, TypeSerializerFactory<?> serializer) { JobOutputVertex output = JobGraphUtils.createFileOutput(jobGraph, "Output", numSubTasks, numSubTasks); { TaskConfig taskConfig = new TaskConfig(output.getConfiguration()); taskConfig.addInputToGroup(0); taskConfig.setInputSerializer(serializer, 0); PointOutFormat outFormat = new PointOutFormat(); outFormat.setOutputFilePath(new Path(resultPath)); taskConfig.setStubWrapper(new UserCodeObjectWrapper<PointOutFormat>(outFormat)); } return output; } private static JobTaskVertex createIterationHead(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) { JobTaskVertex head = JobGraphUtils.createTask(IterationHeadPactTask.class, "Iteration Head", jobGraph, numSubTasks, numSubTasks); TaskConfig headConfig = new TaskConfig(head.getConfiguration()); headConfig.setIterationId(ITERATION_ID); // initial input / partial solution headConfig.addInputToGroup(0); headConfig.setIterationHeadPartialSolutionOrWorksetInputIndex(0); headConfig.setInputSerializer(serializer, 0); // back channel / iterations headConfig.setBackChannelMemory(MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE); // output into iteration. broadcasting the centers headConfig.setOutputSerializer(serializer); headConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST); // final output TaskConfig headFinalOutConfig = new TaskConfig(new Configuration()); headFinalOutConfig.setOutputSerializer(serializer); headFinalOutConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); headConfig.setIterationHeadFinalOutputConfig(headFinalOutConfig); // the sync headConfig.setIterationHeadIndexOfSyncOutput(2); // the driver headConfig.setDriver(NoOpDriver.class); headConfig.setDriverStrategy(DriverStrategy.UNARY_NO_OP); return head; } private static JobTaskVertex createMapper(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> inputSerializer, TypeSerializerFactory<?> broadcastVarSerializer, TypeSerializerFactory<?> outputSerializer, TypeComparatorFactory<?> outputComparator) { JobTaskVertex mapper = JobGraphUtils.createTask(IterationIntermediatePactTask.class, "Map (Select nearest center)", jobGraph, numSubTasks, numSubTasks); TaskConfig intermediateConfig = new TaskConfig(mapper.getConfiguration()); intermediateConfig.setIterationId(ITERATION_ID); intermediateConfig.setDriver(CollectorMapDriver.class); intermediateConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP); intermediateConfig.addInputToGroup(0); intermediateConfig.setInputSerializer(inputSerializer, 0); intermediateConfig.setOutputSerializer(outputSerializer); intermediateConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH); intermediateConfig.setOutputComparator(outputComparator, 0); intermediateConfig.setBroadcastInputName("centers", 0); intermediateConfig.addBroadcastInputToGroup(0); intermediateConfig.setBroadcastInputSerializer(broadcastVarSerializer, 0); // the udf intermediateConfig.setStubWrapper(new UserCodeObjectWrapper<SelectNearestCenter>(new SelectNearestCenter())); return mapper; } private static JobTaskVertex createReducer(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> inputSerializer, TypeComparatorFactory<?> inputComparator, TypeSerializerFactory<?> outputSerializer) { // ---------------- the tail (co group) -------------------- JobTaskVertex tail = JobGraphUtils.createTask(IterationTailPactTask.class, "Reduce / Iteration Tail", jobGraph, numSubTasks, numSubTasks); TaskConfig tailConfig = new TaskConfig(tail.getConfiguration()); tailConfig.setIterationId(ITERATION_ID); tailConfig.setIsWorksetUpdate(); // inputs and driver tailConfig.setDriver(GroupReduceDriver.class); tailConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE); tailConfig.addInputToGroup(0); tailConfig.setInputSerializer(inputSerializer, 0); tailConfig.setDriverComparator(inputComparator, 0); tailConfig.setInputLocalStrategy(0, LocalStrategy.SORT); tailConfig.setInputComparator(inputComparator, 0); tailConfig.setMemoryInput(0, MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE); tailConfig.setFilehandlesInput(0, 128); tailConfig.setSpillingThresholdInput(0, 0.9f); // output tailConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); tailConfig.setOutputSerializer(outputSerializer); // the udf tailConfig.setStubWrapper(new UserCodeObjectWrapper<RecomputeClusterCenter>(new RecomputeClusterCenter())); return tail; } private static JobOutputVertex createSync(JobGraph jobGraph, int numIterations, int dop) { JobOutputVertex sync = JobGraphUtils.createSync(jobGraph, dop); TaskConfig syncConfig = new TaskConfig(sync.getConfiguration()); syncConfig.setNumberOfIterations(numIterations); syncConfig.setIterationId(ITERATION_ID); return sync; } // ------------------------------------------------------------------------------------------------------------- // Unified solution set and workset tail update // ------------------------------------------------------------------------------------------------------------- private static JobGraph createJobGraph(String pointsPath, String centersPath, String resultPath, int numSubTasks, int numIterations) throws JobGraphDefinitionException { // -- init ------------------------------------------------------------------------------------------------- final TypeSerializerFactory<?> serializer = RecordSerializerFactory.get(); @SuppressWarnings("unchecked") final TypeComparatorFactory<?> int0Comparator = new RecordComparatorFactory(new int[] { 0 }, new Class[] { IntValue.class }); JobGraph jobGraph = new JobGraph("KMeans Iterative"); // -- vertices --------------------------------------------------------------------------------------------- JobInputVertex points = createPointsInput(jobGraph, pointsPath, numSubTasks, serializer); JobInputVertex centers = createCentersInput(jobGraph, centersPath, numSubTasks, serializer); JobTaskVertex head = createIterationHead(jobGraph, numSubTasks, serializer); JobTaskVertex mapper = createMapper(jobGraph, numSubTasks, serializer, serializer, serializer, int0Comparator); JobTaskVertex reducer = createReducer(jobGraph, numSubTasks, serializer, int0Comparator, serializer); JobOutputVertex fakeTailOutput = JobGraphUtils.createFakeOutput(jobGraph, "FakeTailOutput", numSubTasks, numSubTasks); JobOutputVertex sync = createSync(jobGraph, numIterations, numSubTasks); JobOutputVertex output = createOutput(jobGraph, resultPath, numSubTasks, serializer); // -- edges ------------------------------------------------------------------------------------------------ JobGraphUtils.connect(points, mapper, ChannelType.NETWORK, DistributionPattern.POINTWISE); JobGraphUtils.connect(centers, head, ChannelType.NETWORK, DistributionPattern.POINTWISE); JobGraphUtils.connect(head, mapper, ChannelType.NETWORK, DistributionPattern.BIPARTITE); new TaskConfig(mapper.getConfiguration()).setBroadcastGateIterativeWithNumberOfEventsUntilInterrupt(0, numSubTasks); new TaskConfig(mapper.getConfiguration()).setInputCached(0, true); new TaskConfig(mapper.getConfiguration()).setInputMaterializationMemory(0, MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE); JobGraphUtils.connect(mapper, reducer, ChannelType.NETWORK, DistributionPattern.BIPARTITE); new TaskConfig(reducer.getConfiguration()).setGateIterativeWithNumberOfEventsUntilInterrupt(0, numSubTasks); JobGraphUtils.connect(reducer, fakeTailOutput, ChannelType.NETWORK, DistributionPattern.POINTWISE); JobGraphUtils.connect(head, output, ChannelType.NETWORK, DistributionPattern.POINTWISE); JobGraphUtils.connect(head, sync, ChannelType.NETWORK, DistributionPattern.BIPARTITE); // -- instance sharing ------------------------------------------------------------------------------------- points.setVertexToShareInstancesWith(output); centers.setVertexToShareInstancesWith(output); head.setVertexToShareInstancesWith(output); mapper.setVertexToShareInstancesWith(output); reducer.setVertexToShareInstancesWith(output); fakeTailOutput.setVertexToShareInstancesWith(output); sync.setVertexToShareInstancesWith(output); return jobGraph; } }