/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.examples.java.clustering;
import java.io.Serializable;
import java.util.Collection;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.examples.java.clustering.util.KMeansData;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
/**
* This example implements a basic K-Means clustering algorithm.
*
* <p>
* K-Means is an iterative clustering algorithm and works as follows:<br>
* K-Means is given a set of data points to be clustered and an initial set of <i>K</i> cluster centers.
* In each iteration, the algorithm computes the distance of each data point to each cluster center.
* Each point is assigned to the cluster center which is closest to it.
* Subsequently, each cluster center is moved to the center (<i>mean</i>) of all points that have been assigned to it.
* The moved cluster centers are fed into the next iteration.
* The algorithm terminates after a fixed number of iterations (as in this implementation)
* or if cluster centers do not (significantly) move in an iteration.<br>
* This is the Wikipedia entry for the <a href="http://en.wikipedia.org/wiki/K-means_clustering">K-Means Clustering algorithm</a>.
*
* <p>
* This implementation works on two-dimensional data points. <br>
* It computes an assignment of data points to cluster centers, i.e.,
* each data point is annotated with the id of the final cluster (center) it belongs to.
*
* <p>
* Input files are plain text files and must be formatted as follows:
* <ul>
* <li>Data points are represented as two double values separated by a blank character.
* Data points are separated by newline characters.<br>
* For example <code>"1.2 2.3\n5.3 7.2\n"</code> gives two data points (x=1.2, y=2.3) and (x=5.3, y=7.2).
* <li>Cluster centers are represented by an integer id and a point value.<br>
* For example <code>"1 6.2 3.2\n2 2.9 5.7\n"</code> gives two centers (id=1, x=6.2, y=3.2) and (id=2, x=2.9, y=5.7).
* </ul>
*
* <p>
* Usage: <code>KMeans --points <path> --centroids <path> --output <path> --iterations <n></code><br>
* If no parameters are provided, the program is run with default data from {@link org.apache.flink.examples.java.clustering.util.KMeansData} and 10 iterations.
*
* <p>
* This example shows how to use:
* <ul>
* <li>Bulk iterations
* <li>Broadcast variables in bulk iterations
* <li>Custom Java objects (POJOs)
* </ul>
*/
@SuppressWarnings("serial")
public class KMeans {
public static void main(String[] args) throws Exception {
// Checking input parameters
final ParameterTool params = ParameterTool.fromArgs(args);
// set up execution environment
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.getConfig().setGlobalJobParameters(params); // make parameters available in the web interface
// get input data:
// read the points and centroids from the provided paths or fall back to default data
DataSet<Point> points = getPointDataSet(params, env);
DataSet<Centroid> centroids = getCentroidDataSet(params, env);
// set number of bulk iterations for KMeans algorithm
IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));
DataSet<Centroid> newCentroids = points
// compute closest centroid for each point
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// count and sum point coordinates for each centroid
.map(new CountAppender())
.groupBy(0).reduce(new CentroidAccumulator())
// compute new centroids from point counts and coordinate sums
.map(new CentroidAverager());
// feed new centroids back into next iteration
DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);
DataSet<Tuple2<Integer, Point>> clusteredPoints = points
// assign points to final clusters
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
// emit result
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");
// since file sinks are lazy, we trigger the execution explicitly
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
}
// *************************************************************************
// DATA SOURCE READING (POINTS AND CENTROIDS)
// *************************************************************************
private static DataSet<Centroid> getCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
DataSet<Centroid> centroids;
if (params.has("centroids")) {
centroids = env.readCsvFile(params.get("centroids"))
.fieldDelimiter(" ")
.pojoType(Centroid.class, "id", "x", "y");
} else {
System.out.println("Executing K-Means example with default centroid data set.");
System.out.println("Use --centroids to specify file input.");
centroids = KMeansData.getDefaultCentroidDataSet(env);
}
return centroids;
}
private static DataSet<Point> getPointDataSet(ParameterTool params, ExecutionEnvironment env) {
DataSet<Point> points;
if (params.has("points")) {
// read points from CSV file
points = env.readCsvFile(params.get("points"))
.fieldDelimiter(" ")
.pojoType(Point.class, "x", "y");
} else {
System.out.println("Executing K-Means example with default point data set.");
System.out.println("Use --points to specify file input.");
points = KMeansData.getDefaultPointDataSet(env);
}
return points;
}
// *************************************************************************
// DATA TYPES
// *************************************************************************
/**
* A simple two-dimensional point.
*/
public static class Point implements Serializable {
public double x, y;
public Point() {}
public Point(double x, double y) {
this.x = x;
this.y = y;
}
public Point add(Point other) {
x += other.x;
y += other.y;
return this;
}
public Point div(long val) {
x /= val;
y /= val;
return this;
}
public double euclideanDistance(Point other) {
return Math.sqrt((x-other.x)*(x-other.x) + (y-other.y)*(y-other.y));
}
public void clear() {
x = y = 0.0;
}
@Override
public String toString() {
return x + " " + y;
}
}
/**
* A simple two-dimensional centroid, basically a point with an ID.
*/
public static class Centroid extends Point {
public int id;
public Centroid() {}
public Centroid(int id, double x, double y) {
super(x,y);
this.id = id;
}
public Centroid(int id, Point p) {
super(p.x, p.y);
this.id = id;
}
@Override
public String toString() {
return id + " " + super.toString();
}
}
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/** Determines the closest cluster center for a data point. */
@ForwardedFields("*->1")
public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> {
private Collection<Centroid> centroids;
/** Reads the centroid values from a broadcast variable into a collection. */
@Override
public void open(Configuration parameters) throws Exception {
this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
}
@Override
public Tuple2<Integer, Point> map(Point p) throws Exception {
double minDistance = Double.MAX_VALUE;
int closestCentroidId = -1;
// check all cluster centers
for (Centroid centroid : centroids) {
// compute distance
double distance = p.euclideanDistance(centroid);
// update nearest cluster if necessary
if (distance < minDistance) {
minDistance = distance;
closestCentroidId = centroid.id;
}
}
// emit a new record with the center id and the data point.
return new Tuple2<>(closestCentroidId, p);
}
}
/** Appends a count variable to the tuple. */
@ForwardedFields("f0;f1")
public static final class CountAppender implements MapFunction<Tuple2<Integer, Point>, Tuple3<Integer, Point, Long>> {
@Override
public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> t) {
return new Tuple3<>(t.f0, t.f1, 1L);
}
}
/** Sums and counts point coordinates. */
@ForwardedFields("0")
public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer, Point, Long>> {
@Override
public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> val1, Tuple3<Integer, Point, Long> val2) {
return new Tuple3<>(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
}
}
/** Computes new centroid from coordinate sum and count of points. */
@ForwardedFields("0->id")
public static final class CentroidAverager implements MapFunction<Tuple3<Integer, Point, Long>, Centroid> {
@Override
public Centroid map(Tuple3<Integer, Point, Long> value) {
return new Centroid(value.f0, value.f1.div(value.f2));
}
}
}