/* Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.kmeans; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import com.google.common.collect.Lists; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.SequenceFile.Writer; import org.apache.hadoop.mapreduce.Mapper; import org.apache.mahout.clustering.AbstractCluster; import org.apache.mahout.clustering.ClusterObservations; import org.apache.mahout.clustering.WeightedPropertyVectorWritable; import org.apache.mahout.clustering.WeightedVectorWritable; import org.apache.mahout.common.distance.DistanceMeasure; import org.apache.mahout.math.Vector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * This class implements the k-means clustering algorithm. It uses {@link Cluster} as a cluster * representation. The class can be used as part of a clustering job to be started as map/reduce job. * */ public class KMeansClusterer { private static final Logger log = LoggerFactory.getLogger(KMeansClusterer.class); /** Distance to use for point to cluster comparison. */ private final DistanceMeasure measure; /** * Init the k-means clusterer with the distance measure to use for comparison. * * @param measure * The distance measure to use for comparing clusters against points. * */ public KMeansClusterer(DistanceMeasure measure) { this.measure = measure; } /** * Iterates over all clusters and identifies the one closes to the given point. Distance measure used is * configured at creation time. * * @param point * a point to find a cluster for. * @param clusters * a List<Cluster> to test. */ public void emitPointToNearestCluster(Vector point, Iterable<Cluster> clusters, Mapper<?,?,Text,ClusterObservations>.Context context) throws IOException, InterruptedException { Cluster nearestCluster = null; double nearestDistance = Double.MAX_VALUE; for (Cluster cluster : clusters) { Vector clusterCenter = cluster.getCenter(); double distance = this.measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point); if (log.isDebugEnabled()) { log.debug("{} Cluster: {}", distance, cluster.getId()); } if (distance < nearestDistance || nearestCluster == null) { nearestCluster = cluster; nearestDistance = distance; } } context.write(new Text(nearestCluster.getIdentifier()), new ClusterObservations(1, point, point.times(point))); } /** * Sequential implementation to add point to the nearest cluster * @param point * @param clusters */ protected void addPointToNearestCluster(Vector point, Iterable<Cluster> clusters) { Cluster closestCluster = null; double closestDistance = Double.MAX_VALUE; for (Cluster cluster : clusters) { double distance = measure.distance(cluster.getCenter(), point); if (closestCluster == null || closestDistance > distance) { closestCluster = cluster; closestDistance = distance; } } closestCluster.observe(point, 1); } /** * Sequential implementation to test convergence and update cluster centers */ protected boolean testConvergence(Iterable<Cluster> clusters, double distanceThreshold) { boolean converged = true; for (Cluster cluster : clusters) { if (!computeConvergence(cluster, distanceThreshold)) { converged = false; } cluster.computeParameters(); } return converged; } public void outputPointWithClusterInfo(Vector vector, Iterable<Cluster> clusters, Mapper<?,?,IntWritable,WeightedPropertyVectorWritable>.Context context) throws IOException, InterruptedException { AbstractCluster nearestCluster = null; double nearestDistance = Double.MAX_VALUE; for (AbstractCluster cluster : clusters) { Vector clusterCenter = cluster.getCenter(); double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, vector); if (distance < nearestDistance || nearestCluster == null) { nearestCluster = cluster; nearestDistance = distance; } } Map<Text, Text> props = new HashMap<Text, Text>(); props.put(new Text("distance"), new Text(String.valueOf(nearestDistance))); context.write(new IntWritable(nearestCluster.getId()), new WeightedPropertyVectorWritable(1, vector, props)); } /** * Iterates over all clusters and identifies the one closes to the given point. Distance measure used is * configured at creation time. * * @param point * a point to find a cluster for. * @param clusters * a List<Cluster> to test. */ protected void emitPointToNearestCluster(Vector point, Iterable<Cluster> clusters, Writer writer) throws IOException { AbstractCluster nearestCluster = null; double nearestDistance = Double.MAX_VALUE; for (AbstractCluster cluster : clusters) { Vector clusterCenter = cluster.getCenter(); double distance = this.measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point); if (log.isDebugEnabled()) { log.debug("{} Cluster: {}", distance, cluster.getId()); } if (distance < nearestDistance || nearestCluster == null) { nearestCluster = cluster; nearestDistance = distance; } } writer.append(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, point)); } /** * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters * until their centers converge or until the maximum number of iterations is exceeded. * * @param points * the input List<Vector> of points * @param clusters * the List<Cluster> of initial clusters * @param measure * the DistanceMeasure to use * @param maxIter * the maximum number of iterations */ public static List<List<Cluster>> clusterPoints(Iterable<Vector> points, List<Cluster> clusters, DistanceMeasure measure, int maxIter, double distanceThreshold) { List<List<Cluster>> clustersList = Lists.newArrayList(); clustersList.add(clusters); boolean converged = false; int iteration = 0; while (!converged && iteration < maxIter) { log.info("Reference Iteration: {}", iteration); List<Cluster> next = Lists.newArrayList(); for (Cluster c : clustersList.get(iteration)) { next.add(new Cluster(c.getCenter(), c.getId(), measure)); } clustersList.add(next); converged = runKMeansIteration(points, next, measure, distanceThreshold); iteration++; } return clustersList; } /** * Perform a single iteration over the points and clusters, assigning points to clusters and returning if * the iterations are completed. * * @param points * the List<Vector> having the input points * @param clusters * the List<Cluster> clusters * @param measure * a DistanceMeasure to use */ protected static boolean runKMeansIteration(Iterable<Vector> points, Iterable<Cluster> clusters, DistanceMeasure measure, double distanceThreshold) { // iterate through all points, assigning each to the nearest cluster KMeansClusterer clusterer = new KMeansClusterer(measure); for (Vector point : points) { clusterer.addPointToNearestCluster(point, clusters); } return clusterer.testConvergence(clusters, distanceThreshold); } public boolean computeConvergence(Cluster cluster, double distanceThreshold) { return cluster.computeConvergence(measure, distanceThreshold); } }