/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.clustering.dirichlet; import java.util.List; import com.google.common.collect.Lists; import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution; import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.VectorWritable; import org.junit.Before; import org.junit.Test; public final class TestDirichletClustering extends MahoutTestCase { private List<VectorWritable> sampleData; @Override @Before public void setUp() throws Exception { super.setUp(); sampleData = Lists.newArrayList(); } /** * Generate random samples and add them to the sampleData * * @param num int number of samples to generate * @param mx double x-value of the sample mean * @param my double y-value of the sample mean * @param sd double standard deviation of the samples * @param card int cardinality of the generated sample vectors */ private void generateSamples(int num, double mx, double my, double sd, int card) { System.out.println("Generating " + num + " samples m=[" + mx + ", " + my + "] sd=" + sd); for (int i = 0; i < num; i++) { DenseVector v = new DenseVector(card); for (int j = 0; j < card; j++) { v.set(j, UncommonDistributions.rNorm(mx, sd)); } sampleData.add(new VectorWritable(v)); } } /** * Generate 2-d samples for backwards compatibility with existing tests * @param num int number of samples to generate * @param mx double x-value of the sample mean * @param my double y-value of the sample mean * @param sd double standard deviation of the samples */ private void generateSamples(int num, double mx, double my, double sd) { generateSamples(num, mx, my, sd, 2); } private static void printResults(Iterable<Cluster[]> result, int significant) { int row = 0; for (Cluster[] r : result) { System.out.print("sample[" + row++ + "]= "); for (Cluster model : r) { if (model.count() > significant) { System.out.print(model.asFormatString(null) + ", "); } } System.out.println(); } System.out.println(); } @Test public void testDirichletCluster100() { System.out.println("testDirichletCluster100"); generateSamples(40, 1, 1, 3); generateSamples(30, 1, 0, 0.1); generateSamples(30, 0, 1, 0.1); DirichletClusterer dc = new DirichletClusterer(sampleData, new GaussianClusterDistribution(new VectorWritable(new DenseVector(2))), 1.0, 10, 1, 0); List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @Test public void testDirichletGaussianCluster100() { System.out.println("testDirichletGaussianCluster100"); generateSamples(40, 1, 1, 3); generateSamples(30, 1, 0, 0.1); generateSamples(30, 0, 1, 0.1); DirichletClusterer dc = new DirichletClusterer(sampleData, new GaussianClusterDistribution(new VectorWritable(new DenseVector(2))), 1.0, 10, 1, 0); List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @Test public void testDirichletDMCluster100() { System.out.println("testDirichletDMCluster100"); generateSamples(40, 1, 1, 3); generateSamples(30, 1, 0, 0.1); generateSamples(30, 0, 1, 0.1); DirichletClusterer dc = new DirichletClusterer(sampleData, new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2))), 1.0, 10, 1, 0); List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } }