/*********************************************************************************************************************** * * Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. * **********************************************************************************************************************/ package eu.stratosphere.pact.example.datamining; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.text.DecimalFormat; import java.text.DecimalFormatSymbols; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import eu.stratosphere.pact.common.contract.CrossContract; import eu.stratosphere.pact.common.contract.FileDataSink; import eu.stratosphere.pact.common.contract.FileDataSource; import eu.stratosphere.pact.common.contract.ReduceContract; import eu.stratosphere.pact.common.contract.ReduceContract.Combinable; import eu.stratosphere.pact.common.io.DelimitedInputFormat; import eu.stratosphere.pact.common.io.DelimitedOutputFormat; import eu.stratosphere.pact.common.plan.Plan; import eu.stratosphere.pact.common.plan.PlanAssembler; import eu.stratosphere.pact.common.plan.PlanAssemblerDescription; import eu.stratosphere.pact.common.stubs.Collector; import eu.stratosphere.pact.common.stubs.CrossStub; import eu.stratosphere.pact.common.stubs.ReduceStub; import eu.stratosphere.pact.common.stubs.StubAnnotation.ConstantFields; import eu.stratosphere.pact.common.stubs.StubAnnotation.ConstantFieldsFirstExcept; import eu.stratosphere.pact.common.stubs.StubAnnotation.OutCardBounds; import eu.stratosphere.pact.common.type.Key; import eu.stratosphere.pact.common.type.PactRecord; import eu.stratosphere.pact.common.type.base.PactDouble; import eu.stratosphere.pact.common.type.base.PactInteger; import eu.stratosphere.pact.common.util.FieldSet; /** * The K-Means cluster algorithm is well-known (see * http://en.wikipedia.org/wiki/K-means_clustering). KMeansIteration is a PACT * program that computes a single iteration of the k-means algorithm. The job * has two inputs, a set of data points and a set of cluster centers. A Cross * PACT is used to compute all distances from all centers to all points. A * following Reduce PACT assigns each data point to the cluster center that is * next to it. Finally, a second Reduce PACT compute the new locations of all * cluster centers. * * @author Fabian Hueske */ public class KMeansIteration implements PlanAssembler, PlanAssemblerDescription { /** * Implements a feature vector as a multi-dimensional point. Coordinates of that point * (= the features) are stored as double values. The distance between two feature vectors is * the Euclidian distance between the points. * * @author Fabian Hueske */ public static final class CoordVector implements Key { // coordinate array private double[] coordinates; /** * Initializes a blank coordinate vector. Required for deserialization! */ public CoordVector() { coordinates = null; } /** * Initializes a coordinate vector. * * @param coordinates The coordinate vector of a multi-dimensional point. */ public CoordVector(Double[] coordinates) { this.coordinates = new double[coordinates.length]; for (int i = 0; i < coordinates.length; i++) { this.coordinates[i] = coordinates[i]; } } /** * Initializes a coordinate vector. * * @param coordinates The coordinate vector of a multi-dimensional point. */ public CoordVector(double[] coordinates) { this.coordinates = coordinates; } /** * Returns the coordinate vector of a multi-dimensional point. * * @return The coordinate vector of a multi-dimensional point. */ public double[] getCoordinates() { return this.coordinates; } /** * Sets the coordinate vector of a multi-dimensional point. * * @param point The dimension values of the point. */ public void setCoordinates(double[] coordinates) { this.coordinates = coordinates; } /** * Computes the Euclidian distance between this coordinate vector and a * second coordinate vector. * * @param cv The coordinate vector to which the distance is computed. * @return The Euclidian distance to coordinate vector cv. If cv has a * different length than this coordinate vector, -1 is returned. */ public double computeEuclidianDistance(CoordVector cv) { // check coordinate vector lengths if (cv.coordinates.length != this.coordinates.length) { return -1.0; } double quadSum = 0.0; for (int i = 0; i < this.coordinates.length; i++) { double diff = this.coordinates[i] - cv.coordinates[i]; quadSum += diff*diff; } return Math.sqrt(quadSum); } /** * {@inheritDoc} */ @Override public void read(DataInput in) throws IOException { int length = in.readInt(); this.coordinates = new double[length]; for (int i = 0; i < length; i++) { this.coordinates[i] = in.readDouble(); } } /** * {@inheritDoc} */ @Override public void write(DataOutput out) throws IOException { out.writeInt(this.coordinates.length); for (int i = 0; i < this.coordinates.length; i++) { out.writeDouble(this.coordinates[i]); } } /** * Compares this coordinate vector to another key. * * @return -1 if the other key is not of type CoordVector. If the other * key is also a CoordVector but its length differs from this * coordinates vector, -1 is return if this coordinate vector is * smaller and 1 if it is larger. If both coordinate vectors * have the same length, the coordinates of both are compared. * If a coordinate of this coordinate vector is smaller than the * corresponding coordinate of the other vector -1 is returned * and 1 otherwise. If all coordinates are identical 0 is * returned. */ @Override public int compareTo(Key o) { // check if other key is also of type CoordVector if (!(o instanceof CoordVector)) { return -1; } // cast to CoordVector CoordVector oP = (CoordVector) o; // check if both coordinate vectors have identical lengths if (oP.coordinates.length > this.coordinates.length) { return -1; } else if (oP.coordinates.length < this.coordinates.length) { return 1; } // compare all coordinates for (int i = 0; i < this.coordinates.length; i++) { if (oP.coordinates[i] > this.coordinates[i]) { return -1; } else if (oP.coordinates[i] < this.coordinates[i]) { return 1; } } return 0; } } /** * Generates records with an id and a and CoordVector. * The input format is line-based, i.e. one record is read from one line * which is terminated by '\n'. Within a line the first '|' character separates * the id from the the CoordVector. The vector consists of a vector of decimals. * The decimals are separated by '|' as well. The id is the id of a data point or * cluster center and the CoordVector the corresponding position (coordinate * vector) of the data point or cluster center. Example line: * "42|23.23|52.57|74.43| Id: 42 Coordinate vector: (23.23, 52.57, 74.43) * * @author Fabian Hueske */ public static class PointInFormat extends DelimitedInputFormat { private final PactInteger idInteger = new PactInteger(); private final CoordVector point = new CoordVector(); private final List<Double> dimensionValues = new ArrayList<Double>(); private double[] pointValues = new double[0]; @Override public boolean readRecord(PactRecord record, byte[] line, int offset, int numBytes) { final int limit = offset + numBytes; int id = -1; int value = 0; int fractionValue = 0; int fractionChars = 0; this.dimensionValues.clear(); for (int pos = offset; pos < limit; pos++) { if (line[pos] == '|') { // check if id was already set if (id == -1) { id = value; } else { this.dimensionValues.add(value + ((double) fractionValue) * Math.pow(10, (-1 * (fractionChars - 1)))); } // reset value value = 0; fractionValue = 0; fractionChars = 0; } else if (line[pos] == '.') { fractionChars = 1; } else { if (fractionChars == 0) { value *= 10; value += line[pos] - '0'; } else { fractionValue *= 10; fractionValue += line[pos] - '0'; fractionChars++; } } } // set the ID this.idInteger.setValue(id); record.setField(0, this.idInteger); // set the data points if (this.pointValues.length != this.dimensionValues.size()) { this.pointValues = new double[this.dimensionValues.size()]; } for (int i = 0; i < this.pointValues.length; i++) { this.pointValues[i] = this.dimensionValues.get(i); } this.point.setCoordinates(this.pointValues); record.setField(1, this.point); return true; } } /** * Writes records that contain an id and a CoordVector. * The output format is line-based, i.e. one record is written to * a line and terminated with '\n'. Within a line the first '|' character * separates the id from the CoordVector. The vector consists of a vector of * decimals. The decimals are separated by '|'. The is is the id of a data * point or cluster center and the vector the corresponding position * (coordinate vector) of the data point or cluster center. Example line: * "42|23.23|52.57|74.43| Id: 42 Coordinate vector: (23.23, 52.57, 74.43) * * @author Fabian Hueske */ public static class PointOutFormat extends DelimitedOutputFormat { private final DecimalFormat df = new DecimalFormat("####0.00"); private final StringBuilder line = new StringBuilder(); public PointOutFormat() { DecimalFormatSymbols dfSymbols = new DecimalFormatSymbols(); dfSymbols.setDecimalSeparator('.'); this.df.setDecimalFormatSymbols(dfSymbols); } @Override public int serializeRecord(PactRecord record, byte[] target) { line.setLength(0); PactInteger centerId = record.getField(0, PactInteger.class); CoordVector centerPos = record.getField(1, CoordVector.class); line.append(centerId.getValue()); for (double coord : centerPos.getCoordinates()) { line.append('|'); line.append(df.format(coord)); } line.append('|'); byte[] byteString = line.toString().getBytes(); if (byteString.length <= target.length) { System.arraycopy(byteString, 0, target, 0, byteString.length); return byteString.length; } else { return -byteString.length; } } } /** * Cross PACT computes the distance of all data points to all cluster * centers. * <p> * * @author Fabian Hueske */ @ConstantFieldsFirstExcept(fields={2,3}) @OutCardBounds(lowerBound=1, upperBound=1) public static class ComputeDistance extends CrossStub { private final PactDouble distance = new PactDouble(); /** * Computes the distance of one data point to one cluster center. * * Output Format: * 0: pointID * 1: pointVector * 2: clusterID * 3: distance */ @Override public void cross(PactRecord dataPointRecord, PactRecord clusterCenterRecord, Collector<PactRecord> out) { CoordVector dataPoint = dataPointRecord.getField(1, CoordVector.class); PactInteger clusterCenterId = clusterCenterRecord.getField(0, PactInteger.class); CoordVector clusterPoint = clusterCenterRecord.getField(1, CoordVector.class); this.distance.setValue(dataPoint.computeEuclidianDistance(clusterPoint)); // add cluster center id and distance to the data point record dataPointRecord.setField(2, clusterCenterId); dataPointRecord.setField(3, this.distance); out.collect(dataPointRecord); } } /** * Reduce PACT determines the closes cluster center for a data point. This * is a minimum aggregation. Hence, a Combiner can be easily implemented. * * @author Fabian Hueske */ @ConstantFields(fields={1}) @OutCardBounds(lowerBound=1, upperBound=1) @Combinable public static class FindNearestCenter extends ReduceStub { private final PactInteger centerId = new PactInteger(); private final CoordVector position = new CoordVector(); private final PactInteger one = new PactInteger(1); private final PactRecord result = new PactRecord(3); /** * Computes a minimum aggregation on the distance of a data point to * cluster centers. * * Output Format: * 0: centerID * 1: pointVector * 2: constant(1) (to enable combinable average computation in the following reducer) */ @Override public void reduce(Iterator<PactRecord> pointsWithDistance, Collector<PactRecord> out) { double nearestDistance = Double.MAX_VALUE; int nearestClusterId = 0; // check all cluster centers while (pointsWithDistance.hasNext()) { PactRecord res = pointsWithDistance.next(); double distance = res.getField(3, PactDouble.class).getValue(); // compare distances if (distance < nearestDistance) { // if distance is smaller than smallest till now, update nearest cluster nearestDistance = distance; nearestClusterId = res.getField(2, PactInteger.class).getValue(); res.getFieldInto(1, this.position); } } // emit a new record with the center id and the data point. add a one to ease the // implementation of the average function with a combiner this.centerId.setValue(nearestClusterId); this.result.setField(0, this.centerId); this.result.setField(1, this.position); this.result.setField(2, this.one); out.collect(this.result); } // ---------------------------------------------------------------------------------------- private final PactRecord nearest = new PactRecord(); /** * Computes a minimum aggregation on the distance of a data point to * cluster centers. */ @Override public void combine(Iterator<PactRecord> pointsWithDistance, Collector<PactRecord> out) { double nearestDistance = Double.MAX_VALUE; // check all cluster centers while (pointsWithDistance.hasNext()) { PactRecord res = pointsWithDistance.next(); double distance = res.getField(3, PactDouble.class).getValue(); // compare distances if (distance < nearestDistance) { nearestDistance = distance; res.copyTo(this.nearest); } } // emit nearest one out.collect(this.nearest); } } /** * Reduce PACT computes the new position (coordinate vector) of a cluster * center. This is an average computation. Hence, Combinable is annotated * and the combine method implemented. * * Output Format: * 0: clusterID * 1: clusterVector * * @author Fabian Hueske */ @ConstantFields(fields={0}) @OutCardBounds(lowerBound=1, upperBound=1) @Combinable public static class RecomputeClusterCenter extends ReduceStub { private final PactInteger count = new PactInteger(); /** * Compute the new position (coordinate vector) of a cluster center. */ @Override public void reduce(Iterator<PactRecord> dataPoints, Collector<PactRecord> out) { PactRecord next = null; // initialize coordinate vector sum and count CoordVector coordinates = new CoordVector(); double[] coordinateSum = null; int count = 0; // compute coordinate vector sum and count while (dataPoints.hasNext()) { next = dataPoints.next(); // get the coordinates and the count from the record double[] thisCoords = next.getField(1, CoordVector.class).getCoordinates(); int thisCount = next.getField(2, PactInteger.class).getValue(); if (coordinateSum == null) { if (coordinates.getCoordinates() != null) { coordinateSum = coordinates.getCoordinates(); } else { coordinateSum = new double[thisCoords.length]; } } addToCoordVector(coordinateSum, thisCoords); count += thisCount; } // compute new coordinate vector (position) of cluster center for (int i = 0; i < coordinateSum.length; i++) { coordinateSum[i] /= count; } coordinates.setCoordinates(coordinateSum); next.setField(1, coordinates); next.setNull(2); // emit new position of cluster center out.collect(next); } /** * Computes a pre-aggregated average value of a coordinate vector. */ @Override public void combine(Iterator<PactRecord> dataPoints, Collector<PactRecord> out) { PactRecord next = null; // initialize coordinate vector sum and count CoordVector coordinates = new CoordVector(); double[] coordinateSum = null; int count = 0; // compute coordinate vector sum and count while (dataPoints.hasNext()) { next = dataPoints.next(); // get the coordinates and the count from the record double[] thisCoords = next.getField(1, CoordVector.class).getCoordinates(); int thisCount = next.getField(2, PactInteger.class).getValue(); if (coordinateSum == null) { if (coordinates.getCoordinates() != null) { coordinateSum = coordinates.getCoordinates(); } else { coordinateSum = new double[thisCoords.length]; } } addToCoordVector(coordinateSum, thisCoords); count += thisCount; } coordinates.setCoordinates(coordinateSum); this.count.setValue(count); next.setField(1, coordinates); next.setField(2, this.count); // emit partial sum and partial count for average computation out.collect(next); } /** * Adds two coordinate vectors by summing up each of their coordinates. * * @param cvToAddTo * The coordinate vector to which the other vector is added. * This vector is returned. * @param cvToBeAdded * The coordinate vector which is added to the other vector. * This vector is not modified. */ private void addToCoordVector(double[] cvToAddTo, double[] cvToBeAdded) { // check if both vectors have same length if (cvToAddTo.length != cvToBeAdded.length) { throw new IllegalArgumentException("The given coordinate vectors are not of equal length."); } // sum coordinate vectors coordinate-wise for (int i = 0; i < cvToAddTo.length; i++) { cvToAddTo[i] += cvToBeAdded[i]; } } } /** * {@inheritDoc} */ @Override public Plan getPlan(String... args) { // parse job parameters int noSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1); String dataPointInput = (args.length > 1 ? args[1] : ""); String clusterInput = (args.length > 2 ? args[2] : ""); String output = (args.length > 3 ? args[3] : ""); // create DataSourceContract for data point input FileDataSource dataPoints = new FileDataSource(PointInFormat.class, dataPointInput, "Data Points"); DelimitedInputFormat.configureDelimitedFormat(dataPoints) .recordDelimiter('\n'); dataPoints.getCompilerHints().setUniqueField(new FieldSet(0)); // create DataSourceContract for cluster center input FileDataSource clusterPoints = new FileDataSource(PointInFormat.class, clusterInput, "Centers"); DelimitedInputFormat.configureDelimitedFormat(clusterPoints) .recordDelimiter('\n'); clusterPoints.setDegreeOfParallelism(1); clusterPoints.getCompilerHints().setUniqueField(new FieldSet(0)); // create CrossContract for distance computation CrossContract computeDistance = CrossContract.builder(ComputeDistance.class) .input1(dataPoints) .input2(clusterPoints) .name("Compute Distances") .build(); computeDistance.getCompilerHints().setAvgBytesPerRecord(48); // create ReduceContract for finding the nearest cluster centers ReduceContract findNearestClusterCenters = new ReduceContract.Builder(FindNearestCenter.class, PactInteger.class, 0) .input(computeDistance) .name("Find Nearest Centers") .build(); findNearestClusterCenters.getCompilerHints().setAvgBytesPerRecord(48); // create ReduceContract for computing new cluster positions ReduceContract recomputeClusterCenter = new ReduceContract.Builder(RecomputeClusterCenter.class, PactInteger.class, 0) .input(findNearestClusterCenters) .name("Recompute Center Positions") .build(); recomputeClusterCenter.getCompilerHints().setAvgBytesPerRecord(36); // create DataSinkContract for writing the new cluster positions FileDataSink newClusterPoints = new FileDataSink(PointOutFormat.class, output, recomputeClusterCenter, "New Center Positions"); // return the PACT plan Plan plan = new Plan(newClusterPoints, "KMeans Iteration"); plan.setDefaultParallelism(noSubTasks); return plan; } @Override public String getDescription() { return "Parameters: [noSubStasks] [dataPoints] [clusterCenters] [output]"; } }