/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.classify;
import java.io.IOException;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
public class ClusterClassificationDriverTest extends MahoutTestCase {
private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4},
{5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
private FileSystem fs;
private Path clusteringOutputPath;
private Configuration conf;
private Path pointsPath;
private Path classifiedOutputPath;
private List<Vector> firstCluster;
private List<Vector> secondCluster;
private List<Vector> thirdCluster;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
Configuration conf = getConfiguration();
fs = FileSystem.get(conf);
firstCluster = Lists.newArrayList();
secondCluster = Lists.newArrayList();
thirdCluster = Lists.newArrayList();
}
private static List<VectorWritable> getPointsWritable(double[][] raw) {
List<VectorWritable> points = Lists.newArrayList();
for (double[] fr : raw) {
Vector vec = new RandomAccessSparseVector(fr.length);
vec.assign(fr);
points.add(new VectorWritable(vec));
}
return points;
}
@Test
public void testVectorClassificationWithOutlierRemovalMR() throws Exception {
List<VectorWritable> points = getPointsWritable(REFERENCE);
pointsPath = getTestTempDirPath("points");
clusteringOutputPath = getTestTempDirPath("output");
classifiedOutputPath = getTestTempDirPath("classifiedClusters");
HadoopUtil.delete(conf, classifiedOutputPath);
conf = getConfiguration();
ClusteringTestUtils.writePointsToFile(points, true,
new Path(pointsPath, "file1"), fs, conf);
runClustering(pointsPath, conf, false);
runClassificationWithOutlierRemoval(false);
collectVectorsForAssertion();
assertVectorsWithOutlierRemoval();
}
@Test
public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
List<VectorWritable> points = getPointsWritable(REFERENCE);
pointsPath = getTestTempDirPath("points");
clusteringOutputPath = getTestTempDirPath("output");
classifiedOutputPath = getTestTempDirPath("classify");
conf = getConfiguration();
ClusteringTestUtils.writePointsToFile(points,
new Path(pointsPath, "file1"), fs, conf);
runClustering(pointsPath, conf, true);
runClassificationWithoutOutlierRemoval();
collectVectorsForAssertion();
assertVectorsWithoutOutlierRemoval();
}
@Test
public void testVectorClassificationWithOutlierRemoval() throws Exception {
List<VectorWritable> points = getPointsWritable(REFERENCE);
pointsPath = getTestTempDirPath("points");
clusteringOutputPath = getTestTempDirPath("output");
classifiedOutputPath = getTestTempDirPath("classify");
conf = getConfiguration();
ClusteringTestUtils.writePointsToFile(points,
new Path(pointsPath, "file1"), fs, conf);
runClustering(pointsPath, conf, true);
runClassificationWithOutlierRemoval(true);
collectVectorsForAssertion();
assertVectorsWithOutlierRemoval();
}
private void runClustering(Path pointsPath, Configuration conf,
Boolean runSequential) throws IOException, InterruptedException,
ClassNotFoundException {
CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
new ManhattanDistanceMeasure(), 3.1, 2.1, false, 0.0, runSequential);
Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final");
ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
finalClustersPath);
}
private void runClassificationWithoutOutlierRemoval()
throws IOException, InterruptedException, ClassNotFoundException {
ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.0, true, true);
}
private void runClassificationWithOutlierRemoval(boolean runSequential)
throws IOException, InterruptedException, ClassNotFoundException {
ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.73, true, runSequential);
}
private void collectVectorsForAssertion() throws IOException {
Path[] partFilePaths = FileUtil.stat2Paths(fs
.globStatus(classifiedOutputPath));
FileStatus[] listStatus = fs.listStatus(partFilePaths,
PathFilters.partFilter());
for (FileStatus partFile : listStatus) {
SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
partFile.getPath(), conf);
Writable clusterIdAsKey = new IntWritable();
WeightedPropertyVectorWritable point = new WeightedPropertyVectorWritable();
while (classifiedVectors.next(clusterIdAsKey, point)) {
collectVector(clusterIdAsKey.toString(), point.getVector());
}
}
}
private void collectVector(String clusterId, Vector vector) {
if ("0".equals(clusterId)) {
firstCluster.add(vector);
} else if ("1".equals(clusterId)) {
secondCluster.add(vector);
} else if ("2".equals(clusterId)) {
thirdCluster.add(vector);
}
}
private void assertVectorsWithOutlierRemoval() {
checkClustersWithOutlierRemoval();
}
private void assertVectorsWithoutOutlierRemoval() {
assertFirstClusterWithoutOutlierRemoval();
assertSecondClusterWithoutOutlierRemoval();
assertThirdClusterWithoutOutlierRemoval();
}
private void assertThirdClusterWithoutOutlierRemoval() {
Assert.assertEquals(2, thirdCluster.size());
for (Vector vector : thirdCluster) {
Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:9.0,1:9.0}",
"{0:8.0,1:8.0}"}, vector.asFormatString()));
}
}
private void assertSecondClusterWithoutOutlierRemoval() {
Assert.assertEquals(4, secondCluster.size());
for (Vector vector : secondCluster) {
Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}",
"{0:5.0,1:4.0}", "{0:4.0,1:5.0}", "{0:5.0,1:5.0}"},
vector.asFormatString()));
}
}
private void assertFirstClusterWithoutOutlierRemoval() {
Assert.assertEquals(3, firstCluster.size());
for (Vector vector : firstCluster) {
Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}",
"{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, vector.asFormatString()));
}
}
private void checkClustersWithOutlierRemoval() {
Set<String> reference = Sets.newHashSet("{0:9.0,1:9.0}", "{0:1.0,1:1.0}");
List<List<Vector>> clusters = Lists.newArrayList();
clusters.add(firstCluster);
clusters.add(secondCluster);
clusters.add(thirdCluster);
int singletonCnt = 0;
int emptyCnt = 0;
for (List<Vector> vList : clusters) {
if (vList.isEmpty()) {
emptyCnt++;
} else {
singletonCnt++;
assertEquals("expecting only singleton clusters; got size=" + vList.size(), 1, vList.size());
if (vList.get(0).getClass().equals(NamedVector.class)) {
Assert.assertTrue("not expecting cluster:" + ((NamedVector) vList.get(0)).getDelegate().asFormatString(),
reference.contains(((NamedVector) vList.get(0)).getDelegate().asFormatString()));
reference.remove(((NamedVector)vList.get(0)).getDelegate().asFormatString());
} else if (vList.get(0).getClass().equals(RandomAccessSparseVector.class)) {
Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
reference.contains(vList.get(0).asFormatString()));
reference.remove(vList.get(0).asFormatString());
}
}
}
Assert.assertEquals("Different number of empty clusters than expected!", 1, emptyCnt);
Assert.assertEquals("Different number of singletons than expected!", 2, singletonCnt);
Assert.assertEquals("Didn't match all reference clusters!", 0, reference.size());
}
}