/* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import com.google.common.collect.Lists;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.TimesFunction;
/**
* This classifier works with any clustering Cluster. It is initialized with a
* list of compatible clusters and thereafter it can classify any new Vector
* into one or more of the clusters based upon the pdf() function which each
* cluster supports.
*
* In addition, it is an OnlineLearner and can be trained. Training amounts to
* asking the actual model to observe the vector and closing the classifier
* causes all the models to computeParameters.
*/
public class ClusterClassifier extends AbstractVectorClassifier implements OnlineLearner, Writable {
private List<Cluster> models;
private String modelClass;
/**
* The public constructor accepts a list of clusters to become the models
*
* @param models
* a List<Cluster>
*/
public ClusterClassifier(List<Cluster> models) {
this.models = models;
modelClass = models.get(0).getClass().getName();
}
// needed for serialization/deserialization
public ClusterClassifier() {}
@Override
public Vector classify(Vector instance) {
if (models.get(0) instanceof SoftCluster) {
Collection<SoftCluster> clusters = Lists.newArrayList();
List<Double> distances = Lists.newArrayList();
for (Cluster model : models) {
SoftCluster sc = (SoftCluster) model;
clusters.add(sc);
distances.add(sc.getMeasure().distance(instance, sc.getCenter()));
}
return new FuzzyKMeansClusterer().computePi(clusters, distances);
} else {
int i = 0;
Vector pdfs = new DenseVector(models.size());
for (Cluster model : models) {
pdfs.set(i++, model.pdf(new VectorWritable(instance)));
}
return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
}
}
@Override
public double classifyScalar(Vector instance) {
if (models.size() == 2) {
double pdf0 = models.get(0).pdf(new VectorWritable(instance));
double pdf1 = models.get(1).pdf(new VectorWritable(instance));
return pdf0 / (pdf0 + pdf1);
}
throw new IllegalStateException();
}
@Override
public int numCategories() {
return models.size();
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(models.size());
out.writeUTF(modelClass);
for (Cluster cluster : models) {
cluster.write(out);
}
}
@Override
public void readFields(DataInput in) throws IOException {
int size = in.readInt();
modelClass = in.readUTF();
models = Lists.newArrayList();
for (int i = 0; i < size; i++) {
Cluster element = ClassUtils.instantiateAs(modelClass, Cluster.class);
element.readFields(in);
models.add(element);
}
}
@Override
public void train(int actual, Vector instance) {
models.get(actual).observe(new VectorWritable(instance));
}
/**
* Train the models given an additional weight. Unique to ClusterClassifier
*
* @param actual
* the int index of a model
* @param data
* a data Vector
* @param weight
* a double weighting factor
*/
public void train(int actual, Vector data, double weight) {
models.get(actual).observe(new VectorWritable(data), weight);
}
@Override
public void train(long trackingKey, String groupKey, int actual, Vector instance) {
models.get(actual).observe(new VectorWritable(instance));
}
@Override
public void train(long trackingKey, int actual, Vector instance) {
models.get(actual).observe(new VectorWritable(instance));
}
@Override
public void close() {
for (Cluster cluster : models) {
cluster.computeParameters();
}
}
public List<Cluster> getModels() {
return models;
}
}