/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.dirichlet;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import com.google.common.collect.Lists;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
/**
* Performs Bayesian mixture modeling.
* <p/>
* The idea is that we use a probabilistic mixture of a number of models that we use to explain some observed
* data. The idea here is that each observed data point is assumed to have come from one of the models in the
* mixture, but we don't know which. The way we deal with that is to use a so-called latent parameter which
* specifies which model each data point came from.
* <p/>
* In addition, since this is a Bayesian clustering algorithm, we don't want to actually commit to any single
* explanation, but rather to sample from the distribution of models and latent assignments of data points to
* models given the observed data and the prior distributions of model parameters.
* <p/>
* This sampling process is initialized by taking models at random from the prior distribution for models.
* <p/>
* Then, we iteratively assign points to the different models using the mixture probabilities and the degree
* of fit between the point and each model expressed as a probability that the point was generated by that
* model.
* <p/>
* After points are assigned, new parameters for each model are sampled from the posterior distribution for
* the model parameters considering all of the observed data points that were assigned to the model. Models
* without any data points are also sampled, but since they have no points assigned, the new samples are
* effectively taken from the prior distribution for model parameters.
* <p/>
* The result is a number of samples that represent mixing probabilities, models and assignment of points to
* models. If the total number of possible models is substantially larger than the number that ever have
* points assigned to them, then this algorithm provides a (nearly) non-parametric clustering algorithm.
* <p/>
* These samples can give us interesting information that is lacking from a normal clustering that consists of
* a single assignment of points to clusters. Firstly, by examining the number of models in each sample that
* actually has any points assigned to it, we can get information about how many models (clusters) that the
* data support.
* <p/>
* Morevoer, by examining how often two points are assigned to the same model, we can get an approximate
* measure of how likely these points are to be explained by the same model. Such soft membership information
* is difficult to come by with conventional clustering methods.
* <p/>
* Finally, we can get an idea of the stability of how the data can be described. Typically, aspects of the
* data with lots of data available wind up with stable descriptions while at the edges, there are aspects
* that are phenomena that we can't really commit to a solid description, but it is still clear that the well
* supported explanations are insufficient to explain these additional aspects.
* <p/>
* One thing that can be difficult about these samples is that we can't always assign a correlation between
* the models in the different samples. Probably the best way to do this is to look for overlap in the
* assignments of data observations to the different models.
* <p/>
*
* <pre>
* \theta_i ~ prior()
* \lambda_i ~ Dirichlet(\alpha_0)
* z_j ~ Multinomial( \lambda )
* x_j ~ model(\theta_i)
* </pre>
*/
public class DirichletClusterer {
// observed data
private final List<VectorWritable> sampleData;
// the ModelDistribution for the computation
private final ModelDistribution<VectorWritable> modelFactory;
// the state of the clustering process
private final DirichletState state;
private final int thin;
private final int burnin;
private final int numClusters;
private final List<Cluster[]> clusterSamples = Lists.newArrayList();
private boolean emitMostLikely;
private double threshold;
/**
* Create a new instance on the sample data with the given additional parameters
*
* @param points
* the observed data to be clustered
* @param modelFactory
* the ModelDistribution to use
* @param alpha0
* the double value for the beta distributions
* @param numClusters
* the int number of clusters
* @param thin
* the int thinning interval, used to report every n iterations
* @param burnin
* the int burnin interval, used to suppress early iterations
* @param numIterations
* number of iterations to be performed
*/
public static List<Cluster[]> clusterPoints(List<VectorWritable> points,
ModelDistribution<VectorWritable> modelFactory,
double alpha0,
int numClusters,
int thin,
int burnin,
int numIterations) {
DirichletClusterer clusterer = new DirichletClusterer(points, modelFactory, alpha0, numClusters, thin, burnin);
return clusterer.cluster(numIterations);
}
/**
* Create a new instance on the sample data with the given additional parameters
*
* @param sampleData
* the observed data to be clustered
* @param modelFactory
* the ModelDistribution to use
* @param alpha0
* the double value for the beta distributions
* @param numClusters
* the int number of clusters
* @param thin
* the int thinning interval, used to report every n iterations
* @param burnin
* the int burnin interval, used to suppress early iterations
*/
public DirichletClusterer(List<VectorWritable> sampleData,
ModelDistribution<VectorWritable> modelFactory,
double alpha0,
int numClusters,
int thin,
int burnin) {
this.sampleData = sampleData;
this.modelFactory = modelFactory;
this.thin = thin;
this.burnin = burnin;
this.numClusters = numClusters;
state = new DirichletState(modelFactory, numClusters, alpha0);
}
/**
* This constructor only used by DirichletClusterMapper for setting up clustering params
* @param emitMostLikely
* @param threshold
*/
public DirichletClusterer(boolean emitMostLikely, double threshold) {
this.sampleData = null;
this.modelFactory = null;
this.thin = 0;
this.burnin = 0;
this.numClusters = 0;
this.state = null;
this.emitMostLikely = emitMostLikely;
this.threshold = threshold;
}
/**
* This constructor is used by DirichletMapper and DirichletReducer for setting up their clusterer
* @param state
*/
public DirichletClusterer(DirichletState state) {
this.state = state;
this.modelFactory = state.getModelFactory();
this.sampleData = null;
this.numClusters = state.getNumClusters();
this.thin = 0;
this.burnin = 0;
}
/**
* Iterate over the sample data, obtaining cluster samples periodically and returning them.
*
* @param numIterations
* the int number of iterations to perform
* @return a List<List<Model<Observation>>> of the observed models
*/
public List<Cluster[]> cluster(int numIterations) {
for (int iteration = 0; iteration < numIterations; iteration++) {
iterate(iteration);
}
return clusterSamples;
}
/**
* Perform one iteration of the clustering process, iterating over the samples to build a new array of
* models, then updating the state for the next iteration
*/
private void iterate(int iteration) {
// create new posterior models
Cluster[] newModels = (Cluster[]) modelFactory.sampleFromPosterior(state.getModels());
// iterate over the samples, assigning each to a model
for (VectorWritable observation : sampleData) {
observe(newModels, observation);
}
// periodically add models to the cluster samples after the burn-in period
if (iteration >= burnin && iteration % thin == 0) {
clusterSamples.add(newModels);
}
// update the state from the new models
state.update(newModels);
}
/**
* @param newModels
* @param observation
*/
protected void observe(Model<VectorWritable>[] newModels, VectorWritable observation) {
int k = assignToModel(observation);
// ask the selected model to observe the datum
newModels[k].observe(observation);
}
/**
* Assign the observation to one of the models based upon probabilities
* @param observation
* @return the assigned model's index
*/
protected int assignToModel(VectorWritable observation) {
// compute an unnormalized vector of probabilities that x is described by each model
Vector pi = new DenseVector(numClusters);
for (int k1 = 0; k1 < numClusters; k1++) {
pi.set(k1, state.adjustedProbability(observation, k1));
}
// then pick one cluster by sampling a Multinomial distribution based upon them
// see: http://en.wikipedia.org/wiki/Multinomial_distribution
return UncommonDistributions.rMultinom(pi);
}
protected void updateModels(Cluster[] newModels) {
state.update(newModels);
}
protected Model<VectorWritable>[] samplePosteriorModels() {
return state.getModelFactory().sampleFromPosterior(state.getModels());
}
protected DirichletCluster updateCluster(Cluster model, int k) {
model.computeParameters();
DirichletCluster cluster = state.getClusters().get(k);
cluster.setModel(model);
return cluster;
}
/**
* Emit the point to one or more clusters depending upon clusterer state
*
* @param vector a VectorWritable holding the Vector
* @param clusters a List of DirichletClusters
* @param context a Mapper.Context to emit to
*/
public void emitPointToClusters(VectorWritable vector,
List<DirichletCluster> clusters,
Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context)
throws IOException, InterruptedException {
Vector pi = new DenseVector(clusters.size());
for (int i = 0; i < clusters.size(); i++) {
pi.set(i, clusters.get(i).getModel().pdf(vector));
}
pi = pi.divide(pi.zSum());
if (emitMostLikely) {
emitMostLikelyCluster(vector, clusters, pi, context);
} else {
emitAllClusters(vector, clusters, pi, context);
}
}
/**
* Emit the point to the most likely cluster
*
* @param point a VectorWritable holding the Vector
* @param clusters a List of DirichletClusters
* @param pi the normalized pdf Vector for the point
* @param context a Mapper.Context to emit to
*/
private void emitMostLikelyCluster(VectorWritable point,
Collection<DirichletCluster> clusters,
Vector pi,
Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context)
throws IOException, InterruptedException {
int clusterId = -1;
double clusterPdf = 0;
for (int i = 0; i < clusters.size(); i++) {
double pdf = pi.get(i);
if (pdf > clusterPdf) {
clusterId = i;
clusterPdf = pdf;
}
}
//System.out.println(clusterId + ": " + ClusterBase.formatVector(vector.get(), null));
context.write(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, point.get()));
}
/**
* Emit the point to all clusters if pdf exceeds the threshold
* @param point a VectorWritable holding the Vector
* @param clusters a List of DirichletClusters
* @param pi the normalized pdf Vector for the point
* @param context a Mapper.Context to emit to
*/
private void emitAllClusters(VectorWritable point,
List<DirichletCluster> clusters,
Vector pi,
Mapper<?,?,IntWritable,WeightedVectorWritable>.Context context)
throws IOException, InterruptedException {
for (int i = 0; i < clusters.size(); i++) {
double pdf = pi.get(i);
if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
//System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null));
context.write(new IntWritable(i), new WeightedVectorWritable(pdf, point.get()));
}
}
}
/**
* Emit the point to one or more clusters depending upon clusterer state
*
* @param vector a VectorWritable holding the Vector
* @param clusters a List of DirichletClusters
* @param writer a SequenceFile.Writer to emit to
*/
public void emitPointToClusters(VectorWritable vector, List<DirichletCluster> clusters, Writer writer)
throws IOException {
Vector pi = new DenseVector(clusters.size());
for (int i = 0; i < clusters.size(); i++) {
double pdf = clusters.get(i).getModel().pdf(vector);
pi.set(i, pdf);
}
pi = pi.divide(pi.zSum());
if (emitMostLikely) {
emitMostLikelyCluster(vector, clusters, pi, writer);
} else {
emitAllClusters(vector, clusters, pi, writer);
}
}
/**
* Emit the point to all clusters if pdf exceeds the threshold
*
* @param vector a VectorWritable holding the Vector
* @param clusters a List of DirichletClusters
* @param pi the normalized pdf Vector for the point
* @param writer a SequenceFile.Writer to emit to
*/
private void emitAllClusters(VectorWritable vector, List<DirichletCluster> clusters, Vector pi, Writer writer)
throws IOException {
for (int i = 0; i < clusters.size(); i++) {
double pdf = pi.get(i);
if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
//System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null));
writer.append(new IntWritable(i), new WeightedVectorWritable(pdf, vector.get()));
}
}
}
/**
* Emit the point to the most likely cluster
*
* @param vector a VectorWritable holding the Vector
* @param clusters a List of DirichletClusters
* @param pi the normalized pdf Vector for the point
* @param writer a SequenceFile.Writer to emit to
*/
private static void emitMostLikelyCluster(VectorWritable vector,
Collection<DirichletCluster> clusters,
Vector pi,
Writer writer) throws IOException {
double maxPdf = 0;
int clusterId = -1;
for (int i = 0; i < clusters.size(); i++) {
double pdf = pi.get(i);
if (pdf > maxPdf) {
maxPdf = pdf;
clusterId = i;
}
}
//System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), null));
writer.append(new IntWritable(clusterId), new WeightedVectorWritable(maxPdf, vector.get()));
}
}