/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.dirichlet.models;
import java.util.Iterator;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
public class GaussianCluster extends AbstractCluster {
public GaussianCluster() {}
public GaussianCluster(Vector point, int id2) {
super(point, id2);
}
public GaussianCluster(Vector center, Vector radius, int id) {
super(center, radius, id);
}
@Override
public String getIdentifier() {
return "GC:" + getId();
}
@Override
public Model<VectorWritable> sampleFromPosterior() {
return new GaussianCluster(getCenter(), getRadius(), getId());
}
/* (non-Javadoc)
* @see org.apache.mahout.clustering.AbstractCluster#setRadius(org.apache.mahout.math.Vector)
*/
@Override
protected void setRadius(Vector s2) {
super.setRadius(s2);
computeProd2piR();
}
// the value of the zProduct(S*2pi) term. Calculated below.
private double zProd2piR;
/**
* Compute the product(r[i]*SQRT2PI) over all i. Note that the cluster Radius
* corresponds to the Stdev of a Gaussian and the Center to its Mean.
*/
private void computeProd2piR() {
zProd2piR = 1.0;
for (Iterator<Element> it = getRadius().iterateNonZero(); it.hasNext();) {
Element radius = it.next();
zProd2piR *= radius.get() * UncommonDistributions.SQRT2PI;
}
}
@Override
public double pdf(VectorWritable vw) {
return Math.exp(-(sumXminusCdivRsquared(vw.get()) / 2)) / zProd2piR;
}
/**
* @param x
* a Vector
* @return the zSum(((x[i]-c[i])/r[i])^2) over all i
*/
private double sumXminusCdivRsquared(Vector x) {
double result = 0;
for (Iterator<Element> it = getRadius().iterateNonZero(); it.hasNext();) {
Element radiusElem = it.next();
int index = radiusElem.index();
double quotient = (x.get(index) - getCenter().get(index))
/ radiusElem.get();
result += quotient * quotient;
}
return result;
}
}