/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.dirichlet; import java.io.IOException; import java.util.Collection; import java.util.List; import com.google.common.collect.Lists; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile.Writer; import org.apache.hadoop.mapreduce.Mapper; import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.Model; import org.apache.mahout.clustering.ModelDistribution; import org.apache.mahout.clustering.WeightedVectorWritable; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; /** * Performs Bayesian mixture modeling. * <p/> * The idea is that we use a probabilistic mixture of a number of models that we use to explain some observed * data. The idea here is that each observed data point is assumed to have come from one of the models in the * mixture, but we don't know which. The way we deal with that is to use a so-called latent parameter which * specifies which model each data point came from. * <p/> * In addition, since this is a Bayesian clustering algorithm, we don't want to actually commit to any single * explanation, but rather to sample from the distribution of models and latent assignments of data points to * models given the observed data and the prior distributions of model parameters. * <p/> * This sampling process is initialized by taking models at random from the prior distribution for models. * <p/> * Then, we iteratively assign points to the different models using the mixture probabilities and the degree * of fit between the point and each model expressed as a probability that the point was generated by that * model. * <p/> * After points are assigned, new parameters for each model are sampled from the posterior distribution for * the model parameters considering all of the observed data points that were assigned to the model. Models * without any data points are also sampled, but since they have no points assigned, the new samples are * effectively taken from the prior distribution for model parameters. * <p/> * The result is a number of samples that represent mixing probabilities, models and assignment of points to * models. If the total number of possible models is substantially larger than the number that ever have * points assigned to them, then this algorithm provides a (nearly) non-parametric clustering algorithm. * <p/> * These samples can give us interesting information that is lacking from a normal clustering that consists of * a single assignment of points to clusters. Firstly, by examining the number of models in each sample that * actually has any points assigned to it, we can get information about how many models (clusters) that the * data support. * <p/> * Morevoer, by examining how often two points are assigned to the same model, we can get an approximate * measure of how likely these points are to be explained by the same model. Such soft membership information * is difficult to come by with conventional clustering methods. * <p/> * Finally, we can get an idea of the stability of how the data can be described. Typically, aspects of the * data with lots of data available wind up with stable descriptions while at the edges, there are aspects * that are phenomena that we can't really commit to a solid description, but it is still clear that the well * supported explanations are insufficient to explain these additional aspects. * <p/> * One thing that can be difficult about these samples is that we can't always assign a correlation between * the models in the different samples. Probably the best way to do this is to look for overlap in the * assignments of data observations to the different models. * <p/> * * <pre> * \theta_i ~ prior() * \lambda_i ~ Dirichlet(\alpha_0) * z_j ~ Multinomial( \lambda ) * x_j ~ model(\theta_i) * </pre> */ public class DirichletClusterer { // observed data private final List<VectorWritable> sampleData; // the ModelDistribution for the computation private final ModelDistribution<VectorWritable> modelFactory; // the state of the clustering process private final DirichletState state; private final int thin; private final int burnin; private final int numClusters; private final List<Cluster[]> clusterSamples = Lists.newArrayList(); private boolean emitMostLikely; private double threshold; /** * Create a new instance on the sample data with the given additional parameters * * @param points * the observed data to be clustered * @param modelFactory * the ModelDistribution to use * @param alpha0 * the double value for the beta distributions * @param numClusters * the int number of clusters * @param thin * the int thinning interval, used to report every n iterations * @param burnin * the int burnin interval, used to suppress early iterations * @param numIterations * number of iterations to be performed */ public static List<Cluster[]> clusterPoints(List<VectorWritable> points, ModelDistribution<VectorWritable> modelFactory, double alpha0, int numClusters, int thin, int burnin, int numIterations) { DirichletClusterer clusterer = new DirichletClusterer(points, modelFactory, alpha0, numClusters, thin, burnin); return clusterer.cluster(numIterations); } /** * Create a new instance on the sample data with the given additional parameters * * @param sampleData * the observed data to be clustered * @param modelFactory * the ModelDistribution to use * @param alpha0 * the double value for the beta distributions * @param numClusters * the int number of clusters * @param thin * the int thinning interval, used to report every n iterations * @param burnin * the int burnin interval, used to suppress early iterations */ public DirichletClusterer(List<VectorWritable> sampleData, ModelDistribution<VectorWritable> modelFactory, double alpha0, int numClusters, int thin, int burnin) { this.sampleData = sampleData; this.modelFactory = modelFactory; this.thin = thin; this.burnin = burnin; this.numClusters = numClusters; state = new DirichletState(modelFactory, numClusters, alpha0); } /** * This constructor only used by DirichletClusterMapper for setting up clustering params * @param emitMostLikely * @param threshold */ public DirichletClusterer(boolean emitMostLikely, double threshold) { this.sampleData = null; this.modelFactory = null; this.thin = 0; this.burnin = 0; this.numClusters = 0; this.state = null; this.emitMostLikely = emitMostLikely; this.threshold = threshold; } /** * This constructor is used by DirichletMapper and DirichletReducer for setting up their clusterer * @param state */ public DirichletClusterer(DirichletState state) { this.state = state; this.modelFactory = state.getModelFactory(); this.sampleData = null; this.numClusters = state.getNumClusters(); this.thin = 0; this.burnin = 0; } /** * Iterate over the sample data, obtaining cluster samples periodically and returning them. * * @param numIterations * the int number of iterations to perform * @return a List<List<Model<Observation>>> of the observed models */ public List<Cluster[]> cluster(int numIterations) { for (int iteration = 0; iteration < numIterations; iteration++) { iterate(iteration); } return clusterSamples; } /** * Perform one iteration of the clustering process, iterating over the samples to build a new array of * models, then updating the state for the next iteration */ private void iterate(int iteration) { // create new posterior models Cluster[] newModels = (Cluster[]) modelFactory.sampleFromPosterior(state.getModels()); // iterate over the samples, assigning each to a model for (VectorWritable observation : sampleData) { observe(newModels, observation); } // periodically add models to the cluster samples after the burn-in period if (iteration >= burnin && iteration % thin == 0) { clusterSamples.add(newModels); } // update the state from the new models state.update(newModels); } /** * @param newModels * @param observation */ protected void observe(Model<VectorWritable>[] newModels, VectorWritable observation) { int k = assignToModel(observation); // ask the selected model to observe the datum newModels[k].observe(observation); } /** * Assign the observation to one of the models based upon probabilities * @param observation * @return the assigned model's index */ protected int assignToModel(VectorWritable observation) { // compute an unnormalized vector of probabilities that x is described by each model Vector pi = new DenseVector(numClusters); for (int k1 = 0; k1 < numClusters; k1++) { pi.set(k1, state.adjustedProbability(observation, k1)); } // then pick one cluster by sampling a Multinomial distribution based upon them // see: http://en.wikipedia.org/wiki/Multinomial_distribution return UncommonDistributions.rMultinom(pi); } protected void updateModels(Cluster[] newModels) { state.update(newModels); } protected Model<VectorWritable>[] samplePosteriorModels() { return state.getModelFactory().sampleFromPosterior(state.getModels()); } protected DirichletCluster updateCluster(Cluster model, int k) { model.computeParameters(); DirichletCluster cluster = state.getClusters().get(k); cluster.setModel(model); return cluster; } /** * Emit the point to one or more clusters depending upon clusterer state * * @param vector a VectorWritable holding the Vector * @param clusters a List of DirichletClusters * @param context a Mapper.Context to emit to */ public void emitPointToClusters(VectorWritable vector, List<DirichletCluster> clusters, Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException { Vector pi = new DenseVector(clusters.size()); for (int i = 0; i < clusters.size(); i++) { pi.set(i, clusters.get(i).getModel().pdf(vector)); } pi = pi.divide(pi.zSum()); if (emitMostLikely) { emitMostLikelyCluster(vector, clusters, pi, context); } else { emitAllClusters(vector, clusters, pi, context); } } /** * Emit the point to the most likely cluster * * @param point a VectorWritable holding the Vector * @param clusters a List of DirichletClusters * @param pi the normalized pdf Vector for the point * @param context a Mapper.Context to emit to */ private void emitMostLikelyCluster(VectorWritable point, Collection<DirichletCluster> clusters, Vector pi, Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException { int clusterId = -1; double clusterPdf = 0; for (int i = 0; i < clusters.size(); i++) { double pdf = pi.get(i); if (pdf > clusterPdf) { clusterId = i; clusterPdf = pdf; } } //System.out.println(clusterId + ": " + ClusterBase.formatVector(vector.get(), null)); context.write(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, point.get())); } /** * Emit the point to all clusters if pdf exceeds the threshold * @param point a VectorWritable holding the Vector * @param clusters a List of DirichletClusters * @param pi the normalized pdf Vector for the point * @param context a Mapper.Context to emit to */ private void emitAllClusters(VectorWritable point, List<DirichletCluster> clusters, Vector pi, Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context) throws IOException, InterruptedException { for (int i = 0; i < clusters.size(); i++) { double pdf = pi.get(i); if (pdf > threshold && clusters.get(i).getTotalCount() > 0) { //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null)); context.write(new IntWritable(i), new WeightedVectorWritable(pdf, point.get())); } } } /** * Emit the point to one or more clusters depending upon clusterer state * * @param vector a VectorWritable holding the Vector * @param clusters a List of DirichletClusters * @param writer a SequenceFile.Writer to emit to */ public void emitPointToClusters(VectorWritable vector, List<DirichletCluster> clusters, Writer writer) throws IOException { Vector pi = new DenseVector(clusters.size()); for (int i = 0; i < clusters.size(); i++) { double pdf = clusters.get(i).getModel().pdf(vector); pi.set(i, pdf); } pi = pi.divide(pi.zSum()); if (emitMostLikely) { emitMostLikelyCluster(vector, clusters, pi, writer); } else { emitAllClusters(vector, clusters, pi, writer); } } /** * Emit the point to all clusters if pdf exceeds the threshold * * @param vector a VectorWritable holding the Vector * @param clusters a List of DirichletClusters * @param pi the normalized pdf Vector for the point * @param writer a SequenceFile.Writer to emit to */ private void emitAllClusters(VectorWritable vector, List<DirichletCluster> clusters, Vector pi, Writer writer) throws IOException { for (int i = 0; i < clusters.size(); i++) { double pdf = pi.get(i); if (pdf > threshold && clusters.get(i).getTotalCount() > 0) { //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null)); writer.append(new IntWritable(i), new WeightedVectorWritable(pdf, vector.get())); } } } /** * Emit the point to the most likely cluster * * @param vector a VectorWritable holding the Vector * @param clusters a List of DirichletClusters * @param pi the normalized pdf Vector for the point * @param writer a SequenceFile.Writer to emit to */ private static void emitMostLikelyCluster(VectorWritable vector, Collection<DirichletCluster> clusters, Vector pi, Writer writer) throws IOException { double maxPdf = 0; int clusterId = -1; for (int i = 0; i < clusters.size(); i++) { double pdf = pi.get(i); if (pdf > maxPdf) { maxPdf = pdf; clusterId = i; } } //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null)); writer.append(new IntWritable(clusterId), new WeightedVectorWritable(maxPdf, vector.get())); } }