/** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math.hadoop.similarity; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.DoubleWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.clustering.ClusteringTestUtils; import org.apache.mahout.common.DummyOutputCollector; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.Pair; import org.apache.mahout.common.StringTuple; import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.common.distance.EuclideanDistanceMeasure; import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.NamedVector; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.easymock.EasyMock; import org.junit.Before; import org.junit.Test; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; public class TestVectorDistanceSimilarityJob extends MahoutTestCase { private FileSystem fs; @Override @Before public void setUp() throws Exception { super.setUp(); Configuration conf = new Configuration(); fs = FileSystem.get(conf); } @Test public void testVectorDistanceMapper() throws Exception { Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable>.Context context = EasyMock.createMock(Mapper.Context.class); StringTuple tuple = new StringTuple(); tuple.add("foo"); tuple.add("123"); context.write(tuple, new DoubleWritable(Math.sqrt(2.0))); tuple = new StringTuple(); tuple.add("foo2"); tuple.add("123"); context.write(tuple, new DoubleWritable(1)); EasyMock.replay(context); Vector vector = new RandomAccessSparseVector(2); vector.set(0, 2); vector.set(1, 2); VectorDistanceMapper mapper = new VectorDistanceMapper(); setField(mapper, "measure", new EuclideanDistanceMeasure()); Collection<NamedVector> seedVectors = new ArrayList<NamedVector>(); Vector seed1 = new RandomAccessSparseVector(2); seed1.set(0, 1); seed1.set(1, 1); Vector seed2 = new RandomAccessSparseVector(2); seed2.set(0, 2); seed2.set(1, 1); seedVectors.add(new NamedVector(seed1, "foo")); seedVectors.add(new NamedVector(seed2, "foo2")); setField(mapper, "seedVectors", seedVectors); mapper.map(new IntWritable(123), new VectorWritable(vector), context); EasyMock.verify(context); } @Test public void testVectorDistanceInvertedMapper() throws Exception { Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = EasyMock.createMock(Mapper.Context.class); Vector expectVec = new DenseVector(new double[]{Math.sqrt(2.0), 1.0}); context.write(new Text("other"), new VectorWritable(expectVec)); EasyMock.replay(context); Vector vector = new NamedVector(new RandomAccessSparseVector(2), "other"); vector.set(0, 2); vector.set(1, 2); VectorDistanceInvertedMapper mapper = new VectorDistanceInvertedMapper(); setField(mapper, "measure", new EuclideanDistanceMeasure()); Collection<NamedVector> seedVectors = new ArrayList<NamedVector>(); Vector seed1 = new RandomAccessSparseVector(2); seed1.set(0, 1); seed1.set(1, 1); Vector seed2 = new RandomAccessSparseVector(2); seed2.set(0, 2); seed2.set(1, 1); seedVectors.add(new NamedVector(seed1, "foo")); seedVectors.add(new NamedVector(seed2, "foo2")); setField(mapper, "seedVectors", seedVectors); mapper.map(new IntWritable(123), new VectorWritable(vector), context); EasyMock.verify(context); } private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5} }; private static final double[][] SEEDS = { {1, 1}, {10, 10} }; @Test public void testRun() throws Exception { Path input = getTestTempDirPath("input"); Path output = getTestTempDirPath("output"); Path seedsPath = getTestTempDirPath("seeds"); List<VectorWritable> points = getPointsWritable(REFERENCE); List<VectorWritable> seeds = getPointsWritable(SEEDS); Configuration conf = new Configuration(); ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf); ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf); String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(), optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName() }; ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args); int expect = SEEDS.length * REFERENCE.length; DummyOutputCollector<StringTuple, DoubleWritable> collector = new DummyOutputCollector<StringTuple, DoubleWritable>(); // for (Pair<StringTuple, DoubleWritable> record : new SequenceFileIterable<StringTuple, DoubleWritable>( new Path(output, "part-m-00000"), conf)) { collector.collect(record.getFirst(), record.getSecond()); } assertEquals(expect, collector.getData().size()); } @Test public void testRunInverted() throws Exception { Path input = getTestTempDirPath("input"); Path output = getTestTempDirPath("output"); Path seedsPath = getTestTempDirPath("seeds"); List<VectorWritable> points = getPointsWritable(REFERENCE); List<VectorWritable> seeds = getPointsWritable(SEEDS); Configuration conf = new Configuration(); ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf); ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf); String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(), optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(), optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v" }; ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args); DummyOutputCollector<Text, VectorWritable> collector = new DummyOutputCollector<Text, VectorWritable>(); // for (Pair<Text, VectorWritable> record : new SequenceFileIterable<Text, VectorWritable>( new Path(output, "part-m-00000"), conf)) { collector.collect(record.getFirst(), record.getSecond()); } assertEquals(REFERENCE.length, collector.getData().size()); for (Map.Entry<Text, List<VectorWritable>> entry : collector.getData().entrySet()) { assertEquals(SEEDS.length, entry.getValue().iterator().next().get().size()); } } public static List<VectorWritable> getPointsWritable(double[][] raw) { List<VectorWritable> points = Lists.newArrayList(); for (double[] fr : raw) { Vector vec = new RandomAccessSparseVector(fr.length); vec.assign(fr); points.add(new VectorWritable(vec)); } return points; } }