/*********************************************************************************************************************** * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. **********************************************************************************************************************/ package eu.stratosphere.test.recordJobs.kmeans; import java.util.ArrayList; import java.util.List; import eu.stratosphere.api.common.Plan; import eu.stratosphere.api.common.Program; import eu.stratosphere.api.common.ProgramDescription; import eu.stratosphere.api.java.record.operators.BulkIteration; import eu.stratosphere.api.java.record.operators.FileDataSink; import eu.stratosphere.api.java.record.operators.FileDataSource; import eu.stratosphere.api.java.record.operators.CrossOperator; import eu.stratosphere.api.java.record.operators.ReduceOperator; import eu.stratosphere.client.LocalExecutor; import eu.stratosphere.test.recordJobs.kmeans.udfs.ComputeDistance; import eu.stratosphere.test.recordJobs.kmeans.udfs.FindNearestCenter; import eu.stratosphere.test.recordJobs.kmeans.udfs.PointInFormat; import eu.stratosphere.test.recordJobs.kmeans.udfs.PointOutFormat; import eu.stratosphere.test.recordJobs.kmeans.udfs.RecomputeClusterCenter; import eu.stratosphere.types.IntValue; public class KMeansCross implements Program, ProgramDescription { private static final long serialVersionUID = 1L; @Override public Plan getPlan(String... args) { // parse job parameters final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1); final String dataPointInput = (args.length > 1 ? args[1] : ""); final String clusterInput = (args.length > 2 ? args[2] : ""); final String output = (args.length > 3 ? args[3] : ""); final int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1); // create DataSourceContract for cluster center input FileDataSource initialClusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers"); initialClusterPoints.setDegreeOfParallelism(1); BulkIteration iteration = new BulkIteration("K-Means Loop"); iteration.setInput(initialClusterPoints); iteration.setMaximumNumberOfIterations(numIterations); // create DataSourceContract for data point input FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points"); // create CrossOperator for distance computation CrossOperator computeDistance = CrossOperator.builder(new ComputeDistance()) .input1(dataPoints) .input2(iteration.getPartialSolution()) .name("Compute Distances") .build(); // create ReduceOperator for finding the nearest cluster centers ReduceOperator findNearestClusterCenters = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0) .input(computeDistance) .name("Find Nearest Centers") .build(); // create ReduceOperator for computing new cluster positions ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0) .input(findNearestClusterCenters) .name("Recompute Center Positions") .build(); iteration.setNextPartialSolution(recomputeClusterCenter); // create DataSourceContract for data point input FileDataSource dataPoints2 = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points 2"); // compute distance of points to final clusters CrossOperator computeFinalDistance = CrossOperator.builder(new ComputeDistance()) .input1(dataPoints2) .input2(iteration) .name("Compute Final Distances") .build(); // find nearest final cluster for point ReduceOperator findNearestFinalCluster = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0) .input(computeFinalDistance) .name("Find Nearest Final Centers") .build(); // create DataSinkContract for writing the new cluster positions FileDataSink finalClusters = new FileDataSink(new PointOutFormat(), output+"/centers", iteration, "Cluster Positions"); // write assigned clusters FileDataSink clusterAssignments = new FileDataSink(new PointOutFormat(), output+"/points", findNearestFinalCluster, "Cluster Assignments"); List<FileDataSink> sinks = new ArrayList<FileDataSink>(); sinks.add(finalClusters); sinks.add(clusterAssignments); // return the PACT plan Plan plan = new Plan(sinks, "Iterative KMeans"); plan.setDefaultParallelism(numSubTasks); return plan; } @Override public String getDescription() { return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>"; } public static void main(String[] args) throws Exception { KMeansCross kmi = new KMeansCross(); if (args.length < 5) { System.err.println(kmi.getDescription()); System.exit(1); } Plan plan = kmi.getPlan(args); // This will execute the kMeans clustering job embedded in a local context. LocalExecutor.execute(plan); } }