package org.apache.mahout.vectorizer; /** * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.Pair; import org.apache.mahout.common.iterator.sequencefile.PathFilters; import org.apache.mahout.common.iterator.sequencefile.PathType; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; import org.apache.mahout.math.NamedVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.junit.Before; import org.junit.Test; import java.util.Arrays; import java.util.LinkedList; import java.util.List; public class HighDFWordsPrunerTest extends MahoutTestCase { private static final int NUM_DOCS = 100; private static final String[] HIGH_DF_WORDS = {"has", "which", "what", "srtyui"}; private Configuration conf; private Path inputPath; @Override @Before public void setUp() throws Exception { super.setUp(); conf = new Configuration(); FileSystem fs = FileSystem.get(conf); inputPath = getTestTempFilePath("documents/docs.file"); SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, inputPath, Text.class, Text.class); RandomDocumentGenerator gen = new RandomDocumentGenerator(); for (int i = 0; i < NUM_DOCS; i++) { writer.append(new Text("Document::ID::" + i), new Text(enhanceWithHighDFWords(gen.getRandomDocument()))); } writer.close(); } private static String enhanceWithHighDFWords(String initialDoc) { StringBuilder sb = new StringBuilder(initialDoc); for (String word : HIGH_DF_WORDS) { sb.append(' ').append(word); } return sb.toString(); } @Test public void testHighDFWordsPreserving() throws Exception { runTest(false); } @Test public void testHighDFWordsPruning() throws Exception { runTest(true); } private void runTest(boolean prune) throws Exception { Path outputPath = getTestTempFilePath("output"); List<String> argList = new LinkedList<String>(); argList.add("-i"); argList.add(inputPath.toString()); argList.add("-o"); argList.add(outputPath.toString()); if (prune) { argList.add("-xs"); argList.add("3"); // we prune all words that are outside 3*sigma } else { argList.add("--maxDFPercent"); argList.add("100"); // the default if, -xs is not specified is to use maxDFPercent, which defaults to 99% } argList.add("-seq"); argList.add("-nv"); String[] args = argList.toArray(new String[argList.size()]); SparseVectorsFromSequenceFiles.main(args); Path dictionary = new Path(outputPath, "dictionary.file-0"); Path tfVectors = new Path(outputPath, "tf-vectors"); Path tfidfVectors = new Path(outputPath, "tfidf-vectors"); int[] highDFWordsDictionaryIndices = getHighDFWordsDictionaryIndices(dictionary); validateVectors(tfVectors, highDFWordsDictionaryIndices, prune); validateVectors(tfidfVectors, highDFWordsDictionaryIndices, prune); } private int[] getHighDFWordsDictionaryIndices(Path dictionaryPath) { int[] highDFWordsDictionaryIndices = new int[HIGH_DF_WORDS.length]; List<String> highDFWordsList = Arrays.asList(HIGH_DF_WORDS); for (Pair<Text, IntWritable> record : new SequenceFileDirIterable<Text, IntWritable>(dictionaryPath, PathType.GLOB, null, null, true, conf)) { int index = highDFWordsList.indexOf(record.getFirst().toString()); if (index > -1) { highDFWordsDictionaryIndices[index] = record.getSecond().get(); } } return highDFWordsDictionaryIndices; } private void validateVectors(Path vectorPath, int[] highDFWordsDictionaryIndices, boolean prune) { for (VectorWritable value : new SequenceFileDirValueIterable<VectorWritable>(vectorPath, PathType.LIST, PathFilters .partFilter(), null, true, conf)) { Vector v = ((NamedVector) value.get()).getDelegate(); for (int i = 0; i < highDFWordsDictionaryIndices.length; i++) { if (prune) { assertEquals("Found vector for which word '" + HIGH_DF_WORDS[i] + "' is not pruned", 0.0, v .get(highDFWordsDictionaryIndices[i]), 0.0); } else { assertTrue("Found vector for which word '" + HIGH_DF_WORDS[i] + "' is pruned, and shouldn't have been", v .get(highDFWordsDictionaryIndices[i]) != 0.0); } } } } }