/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering.minhash;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.minhash.HashFactory.HashType;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.commandline.MinhashOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Test;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
public class TestMinHashClustering extends MahoutTestCase {
private static final double[][] REFERENCE = { {1, 2, 3, 4, 5}, {2, 1, 3, 6, 7}, {3, 7, 6, 11, 8, 9},
{4, 7, 8, 9, 6, 1}, {5, 8, 10, 4, 1}, {6, 17, 14, 15},
{8, 9, 11, 6, 12, 1, 7}, {10, 13, 9, 7, 4, 6, 3},
{3, 5, 7, 9, 2, 11}, {13, 7, 6, 8, 5}};
private Path input;
private Path output;
public static List<VectorWritable> getPointsWritable(double[][] raw) {
List<VectorWritable> points = Lists.newArrayList();
for (double[] fr : raw) {
Vector vec = new SequentialAccessSparseVector(fr.length);
vec.assign(fr);
points.add(new VectorWritable(vec));
}
return points;
}
@Override
public void setUp() throws Exception {
super.setUp();
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(conf);
List<VectorWritable> points = getPointsWritable(REFERENCE);
input = getTestTempDirPath("points");
output = new Path(getTestTempDirPath(), "output");
Path pointFile = new Path(input, "file1");
SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, pointFile, Text.class, VectorWritable.class);
try {
int id = 0;
for (VectorWritable point : points) {
writer.append(new Text("Id-" + id++), point);
}
} finally {
Closeables.closeQuietly(writer);
}
}
private String[] makeArguments(int minClusterSize,
int minVectorSize,
int numHashFunctions,
int keyGroups,
String hashType) {
return new String[] {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
optKey(MinhashOptionCreator.MIN_CLUSTER_SIZE), String.valueOf(minClusterSize),
optKey(MinhashOptionCreator.MIN_VECTOR_SIZE), String.valueOf(minVectorSize),
optKey(MinhashOptionCreator.HASH_TYPE), hashType,
optKey(MinhashOptionCreator.NUM_HASH_FUNCTIONS), String.valueOf(numHashFunctions),
optKey(MinhashOptionCreator.KEY_GROUPS), String.valueOf(keyGroups),
optKey(MinhashOptionCreator.NUM_REDUCERS), "1",
optKey(MinhashOptionCreator.DEBUG_OUTPUT)};
}
private static Set<Integer> getValues(Vector vector) {
Iterator<Vector.Element> itr = vector.iterator();
Set<Integer> values = new HashSet<Integer>();
while (itr.hasNext()) {
values.add((int) itr.next().get());
}
return values;
}
private static void runPairwiseSimilarity(List<Vector> clusteredItems, double simThreshold, String msg) {
if (clusteredItems.size() > 1) {
for (int i = 0; i < clusteredItems.size(); i++) {
Set<Integer> itemSet1 = getValues(clusteredItems.get(i));
for (int j = i + 1; j < clusteredItems.size(); j++) {
Set<Integer> itemSet2 = getValues(clusteredItems.get(j));
Collection<Integer> union = new HashSet<Integer>();
union.addAll(itemSet1);
union.addAll(itemSet2);
Collection<Integer> intersect = new HashSet<Integer>();
intersect.addAll(itemSet1);
intersect.retainAll(itemSet2);
double similarity = intersect.size() / (double) union.size();
assertTrue(msg + " - Sets failed min similarity test, Set1: " + itemSet1 + " Set2: " + itemSet2
+ ", similarity:" + similarity, similarity >= simThreshold);
}
}
}
}
private static void verify(Path output, double simThreshold, String msg) {
Configuration conf = new Configuration();
Path outputFile = new Path(output, "part-r-00000");
List<Vector> clusteredItems = Lists.newArrayList();
String prevClusterId = "";
for (Pair<Writable,VectorWritable> record : new SequenceFileIterable<Writable,VectorWritable>(outputFile, conf)) {
Writable clusterId = record.getFirst();
VectorWritable point = record.getSecond();
if (prevClusterId.equals(clusterId.toString())) {
clusteredItems.add(point.get());
} else {
runPairwiseSimilarity(clusteredItems, simThreshold, msg);
clusteredItems.clear();
prevClusterId = clusterId.toString();
clusteredItems.add(point.get());
}
}
runPairwiseSimilarity(clusteredItems, simThreshold, msg);
}
@Test
public void testLinearMinHashMRJob() throws Exception {
String[] args = makeArguments(2, 3, 20, 3, HashType.LINEAR.toString());
int ret = ToolRunner.run(new Configuration(), new MinHashDriver(), args);
assertEquals("Minhash MR Job failed for " + HashType.LINEAR, 0, ret);
verify(output, 0.2, "Hash Type: LINEAR");
}
@Test
public void testPolynomialMinHashMRJob() throws Exception {
String[] args = makeArguments(2, 3, 20, 3, HashType.POLYNOMIAL.toString());
int ret = ToolRunner.run(new Configuration(), new MinHashDriver(), args);
assertEquals("Minhash MR Job failed for " + HashType.POLYNOMIAL, 0, ret);
verify(output, 0.3, "Hash Type: POLYNOMIAL");
}
@Test
public void testMurmurMinHashMRJob() throws Exception {
String[] args = makeArguments(2, 3, 20, 4, HashType.MURMUR.toString());
int ret = ToolRunner.run(new Configuration(), new MinHashDriver(), args);
assertEquals("Minhash MR Job failed for " + HashType.MURMUR, 0, ret);
verify(output, 0.3, "Hash Type: MURMUR");
}
@Test
public void testMurmur3MinHashMRJob() throws Exception {
String[] args = makeArguments(2, 3, 20, 4, HashType.MURMUR3.toString());
int ret = ToolRunner.run(new Configuration(), new MinHashDriver(), args);
assertEquals("Minhash MR Job failed for " + HashType.MURMUR3, 0, ret);
verify(output, 0.3, "Hash Type: MURMUR");
}
}