/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.classloading.jar; import org.apache.flink.api.common.accumulators.Accumulator; import org.apache.flink.api.common.accumulators.SimpleAccumulator; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RichReduceFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import java.util.Collection; /** * This class belongs to the {@link org.apache.flink.test.classloading.ClassLoaderITCase} test. * * It tests dynamic class loading for: * <ul> * <li>Custom Functions</li> * <li>Custom Data Types</li> * <li>Custom Accumulators</li> * <li>Custom Types in collect()</li> * </ul> * * <p> * It's removed by Maven from classpath, so other tests must not depend on it. */ @SuppressWarnings("serial") public class KMeansForTest { // ************************************************************************* // PROGRAM // ************************************************************************* public static void main(String[] args) throws Exception { if (args.length < 3) { throw new IllegalArgumentException("Missing parameters"); } final String pointsData = args[0]; final String centersData = args[1]; final int numIterations = Integer.parseInt(args[2]); ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.getConfig().disableSysoutLogging(); // get input data DataSet<Point> points = env.fromElements(pointsData.split("\n")) .map(new TuplePointConverter()); DataSet<Centroid> centroids = env.fromElements(centersData.split("\n")) .map(new TupleCentroidConverter()); // set number of bulk iterations for KMeans algorithm IterativeDataSet<Centroid> loop = centroids.iterate(numIterations); DataSet<Centroid> newCentroids = points // compute closest centroid for each point .map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids") // count and sum point coordinates for each centroid (test pojo return type) .map(new CountAppender()) // !test if key expressions are working! .groupBy("field0").reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .map(new CentroidAverager()); // feed new centroids back into next iteration DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids); // test that custom data type collects are working finalCentroids.collect(); } // ************************************************************************* // DATA TYPES // ************************************************************************* /** * A simple two-dimensional point. */ public static class Point { public double x, y; public Point() {} public Point(double x, double y) { this.x = x; this.y = y; } public Point add(Point other) { x += other.x; y += other.y; return this; } public Point div(long val) { x /= val; y /= val; return this; } public double euclideanDistance(Point other) { return Math.sqrt((x-other.x)*(x-other.x) + (y-other.y)*(y-other.y)); } public void clear() { x = y = 0.0; } @Override public String toString() { return x + " " + y; } } /** * A simple two-dimensional centroid, basically a point with an ID. */ public static class Centroid extends Point { public int id; public Centroid() {} public Centroid(int id, double x, double y) { super(x,y); this.id = id; } public Centroid(int id, Point p) { super(p.x, p.y); this.id = id; } @Override public String toString() { return id + " " + super.toString(); } } // ************************************************************************* // USER FUNCTIONS // ************************************************************************* /** Converts a Tuple2<Double,Double> into a Point. */ public static final class TuplePointConverter extends RichMapFunction<String, Point> { @Override public Point map(String str) { String[] fields = str.split("\\|"); return new Point(Double.parseDouble(fields[1]), Double.parseDouble(fields[2])); } } /** Converts a Tuple3<Integer, Double,Double> into a Centroid. */ public static final class TupleCentroidConverter extends RichMapFunction<String, Centroid> { @Override public Centroid map(String str) { String[] fields = str.split("\\|"); return new Centroid(Integer.parseInt(fields[0]), Double.parseDouble(fields[1]), Double.parseDouble(fields[2])); } } /** Determines the closest cluster center for a data point. */ public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple2<Integer, Point>> { private Collection<Centroid> centroids; private CustomAccumulator acc; /** Reads the centroid values from a broadcast variable into a collection. */ @Override public void open(Configuration parameters) throws Exception { this.centroids = getRuntimeContext().getBroadcastVariable("centroids"); this.acc = new CustomAccumulator(); getRuntimeContext().addAccumulator("myAcc", this.acc); } @Override public Tuple2<Integer, Point> map(Point p) throws Exception { double minDistance = Double.MAX_VALUE; int closestCentroidId = -1; // check all cluster centers for (Centroid centroid : centroids) { // compute distance double distance = p.euclideanDistance(centroid); // update nearest cluster if necessary if (distance < minDistance) { minDistance = distance; closestCentroidId = centroid.id; } } // emit a new record with the center id and the data point. acc.add(1L); return new Tuple2<Integer, Point>(closestCentroidId, p); } } // Use this so that we can check whether POJOs and the POJO comparator also work public static final class DummyTuple3IntPointLong { public Integer field0; public Point field1; public Long field2; public DummyTuple3IntPointLong() {} DummyTuple3IntPointLong(Integer f0, Point f1, Long f2) { this.field0 = f0; this.field1 = f1; this.field2 = f2; } } /** Appends a count variable to the tuple. */ public static final class CountAppender extends RichMapFunction<Tuple2<Integer, Point>, DummyTuple3IntPointLong> { @Override public DummyTuple3IntPointLong map(Tuple2<Integer, Point> t) { return new DummyTuple3IntPointLong(t.f0, t.f1, 1L); } } /** Sums and counts point coordinates. */ public static final class CentroidAccumulator extends RichReduceFunction<DummyTuple3IntPointLong> { @Override public DummyTuple3IntPointLong reduce(DummyTuple3IntPointLong val1, DummyTuple3IntPointLong val2) { return new DummyTuple3IntPointLong(val1.field0, val1.field1.add(val2.field1), val1.field2 + val2.field2); } } /** Computes new centroid from coordinate sum and count of points. */ public static final class CentroidAverager extends RichMapFunction<DummyTuple3IntPointLong, Centroid> { @Override public Centroid map(DummyTuple3IntPointLong value) { return new Centroid(value.field0, value.field1.div(value.field2)); } } public static class CustomAccumulator implements SimpleAccumulator<Long> { private long value; @Override public void add(Long value) { this.value += value; } @Override public Long getLocalValue() { return this.value; } @Override public void resetLocal() { this.value = 0L; } @Override public void merge(Accumulator<Long, Long> other) { this.value += other.getLocalValue(); } @Override public Accumulator<Long, Long> clone() { CustomAccumulator acc = new CustomAccumulator(); acc.value = this.value; return acc; } } }