/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.mahout.classifier.bayes;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
import org.apache.mahout.classifier.bayes.mapreduce.common.BayesFeatureMapper;
import org.apache.mahout.classifier.bayes.mapreduce.common.BayesFeatureReducer;
import org.apache.mahout.classifier.bayes.mapreduce.common.FeatureLabelComparator;
import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.StringTuple;
import org.junit.Test;
public final class BayesFeatureMapReduceTest extends MahoutTestCase {
private static DummyOutputCollector<StringTuple,DoubleWritable> runMapReduce(BayesParameters bp) throws IOException {
BayesFeatureMapper mapper = new BayesFeatureMapper();
JobConf conf = new JobConf();
conf.set("io.serializations",
"org.apache.hadoop.io.serializer.JavaSerialization,"
+ "org.apache.hadoop.io.serializer.WritableSerialization");
conf.set("bayes.parameters", bp.toString());
mapper.configure(conf);
DummyOutputCollector<StringTuple,DoubleWritable> mapperOutput = new DummyOutputCollector<StringTuple,DoubleWritable>();
mapper.map(new Text("foo"), new Text("big brown shoe"), mapperOutput, Reporter.NULL);
mapper.map(new Text("foo"), new Text("cool chuck taylors"), mapperOutput, Reporter.NULL);
mapper.map(new Text("bar"), new Text("big big dog"), mapperOutput, Reporter.NULL);
mapper.map(new Text("bar"), new Text("cool rain"), mapperOutput, Reporter.NULL);
mapper.map(new Text("baz"), new Text("red giant"), mapperOutput, Reporter.NULL);
mapper.map(new Text("baz"), new Text("white dwarf"), mapperOutput, Reporter.NULL);
mapper.map(new Text("baz"), new Text("cool black hole"), mapperOutput, Reporter.NULL);
BayesFeatureReducer reducer = new BayesFeatureReducer();
reducer.configure(conf);
DummyOutputCollector<StringTuple,DoubleWritable> reducerOutput = new DummyOutputCollector<StringTuple,DoubleWritable>();
Map<StringTuple, List<DoubleWritable>> outputData = mapperOutput.getData();
// put the mapper output in the expected order (emulate shuffle)
FeatureLabelComparator cmp = new FeatureLabelComparator();
Collection<StringTuple> keySet = new TreeSet<StringTuple>(cmp);
keySet.addAll(mapperOutput.getKeys());
for (StringTuple k: keySet) {
List<DoubleWritable> v = outputData.get(k);
reducer.reduce(k, v.iterator(), reducerOutput, Reporter.NULL);
}
return reducerOutput;
}
@Test
public void testNoFilters() throws Exception {
BayesParameters bp = new BayesParameters();
bp.setGramSize(1);
bp.setMinDF(1);
DummyOutputCollector<StringTuple,DoubleWritable> reduceOutput = runMapReduce(bp);
assertCounts(reduceOutput,
17, /* df: 13 unique term/label pairs */
14, /* fc: 12 unique features across all labels */
3, /* lc: 3 labels */
17 /* wt: 13 unique term/label pairs */);
}
@Test
public void testMinSupport() throws Exception {
BayesParameters bp = new BayesParameters();
bp.setGramSize(1);
bp.setMinSupport(2);
DummyOutputCollector<StringTuple,DoubleWritable> reduceOutput = runMapReduce(bp);
assertCounts(reduceOutput,
5, /* df: 5 unique term/label pairs */
2, /* fc: 'big' and 'cool' appears more than 2 times */
3, /* lc: 3 labels */
5 /* wt: 5 unique term/label pairs */);
}
@Test
public void testMinDf() throws Exception {
BayesParameters bp = new BayesParameters();
bp.setGramSize(1);
bp.setMinDF(2);
DummyOutputCollector<StringTuple,DoubleWritable> reduceOutput = runMapReduce(bp);
// 13 unique term/label pairs. 3 labels
// should be a df and fc for each pair, no filtering
assertCounts(reduceOutput,
5, /* df: 5 term/label pairs contains terms in more than 2 document */
2, /* fc */
3, /* lc */
5 /* wt */);
}
@Test
public void testMinBoth() throws Exception {
BayesParameters bp = new BayesParameters();
bp.setGramSize(1);
bp.setMinSupport(3);
bp.setMinDF(2);
DummyOutputCollector<StringTuple,DoubleWritable> reduceOutput = runMapReduce(bp);
// 13 unique term/label pairs. 3 labels
// should be a df and fc for each pair, no filtering
assertCounts(reduceOutput,
5, /* df: 5 term/label pairs contains terms in more than 2 document */
2, /* fc: 'cool' appears 3 times */
3, /* lc */
5 /* wt */);
}
private static void assertCounts(DummyOutputCollector<StringTuple,DoubleWritable> output,
int dfExpected,
int fcExpected,
int lcExpected,
int wtExpected) {
int dfCount = 0;
int fcCount = 0;
int lcCount = 0;
int wtCount = 0;
Map<StringTuple, List<DoubleWritable>> outputData = output.getData();
for (Map.Entry<StringTuple, List<DoubleWritable>> entry: outputData.entrySet()) {
String type = entry.getKey().stringAt(0);
if (type.equals(BayesConstants.DOCUMENT_FREQUENCY)) {
dfCount++;
} else if (type.equals(BayesConstants.FEATURE_COUNT)) {
fcCount++;
} else if (type.equals(BayesConstants.LABEL_COUNT)) {
lcCount++;
} else if (type.equals(BayesConstants.WEIGHT)) {
wtCount++;
}
assertEquals("value size", 1, entry.getValue().size());
}
assertEquals("document frequency count", dfExpected, dfCount);
assertEquals("feature count", fcExpected, fcCount);
assertEquals("label count", lcExpected, lcCount);
assertEquals("feature weight count", wtExpected, wtCount);
}
}