/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.iterator;
import java.io.IOException;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.classify.ClusterClassifier;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Test;
import com.google.common.collect.Lists;
public final class TestClusterClassifier extends MahoutTestCase {
private static ClusterClassifier newDMClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0, measure));
models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2, measure));
return new ClusterClassifier(models, new KMeansClusteringPolicy());
}
private static ClusterClassifier newKlusterClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
return new ClusterClassifier(models, new KMeansClusteringPolicy());
}
private static ClusterClassifier newCosineKlusterClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new CosineDistanceMeasure();
models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
return new ClusterClassifier(models, new KMeansClusteringPolicy());
}
private static ClusterClassifier newSoftClusterClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
models.add(new SoftCluster(new DenseVector(2).assign(1), 0, measure));
models.add(new SoftCluster(new DenseVector(2), 1, measure));
models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
return new ClusterClassifier(models, new FuzzyKMeansClusteringPolicy());
}
private ClusterClassifier writeAndRead(ClusterClassifier classifier) throws IOException {
Path path = new Path(getTestTempDirPath(), "output");
classifier.writeToSeqFiles(path);
ClusterClassifier newClassifier = new ClusterClassifier();
newClassifier.readFromSeqFiles(getConfiguration(), path);
return newClassifier;
}
@Test
public void testDMClusterClassification() {
ClusterClassifier classifier = newDMClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
assertEquals("[0,0]", "[0.2,0.6,0.2]", AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
assertEquals("[2,2]", "[0.493,0.296,0.211]", AbstractCluster.formatVector(pdf, null));
}
@Test
public void testClusterClassification() {
ClusterClassifier classifier = newKlusterClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
assertEquals("[0,0]", "[0.2,0.6,0.2]", AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
assertEquals("[2,2]", "[0.493,0.296,0.211]", AbstractCluster.formatVector(pdf, null));
}
@Test
public void testSoftClusterClassification() {
ClusterClassifier classifier = newSoftClusterClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
assertEquals("[0,0]", "[0.0,1.0,0.0]", AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
assertEquals("[2,2]", "[0.735,0.184,0.082]", AbstractCluster.formatVector(pdf, null));
}
@Test
public void testDMClassifierSerialization() throws Exception {
ClusterClassifier classifier = newDMClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
.getName());
}
@Test
public void testClusterClassifierSerialization() throws Exception {
ClusterClassifier classifier = newKlusterClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
.getName());
}
@Test
public void testSoftClusterClassifierSerialization() throws Exception {
ClusterClassifier classifier = newSoftClusterClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
.getName());
}
@Test
public void testClusterIteratorKMeans() {
List<Vector> data = TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
ClusterClassifier prior = newKlusterClassifier();
ClusterClassifier posterior = ClusterIterator.iterate(data, prior, 5);
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
}
@Test
public void testClusterIteratorDirichlet() {
List<Vector> data = TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
ClusterClassifier prior = newKlusterClassifier();
ClusterClassifier posterior = ClusterIterator.iterate(data, prior, 5);
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
}
@Test
public void testSeqFileClusterIteratorKMeans() throws IOException {
Path pointsPath = getTestTempDirPath("points");
Path priorPath = getTestTempDirPath("prior");
Path outPath = getTestTempDirPath("output");
Configuration conf = getConfiguration();
FileSystem fs = FileSystem.get(pointsPath.toUri(), conf);
List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
Path path = new Path(priorPath, "priorClassifier");
ClusterClassifier prior = newKlusterClassifier();
prior.writeToSeqFiles(path);
assertEquals(3, prior.getModels().size());
System.out.println("Prior");
for (Cluster cluster : prior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
ClusterIterator.iterateSeq(conf, pointsPath, path, outPath, 5);
for (int i = 1; i <= 4; i++) {
System.out.println("Classifier-" + i);
ClusterClassifier posterior = new ClusterClassifier();
String name = i == 4 ? "clusters-4-final" : "clusters-" + i;
posterior.readFromSeqFiles(conf, new Path(outPath, name));
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
}
}
@Test
public void testMRFileClusterIteratorKMeans() throws Exception {
Path pointsPath = getTestTempDirPath("points");
Path priorPath = getTestTempDirPath("prior");
Path outPath = getTestTempDirPath("output");
Configuration conf = getConfiguration();
FileSystem fs = FileSystem.get(pointsPath.toUri(), conf);
List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
Path path = new Path(priorPath, "priorClassifier");
ClusterClassifier prior = newKlusterClassifier();
prior.writeToSeqFiles(path);
ClusteringPolicy policy = new KMeansClusteringPolicy();
ClusterClassifier.writePolicy(policy, path);
assertEquals(3, prior.getModels().size());
System.out.println("Prior");
for (Cluster cluster : prior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
ClusterIterator.iterateMR(conf, pointsPath, path, outPath, 5);
for (int i = 1; i <= 4; i++) {
System.out.println("Classifier-" + i);
ClusterClassifier posterior = new ClusterClassifier();
String name = i == 4 ? "clusters-4-final" : "clusters-" + i;
posterior.readFromSeqFiles(conf, new Path(outPath, name));
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
}
}
@Test
public void testCosineKlusterClassification() {
ClusterClassifier classifier = newCosineKlusterClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
assertEquals("[0,0]", "[0.333,0.333,0.333]", AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
assertEquals("[2,2]", "[0.429,0.429,0.143]", AbstractCluster.formatVector(pdf, null));
}
}