/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.projectflink.spark; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.broadcast.Broadcast; import scala.Tuple2; import scala.Tuple3; import java.io.Serializable; import java.util.Arrays; import java.util.List; public class KMeansArbitraryDimension { public static void main(String[] args) { if(!parseParameters(args)) { return; } SparkConf conf = new SparkConf().setAppName("KMeans Multi-Dimension").setMaster(master).set("spark.hadoop.validateOutputSpecs", "false"); conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); // conf.set("spark.kryo.registrator", ScalaRegistrator.class.getCanonicalName()); conf.set("spark.kryo.registrator", MyRegistrator.class.getCanonicalName()); JavaSparkContext sc = new JavaSparkContext(conf); // ================================ Standard KMeans ============================= JavaRDD<Point> points = sc .textFile(pointsPath, dop) .map(new ConvertToPoint()).cache(); JavaPairRDD<Integer, Point> kCenters = sc .textFile(centersPath) .mapToPair(new ConvertToCentroid()); for(int i=0; i<numIterations; ++i) { Broadcast<List<Tuple2<Integer, Point>>> brCenters = sc.broadcast(kCenters.collect()); kCenters = points // compute closest centroid for each point .mapToPair(new SelectNearestCentroid(brCenters)) // count and sum point coordinates for each centroid .mapToPair(new CountAppender()) .reduceByKey(new CentroidSum()) // calculate the mean( the new center ) of each cluster .mapToPair(new CentroidAverage()); brCenters.unpersist(); } Broadcast<List<Tuple2<Integer, Point>>> brCenters = sc.broadcast(kCenters.collect()); JavaPairRDD<Integer, Point> clusteredPoints = points.mapToPair(new SelectNearestCentroid(brCenters)); clusteredPoints.saveAsTextFile(outputPath); } /** Convert String value into data point **/ public static final class ConvertToPoint implements Function<String, Point> { @Override public Point call(String s) throws Exception { String [] line = s.split(" "); double [] points = new double[line.length]; for (int i = 0; i < line.length; i++) { points[i] = Double.parseDouble(line[i]); } return new Point(points); } } /** Convert String value into data centroid **/ public static final class ConvertToCentroid implements PairFunction<String, Integer, Point> { @Override public Tuple2<Integer, Point> call(String s) throws Exception { String [] line = s.split(" "); int id = Integer.parseInt(line[0]); double [] points = new double[line.length - 1]; for (int i = 1; i < line.length; i++) { points[i - 1] = Double.parseDouble(line[i]); } return new Tuple2<Integer, Point>(id, new Point(points)); } } /** * Assign each point to its closest center * */ public static final class SelectNearestCentroid implements PairFunction<Point, Integer, Point> { List<Tuple2<Integer, Point>> brCenters; public SelectNearestCentroid(Broadcast<List<Tuple2<Integer, Point>>> brCenters) { this.brCenters = brCenters.getValue(); } public Tuple2<Integer, Point> call(Point v1) throws Exception { double minDistance = Double.MAX_VALUE; int centerId = 0; for(Tuple2<Integer, Point> c : brCenters) { double d = v1.euclideanDistance(c._2()); if(minDistance > d) { minDistance = d; centerId = c._1(); } } return new Tuple2<Integer, Point>(centerId, v1); } } /** * Appends a count variable to the tuple. */ public static final class CountAppender implements PairFunction<Tuple2<Integer, Point>, Integer, Tuple2<Point, Long>> { @Override public Tuple2<Integer, Tuple2<Point, Long>> call(Tuple2<Integer, Point> t) throws Exception { return new Tuple2<Integer, Tuple2<Point, Long>>(t._1(), new Tuple2<Point, Long>(t._2(), 1L)); } } /** * Aggregate(sum) all the points in each cluster for calculating mean * */ public static final class CentroidSum implements Function2<Tuple2<Point, Long>, Tuple2<Point, Long>, Tuple2<Point, Long>> { @Override public Tuple2<Point, Long> call(Tuple2<Point, Long> v1, Tuple2<Point, Long> v2) throws Exception { return new Tuple2<Point, Long>(v1._1().add(v2._1()), v1._2() + v2._2()); } } /** * Calculate the mean(new center) of the cluster ( sum of points / number of points ) * */ public static final class CentroidAverage implements PairFunction<Tuple2<Integer, Tuple2<Point, Long>>, Integer, Point> { @Override public Tuple2<Integer, Point> call(Tuple2<Integer, Tuple2<Point, Long>> t) throws Exception { Point p = t._2()._1(); Long l = t._2()._2(); Point nev = p.div(l); Tuple2<Integer, Point> cen = new Tuple2<Integer, Point>(t._1(), nev); return cen; } } // ************************************************************************* // DATA TYPES // ************************************************************************* public static class Point implements Serializable { private double [] points; public Point() { } public Point(double[] points) { this.points = points.clone(); } public Point add(Point other) { Point ret = new Point(this.points); for (int i = 0; i < points.length; i++) { ret.points[i] = points[i] + other.points[i]; } return ret; } public Point div(long val) { Point ret = new Point(this.points); for (int i = 0; i < points.length; i++) { ret.points[i] = points[i] / val; } return ret; } public double euclideanDistance(Point other) { double sum = 0; for (int i = 0; i < points.length; i++) { sum = sum + (points[i] - other.points[i]) * (points[i] - other.points[i]); } return Math.sqrt(sum); } @Override public String toString() { return Arrays.toString(points); } } // ************************************************************************* // UTIL METHODS // ************************************************************************* private static String master = null; private static String pointsPath = null; private static String centersPath = null; private static String outputPath = null; private static int numIterations = 10; private static int dop = 400; private static boolean parseParameters(String[] programArguments) { // parse input arguments if(programArguments.length == 6) { master = programArguments[0]; pointsPath = programArguments[1]; centersPath = programArguments[2]; outputPath = programArguments[3]; numIterations = Integer.parseInt(programArguments[4]); dop = Integer.parseInt(programArguments[5]); } else { System.err.println("Usage: KMeans <master> <points path> <centers path> <result path> <num iterations> <dop>"); return false; } return true; } }