/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering; import java.io.IOException; import java.util.List; import com.google.common.collect.Lists; import com.google.common.io.Closeables; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.mahout.clustering.canopy.Canopy; import org.apache.mahout.clustering.dirichlet.models.GaussianCluster; import org.apache.mahout.clustering.fuzzykmeans.SoftCluster; import org.apache.mahout.clustering.kmeans.TestKmeansClustering; import org.apache.mahout.clustering.meanshift.MeanShiftCanopy; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.distance.DistanceMeasure; import org.apache.mahout.common.distance.ManhattanDistanceMeasure; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.junit.Test; public final class TestClusterClassifier extends MahoutTestCase { private static ClusterClassifier newDMClassifier() { List<Cluster> models = Lists.newArrayList(); DistanceMeasure measure = new ManhattanDistanceMeasure(); models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0, measure)); models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure)); models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2, measure)); return new ClusterClassifier(models); } private static ClusterClassifier newClusterClassifier() { List<Cluster> models = Lists.newArrayList(); DistanceMeasure measure = new ManhattanDistanceMeasure(); models.add(new org.apache.mahout.clustering.kmeans.Cluster(new DenseVector( 2).assign(1), 0, measure)); models.add(new org.apache.mahout.clustering.kmeans.Cluster(new DenseVector( 2), 1, measure)); models.add(new org.apache.mahout.clustering.kmeans.Cluster(new DenseVector( 2).assign(-1), 2, measure)); return new ClusterClassifier(models); } private static ClusterClassifier newSoftClusterClassifier() { List<Cluster> models = Lists.newArrayList(); DistanceMeasure measure = new ManhattanDistanceMeasure(); models.add(new SoftCluster(new DenseVector(2).assign(1), 0, measure)); models.add(new SoftCluster(new DenseVector(2), 1, measure)); models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure)); return new ClusterClassifier(models); } private static ClusterClassifier newGaussianClassifier() { List<Cluster> models = Lists.newArrayList(); models.add(new GaussianCluster(new DenseVector(2).assign(1), new DenseVector(2).assign(1), 0)); models.add(new GaussianCluster(new DenseVector(2), new DenseVector(2) .assign(1), 1)); models.add(new GaussianCluster(new DenseVector(2).assign(-1), new DenseVector(2).assign(1), 2)); return new ClusterClassifier(models); } private ClusterClassifier writeAndRead(ClusterClassifier classifier) throws IOException { Configuration config = new Configuration(); Path path = new Path(getTestTempDirPath(), "output"); FileSystem fs = FileSystem.get(path.toUri(), config); writeClassifier(classifier, config, path, fs); return readClassifier(config, path, fs); } private static void writeClassifier(ClusterClassifier classifier, Configuration config, Path path, FileSystem fs) throws IOException { SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, path, Text.class, ClusterClassifier.class); Writable key = new Text("test"); try { writer.append(key, classifier); } finally { Closeables.closeQuietly(writer); } } private static ClusterClassifier readClassifier(Configuration config, Path path, FileSystem fs) throws IOException { SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config); Writable key = new Text(); ClusterClassifier classifierOut = new ClusterClassifier(); try { reader.next(key, classifierOut); } finally { Closeables.closeQuietly(reader); } return classifierOut; } @Test public void testDMClusterClassification() { ClusterClassifier classifier = newDMClassifier(); Vector pdf = classifier.classify(new DenseVector(2)); assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(pdf, null)); pdf = classifier.classify(new DenseVector(2).assign(2)); assertEquals("[2,2]", "[0.493, 0.296, 0.211]", AbstractCluster.formatVector(pdf, null)); } @Test public void testCanopyClassification() { List<Cluster> models = Lists.newArrayList(); DistanceMeasure measure = new ManhattanDistanceMeasure(); models.add(new Canopy(new DenseVector(2).assign(1), 0, measure)); models.add(new Canopy(new DenseVector(2), 1, measure)); models.add(new Canopy(new DenseVector(2).assign(-1), 2, measure)); ClusterClassifier classifier = new ClusterClassifier(models); Vector pdf = classifier.classify(new DenseVector(2)); assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(pdf, null)); pdf = classifier.classify(new DenseVector(2).assign(2)); assertEquals("[2,2]", "[0.493, 0.296, 0.211]", AbstractCluster.formatVector(pdf, null)); } @Test public void testClusterClassification() { ClusterClassifier classifier = newClusterClassifier(); Vector pdf = classifier.classify(new DenseVector(2)); assertEquals("[0,0]", "[0.200, 0.600, 0.200]", AbstractCluster.formatVector(pdf, null)); pdf = classifier.classify(new DenseVector(2).assign(2)); assertEquals("[2,2]", "[0.493, 0.296, 0.211]", AbstractCluster.formatVector(pdf, null)); } @Test(expected = UnsupportedOperationException.class) public void testMSCanopyClassification() { List<Cluster> models = Lists.newArrayList(); DistanceMeasure measure = new ManhattanDistanceMeasure(); models.add(new MeanShiftCanopy(new DenseVector(2).assign(1), 0, measure)); models.add(new MeanShiftCanopy(new DenseVector(2), 1, measure)); models.add(new MeanShiftCanopy(new DenseVector(2).assign(-1), 2, measure)); ClusterClassifier classifier = new ClusterClassifier(models); classifier.classify(new DenseVector(2)); } @Test public void testSoftClusterClassification() { ClusterClassifier classifier = newSoftClusterClassifier(); Vector pdf = classifier.classify(new DenseVector(2)); assertEquals("[0,0]", "[0.000, 1.000, 0.000]", AbstractCluster.formatVector(pdf, null)); pdf = classifier.classify(new DenseVector(2).assign(2)); assertEquals("[2,2]", "[0.735, 0.184, 0.082]", AbstractCluster.formatVector(pdf, null)); } @Test public void testGaussianClusterClassification() { ClusterClassifier classifier = newGaussianClassifier(); Vector pdf = classifier.classify(new DenseVector(2)); assertEquals("[0,0]", "[0.212, 0.576, 0.212]", AbstractCluster.formatVector(pdf, null)); pdf = classifier.classify(new DenseVector(2).assign(2)); assertEquals("[2,2]", "[0.952, 0.047, 0.000]", AbstractCluster.formatVector(pdf, null)); } @Test public void testDMClassifierSerialization() throws Exception { ClusterClassifier classifier = newDMClassifier(); ClusterClassifier classifierOut = writeAndRead(classifier); assertEquals(classifier.getModels().size(), classifierOut.getModels() .size()); assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass().getName()); } @Test public void testClusterClassifierSerialization() throws Exception { ClusterClassifier classifier = newClusterClassifier(); ClusterClassifier classifierOut = writeAndRead(classifier); assertEquals(classifier.getModels().size(), classifierOut.getModels() .size()); assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass().getName()); } @Test public void testSoftClusterClassifierSerialization() throws Exception { ClusterClassifier classifier = newSoftClusterClassifier(); ClusterClassifier classifierOut = writeAndRead(classifier); assertEquals(classifier.getModels().size(), classifierOut.getModels() .size()); assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass().getName()); } @Test public void testGaussianClassifierSerialization() throws Exception { ClusterClassifier classifier = newGaussianClassifier(); ClusterClassifier classifierOut = writeAndRead(classifier); assertEquals(classifier.getModels().size(), classifierOut.getModels() .size()); assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass().getName()); } @Test public void testClusterIteratorKMeans() { List<Vector> data = TestKmeansClustering .getPoints(TestKmeansClustering.REFERENCE); ClusteringPolicy policy = new KMeansClusteringPolicy(); ClusterClassifier prior = newClusterClassifier(); ClusterIterator iterator = new ClusterIterator(policy); ClusterClassifier posterior = iterator.iterate(data, prior, 5); assertEquals(3, posterior.getModels().size()); for (Cluster cluster : posterior.getModels()) { System.out.println(cluster.asFormatString(null)); } } @Test public void testClusterIteratorDirichlet() { List<Vector> data = TestKmeansClustering .getPoints(TestKmeansClustering.REFERENCE); ClusteringPolicy policy = new DirichletClusteringPolicy(3, 1); ClusterClassifier prior = newClusterClassifier(); ClusterIterator iterator = new ClusterIterator(policy); ClusterClassifier posterior = iterator.iterate(data, prior, 5); assertEquals(3, posterior.getModels().size()); for (Cluster cluster : posterior.getModels()) { System.out.println(cluster.asFormatString(null)); } } @Test public void testSeqFileClusterIteratorKMeans() throws IOException { Path pointsPath = getTestTempDirPath("points"); Path priorPath = getTestTempDirPath("prior"); Path outPath = getTestTempDirPath("output"); Configuration conf = new Configuration(); FileSystem fs = FileSystem.get(conf); List<VectorWritable> points = TestKmeansClustering .getPointsWritable(TestKmeansClustering.REFERENCE); ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf); Path path = new Path(priorPath, "priorClassifier"); ClusterClassifier prior = newClusterClassifier(); writeClassifier(prior, conf, path, fs); assertEquals(3, prior.getModels().size()); System.out.println("Prior"); for (Cluster cluster : prior.getModels()) { System.out.println(cluster.asFormatString(null)); } ClusteringPolicy policy = new KMeansClusteringPolicy(); ClusterIterator iterator = new ClusterIterator(policy); iterator.iterateSeq(pointsPath, path, outPath, 5); for (int i = 1; i <= 5; i++) { System.out.println("Classifier-" + i); ClusterClassifier posterior = readClassifier(conf, new Path(outPath, "classifier-" + i), fs); assertEquals(3, posterior.getModels().size()); for (Cluster cluster : posterior.getModels()) { System.out.println(cluster.asFormatString(null)); } } } }