/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.utils;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.Charset;
import com.google.common.base.Charsets;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.ClassifierData;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.junit.Before;
import org.junit.Test;
public final class SplitInputTest extends MahoutTestCase {
private OpenObjectIntHashMap<String> countMap;
private Charset charset;
private FileSystem fs;
private Path tempInputFile;
private Path tempTrainingDirectory;
private Path tempTestDirectory;
private Path tempMapRedOutputDirectory;
private Path tempInputDirectory;
private Path tempSequenceDirectory;
private SplitInput si;
@Override
@Before
public void setUp() throws Exception {
Configuration conf = new Configuration();
fs = FileSystem.get(conf);
super.setUp();
countMap = new OpenObjectIntHashMap<String>();
charset = Charsets.UTF_8;
tempSequenceDirectory = getTestTempFilePath("tmpsequence");
tempInputFile = getTestTempFilePath("bayesinputfile");
tempTrainingDirectory = getTestTempDirPath("bayestrain");
tempTestDirectory = getTestTempDirPath("bayestest");
tempMapRedOutputDirectory = new Path(getTestTempDirPath(), "mapRedOutput");
tempInputDirectory = getTestTempDirPath("bayesinputdir");
si = new SplitInput();
si.setTrainingOutputDirectory(tempTrainingDirectory);
si.setTestOutputDirectory(tempTestDirectory);
si.setInputDirectory(tempInputDirectory);
}
private void writeMultipleInputFiles() throws IOException {
Writer writer = null;
String currentLabel = null;
for (String[] entry : ClassifierData.DATA) {
if (!entry[0].equals(currentLabel)) {
currentLabel = entry[0];
Closeables.closeQuietly(writer);
writer = new BufferedWriter(new OutputStreamWriter(fs.create(new Path(tempInputDirectory, currentLabel)),
Charsets.UTF_8));
}
countMap.adjustOrPutValue(currentLabel, 1, 1);
writer.write(currentLabel + '\t' + entry[1] + '\n');
}
Closeables.closeQuietly(writer);
}
private void writeSingleInputFile() throws IOException {
Writer writer = new BufferedWriter(new OutputStreamWriter(fs.create(tempInputFile), Charsets.UTF_8));
try {
for (String[] entry : ClassifierData.DATA) {
writer.write(entry[0] + '\t' + entry[1] + '\n');
}
} finally {
Closeables.closeQuietly(writer);
}
}
@Test
public void testSplitDirectory() throws Exception {
writeMultipleInputFiles();
final int testSplitSize = 1;
si.setTestSplitSize(testSplitSize);
si.setCallback(new SplitInput.SplitCallback() {
@Override
public void splitComplete(Path inputFile, int lineCount, int trainCount, int testCount, int testSplitStart) {
int trainingLines = countMap.get(inputFile.getName()) - testSplitSize;
assertSplit(fs, inputFile, charset, testSplitSize, trainingLines, tempTrainingDirectory, tempTestDirectory);
}
});
si.splitDirectory(tempInputDirectory);
}
@Test
public void testSplitFile() throws Exception {
writeSingleInputFile();
si.setTestSplitSize(2);
si.setCallback(new TestCallback(2, 10));
si.splitFile(tempInputFile);
}
@Test
public void testSplitFileLocation() throws Exception {
writeSingleInputFile();
si.setTestSplitSize(2);
si.setSplitLocation(50);
si.setCallback(new TestCallback(2, 10));
si.splitFile(tempInputFile);
}
@Test
public void testSplitFilePct() throws Exception {
writeSingleInputFile();
si.setTestSplitPct(25);
si.setCallback(new TestCallback(3, 9));
si.splitFile(tempInputFile);
}
@Test
public void testSplitFilePctLocation() throws Exception {
writeSingleInputFile();
si.setTestSplitPct(25);
si.setSplitLocation(50);
si.setCallback(new TestCallback(3, 9));
si.splitFile(tempInputFile);
}
@Test
public void testSplitFileRandomSelectionSize() throws Exception {
writeSingleInputFile();
si.setTestRandomSelectionSize(5);
si.setCallback(new TestCallback(5, 7));
si.splitFile(tempInputFile);
}
@Test
public void testSplitFileRandomSelectionPct() throws Exception {
writeSingleInputFile();
si.setTestRandomSelectionPct(25);
si.setCallback(new TestCallback(3, 9));
si.splitFile(tempInputFile);
}
/**
* Create a Sequencefile for testing consisting of IntWritable
* keys and VectorWritable values
* @param path path for test SequenceFile
* @param testPoints number of records in test SequenceFile
*/
private void writeVectorSequenceFile(Path path, int testPoints)
throws IOException {
Path tempSequenceFile = new Path(path, "part-00000");
Configuration conf = new Configuration();
IntWritable key = new IntWritable();
VectorWritable value = new VectorWritable();
SequenceFile.Writer writer = null;
try {
writer =
SequenceFile.createWriter(fs, conf, tempSequenceFile,
IntWritable.class, VectorWritable.class);
for (int i = 0; i < testPoints; i++) {
key.set(i);
Vector v = new SequentialAccessSparseVector(4);
v.assign(i);
value.set(v);
writer.append(key, value);
}
} finally {
IOUtils.closeStream(writer);
}
}
/**
* Create a Sequencefile for testing consisting of IntWritable
* keys and Text values
* @param path path for test SequenceFile
* @param testPoints number of records in test SequenceFile
*/
private void writeTextSequenceFile(Path path, int testPoints)
throws IOException {
Path tempSequenceFile = new Path(path, "part-00000");
Configuration conf = new Configuration();
Text key = new Text();
Text value = new Text();
SequenceFile.Writer writer = null;
try {
writer =
SequenceFile.createWriter(fs, conf, tempSequenceFile,
Text.class, Text.class);
for (int i = 0; i < testPoints; i++) {
key.set(Integer.toString(i));
value.set("Line " + i);
writer.append(key, value);
}
} finally {
IOUtils.closeStream(writer);
}
}
/**
* Display contents of a SequenceFile
* @param sequenceFilePath path to SequenceFile
*/
private static void displaySequenceFile(Path sequenceFilePath) {
for (Pair<?,?> record : new SequenceFileIterable<Writable,Writable>(sequenceFilePath, true, new Configuration())) {
System.out.println(record.getFirst() + "\t" + record.getSecond());
}
}
/**
* Determine number of records in a SequenceFile
* @param sequenceFilePath path to SequenceFile
* @return number of records
*/
private static int getNumberRecords(Path sequenceFilePath) {
int numberRecords = 0;
for (Object value : new SequenceFileValueIterable<Writable>(sequenceFilePath, true, new Configuration())) {
numberRecords++;
}
return numberRecords;
}
/**
* Test map reduce version of split input with Text, Text key value
* pairs in input
*/
@Test
public void testSplitInputMapReduceText() throws Exception {
writeTextSequenceFile(tempSequenceDirectory, 1000);
testSplitInputMapReduce(1000);
}
/**
* Test map reduce version of split input with Text, Text key value
* pairs in input called from command line
*/
@Test
public void testSplitInputMapReduceTextCli() throws Exception {
writeTextSequenceFile(tempSequenceDirectory, 1000);
testSplitInputMapReduceCli(1000);
}
/**
* Test map reduce version of split input with IntWritable, Vector key value
* pairs in input
*/
@Test
public void testSplitInputMapReduceVector() throws Exception {
writeVectorSequenceFile(tempSequenceDirectory, 1000);
testSplitInputMapReduce(1000);
}
/**
* Test map reduce version of split input with IntWritable, Vector key value
* pairs in input called from command line
*/
@Test
public void testSplitInputMapReduceVectorCli() throws Exception {
writeVectorSequenceFile(tempSequenceDirectory, 1000);
testSplitInputMapReduceCli(1000);
}
/**
* Test map reduce version of split input through CLI
*/
private void testSplitInputMapReduceCli(int numPoints) throws Exception {
int randomSelectionPct = 25;
int keepPct = 10;
String[] args =
{ "--method", "mapreduce", "--input", tempSequenceDirectory.toString(),
"--mapRedOutputDir", tempMapRedOutputDirectory.toString(),
"--randomSelectionPct", Integer.toString(randomSelectionPct),
"--keepPct", Integer.toString(keepPct), "-ow" };
SplitInput.main(args);
validateSplitInputMapReduce(numPoints, randomSelectionPct, keepPct);
}
/**
* Test map reduce version of split input through method call
*/
private void testSplitInputMapReduce(int numPoints) throws Exception {
int randomSelectionPct = 25;
si.setTestRandomSelectionPct(randomSelectionPct);
int keepPct = 10;
si.setKeepPct(keepPct);
si.setMapRedOutputDirectory(tempMapRedOutputDirectory);
si.setUseMapRed(true);
si.splitDirectory(tempSequenceDirectory);
validateSplitInputMapReduce(numPoints, randomSelectionPct, keepPct);
}
/**
* Validate that number of test records and number of training records
* are consistant with keepPct and randomSelectionPct
*/
private void validateSplitInputMapReduce(int numPoints, int randomSelectionPct, int keepPct) {
Path testPath = new Path(tempMapRedOutputDirectory, "test-r-00000");
Path trainingPath = new Path(tempMapRedOutputDirectory, "training-r-00000");
int numberTestRecords = getNumberRecords(testPath);
int numberTrainingRecords = getNumberRecords(trainingPath);
System.out.printf("Test data: %d records\n", numberTestRecords);
displaySequenceFile(testPath);
System.out.printf("Training data: %d records\n", numberTrainingRecords);
displaySequenceFile(trainingPath);
assertEquals((randomSelectionPct / 100.0) * (keepPct / 100.0) * numPoints,
numberTestRecords, 2);
assertEquals(
(1 - randomSelectionPct / 100.0) * (keepPct / 100.0) * numPoints,
numberTrainingRecords, 2);
}
@Test
public void testValidate() throws Exception {
SplitInput st = new SplitInput();
assertValidateException(st);
st.setTestSplitSize(100);
assertValidateException(st);
st.setTestOutputDirectory(tempTestDirectory);
assertValidateException(st);
st.setTrainingOutputDirectory(tempTrainingDirectory);
st.validate();
st.setTestSplitPct(50);
assertValidateException(st);
st = new SplitInput();
st.setTestRandomSelectionPct(50);
st.setTestOutputDirectory(tempTestDirectory);
st.setTrainingOutputDirectory(tempTrainingDirectory);
st.validate();
st.setTestSplitPct(50);
assertValidateException(st);
st = new SplitInput();
st.setTestRandomSelectionPct(50);
st.setTestOutputDirectory(tempTestDirectory);
st.setTrainingOutputDirectory(tempTrainingDirectory);
st.validate();
st.setTestSplitSize(100);
assertValidateException(st);
}
private class TestCallback implements SplitInput.SplitCallback {
private final int testSplitSize;
private final int trainingLines;
private TestCallback(int testSplitSize, int trainingLines) {
this.testSplitSize = testSplitSize;
this.trainingLines = trainingLines;
}
@Override
public void splitComplete(Path inputFile, int lineCount, int trainCount, int testCount, int testSplitStart) {
assertSplit(fs, tempInputFile, charset, testSplitSize, trainingLines, tempTrainingDirectory, tempTestDirectory);
}
}
private static void assertValidateException(SplitInput st) throws IOException {
try {
st.validate();
fail("Expected IllegalArgumentException");
} catch (IllegalArgumentException iae) {
// good
}
}
private static void assertSplit(FileSystem fs,
Path tempInputFile,
Charset charset,
int testSplitSize,
int trainingLines,
Path tempTrainingDirectory,
Path tempTestDirectory) {
try {
Path testFile = new Path(tempTestDirectory, tempInputFile.getName());
//assertTrue("test file exists", testFile.isFile());
assertEquals("test line count", testSplitSize, SplitInput.countLines(fs, testFile, charset));
Path trainingFile = new Path(tempTrainingDirectory, tempInputFile.getName());
//assertTrue("training file exists", trainingFile.isFile());
assertEquals("training line count", trainingLines, SplitInput.countLines(fs, trainingFile, charset));
} catch (IOException ioe) {
fail(ioe.toString());
}
}
}