/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.fuzzykmeans; import java.io.IOException; import java.util.Collection; import java.util.List; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.SequenceFile.Writer; import org.apache.hadoop.mapreduce.Mapper; import org.apache.mahout.clustering.ClusterObservations; import org.apache.mahout.clustering.WeightedVectorWritable; import org.apache.mahout.clustering.kmeans.Cluster; import org.apache.mahout.common.ClassUtils; import org.apache.mahout.common.distance.DistanceMeasure; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; public class FuzzyKMeansClusterer { private static final double MINIMAL_VALUE = 0.0000000001; private DistanceMeasure measure; private double convergenceDelta; private double m = 2.0; // default value private boolean emitMostLikely = true; private double threshold; /** * Init the fuzzy k-means clusterer with the distance measure to use for comparison. */ public FuzzyKMeansClusterer(DistanceMeasure measure, double convergenceDelta, double m) { this.measure = measure; this.convergenceDelta = convergenceDelta; this.m = m; } public FuzzyKMeansClusterer(Configuration conf) { this.configure(conf); } public FuzzyKMeansClusterer() { } /** * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters * until their centers converge or until the maximum number of iterations is exceeded. * * @param points * the input List<Vector> of points * @param clusters * the initial List<SoftCluster> of clusters * @param measure * the DistanceMeasure to use * @param threshold * the double convergence threshold * @param m * the double "fuzzyness" argument (>1) * @param numIter * the maximum number of iterations * @return * a List<List<SoftCluster>> of clusters produced per iteration */ public static List<List<SoftCluster>> clusterPoints(Iterable<Vector> points, List<SoftCluster> clusters, DistanceMeasure measure, double threshold, double m, int numIter) { List<List<SoftCluster>> clustersList = Lists.newArrayList(); clustersList.add(clusters); FuzzyKMeansClusterer clusterer = new FuzzyKMeansClusterer(measure, threshold, m); boolean converged = false; int iteration = 0; for (int iter = 0; !converged && iter < numIter; iter++) { List<SoftCluster> next = Lists.newArrayList(); List<SoftCluster> cs = clustersList.get(iteration++); for (SoftCluster c : cs) { next.add(new SoftCluster(c.getCenter(), c.getId(), measure)); } clustersList.add(next); converged = runFuzzyKMeansIteration(points, clustersList.get(iteration), clusterer); } return clustersList; } /** * Perform a single iteration over the points and clusters, assigning points to clusters and returning if * the iterations are completed. * * @param points * the List<Vector> having the input points * @param clusterList * the List<Cluster> clusters */ protected static boolean runFuzzyKMeansIteration(Iterable<Vector> points, List<SoftCluster> clusterList, FuzzyKMeansClusterer clusterer) { for (Vector point : points) { clusterer.addPointToClusters(clusterList, point); } return clusterer.testConvergence(clusterList); } /** * Configure the distance measure from the job */ private void configure(Configuration job) { measure = ClassUtils.instantiateAs(job.get(FuzzyKMeansConfigKeys.DISTANCE_MEASURE_KEY), DistanceMeasure.class); measure.configure(job); convergenceDelta = Double.parseDouble(job.get(FuzzyKMeansConfigKeys.CLUSTER_CONVERGENCE_KEY)); // nextClusterId = 0; m = Double.parseDouble(job.get(FuzzyKMeansConfigKeys.M_KEY)); emitMostLikely = Boolean.parseBoolean(job.get(FuzzyKMeansConfigKeys.EMIT_MOST_LIKELY_KEY)); threshold = Double.parseDouble(job.get(FuzzyKMeansConfigKeys.THRESHOLD_KEY)); } /** * Emit the point and its probability of belongingness to each cluster * * @param point * a point * @param clusters * a List<SoftCluster> * @param context * the Context to emit into */ public void emitPointProbToCluster(Vector point, List<SoftCluster> clusters, Mapper<?,?,Text,ClusterObservations>.Context context) throws IOException, InterruptedException { List<Double> clusterDistanceList = Lists.newArrayList(); for (SoftCluster cluster : clusters) { clusterDistanceList.add(measure.distance(cluster.getCenter(), point)); } for (int i = 0; i < clusters.size(); i++) { SoftCluster cluster = clusters.get(i); Text key = new Text(cluster.getIdentifier()); ClusterObservations value = new ClusterObservations(computeProbWeight(clusterDistanceList.get(i), clusterDistanceList), point, point.times(point)); context.write(key, value); } } /** Computes the probability of a point belonging to a cluster */ public double computeProbWeight(double clusterDistance, Iterable<Double> clusterDistanceList) { if (clusterDistance == 0) { clusterDistance = MINIMAL_VALUE; } double denom = 0.0; for (double eachCDist : clusterDistanceList) { if (eachCDist == 0.0) { eachCDist = MINIMAL_VALUE; } denom += Math.pow(clusterDistance / eachCDist, 2.0 / (m - 1)); } return 1.0 / denom; } /** * Return if the cluster is converged by comparing its center and centroid. * * @return if the cluster is converged */ public boolean computeConvergence(Cluster cluster) { return cluster.computeConvergence(measure, convergenceDelta); } public double getM() { return m; } public DistanceMeasure getMeasure() { return this.measure; } public void emitPointToClusters(VectorWritable point, List<SoftCluster> clusters, Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException { // calculate point distances for all clusters List<Double> clusterDistanceList = Lists.newArrayList(); for (SoftCluster cluster : clusters) { clusterDistanceList.add(getMeasure().distance(cluster.getCenter(), point.get())); } // calculate point pdf for all clusters Vector pi = computePi(clusters, clusterDistanceList); if (emitMostLikely) { emitMostLikelyCluster(point.get(), clusters, pi, context); } else { emitAllClusters(point.get(), clusters, pi, context); } } public Vector computePi(Collection<SoftCluster> clusters, List<Double> clusterDistanceList) { Vector pi = new DenseVector(clusters.size()); for (int i = 0; i < clusters.size(); i++) { double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList); pi.set(i, probWeight); } return pi; } /** * Emit the point to the cluster with the highest pdf */ private void emitMostLikelyCluster(Vector point, List<SoftCluster> clusters, Vector pi, Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException { int clusterId = -1; double clusterPdf = 0.0; for (int i = 0; i < clusters.size(); i++) { // System.out.println("cluster-" + clusters.get(i).getId() + "@ " + ClusterBase.formatVector(center, null)); double pdf = pi.get(i); if (pdf > clusterPdf) { clusterId = clusters.get(i).getId(); clusterPdf = pdf; } } // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null)); context.write(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, point)); } /** * Emit the point to all clusters */ private void emitAllClusters(Vector point, Collection<SoftCluster> clusters, Vector pi, Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException { for (int i = 0; i < clusters.size(); i++) { double pdf = pi.get(i); if (pdf > threshold) { // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null)); context.write(new IntWritable(i), new WeightedVectorWritable(pdf, point)); } } } protected void addPointToClusters(List<SoftCluster> clusterList, Vector point) { List<Double> clusterDistanceList = Lists.newArrayList(); for (SoftCluster cluster : clusterList) { clusterDistanceList.add(getMeasure().distance(point, cluster.getCenter())); } for (int i = 0; i < clusterList.size(); i++) { double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList); clusterList.get(i).observe(point, Math.pow(probWeight, getM())); } } protected boolean testConvergence(Iterable<SoftCluster> clusters) { boolean converged = true; for (SoftCluster cluster : clusters) { if (!cluster.computeConvergence(measure, convergenceDelta)) { converged = false; } cluster.computeParameters(); } return converged; } public void emitPointToClusters(VectorWritable point, List<SoftCluster> clusters, Writer writer) throws IOException { // calculate point distances for all clusters List<Double> clusterDistanceList = Lists.newArrayList(); for (SoftCluster cluster : clusters) { clusterDistanceList.add(getMeasure().distance(cluster.getCenter(), point.get())); } Vector pi = computePi(clusters, clusterDistanceList); if (emitMostLikely) { emitMostLikelyCluster(point.get(), clusters, pi, writer); } else { emitAllClusters(point.get(), clusters, pi, writer); } } private void emitAllClusters(Vector point, Collection<SoftCluster> clusters, Vector pi, Writer writer) throws IOException { for (int i = 0; i < clusters.size(); i++) { double pdf = pi.get(i); if (pdf > threshold) { // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null)); writer.append(new IntWritable(i), new WeightedVectorWritable(pdf, point)); } } } private static void emitMostLikelyCluster(Vector point, List<SoftCluster> clusters, Vector pi, Writer writer) throws IOException { int clusterId = -1; double clusterPdf = 0.0; for (int i = 0; i < clusters.size(); i++) { // System.out.println("cluster-" + clusters.get(i).getId() + "@ " + ClusterBase.formatVector(center, null)); double pdf = pi.get(i); if (pdf > clusterPdf) { clusterId = clusters.get(i).getId(); clusterPdf = pdf; } } // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null)); writer.append(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, point)); } }