/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.mapreduce; import junit.framework.TestCase; import java.io.IOException; import java.io.DataInput; import java.io.DataOutput; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.*; import org.apache.hadoop.mapreduce.lib.output.NullOutputFormat; import org.apache.hadoop.mapreduce.MRConfig; import org.apache.hadoop.util.ReflectionUtils; public class TestMapCollection { private static final Log LOG = LogFactory.getLog( TestMapCollection.class.getName()); public static abstract class FillWritable implements Writable, Configurable { private int len; protected boolean disableRead; private byte[] b; private final Random r; protected final byte fillChar; public FillWritable(byte fillChar) { this.fillChar = fillChar; r = new Random(); final long seed = r.nextLong(); LOG.info("seed: " + seed); r.setSeed(seed); } @Override public Configuration getConf() { return null; } public void setLength(int len) { this.len = len; } public int compareTo(FillWritable o) { if (o == this) return 0; return len - o.len; } @Override public int hashCode() { return 37 * len; } @Override public boolean equals(Object o) { if (!(o instanceof FillWritable)) return false; return 0 == compareTo((FillWritable)o); } @Override public void readFields(DataInput in) throws IOException { if (disableRead) { return; } len = WritableUtils.readVInt(in); for (int i = 0; i < len; ++i) { assertEquals("Invalid byte at " + i, fillChar, in.readByte()); } } @Override public void write(DataOutput out) throws IOException { if (0 == len) { return; } int written = 0; if (!disableRead) { WritableUtils.writeVInt(out, len); written -= WritableUtils.getVIntSize(len); } if (len > 1024) { if (null == b || b.length < len) { b = new byte[2 * len]; } Arrays.fill(b, fillChar); do { final int write = Math.min(len - written, r.nextInt(len)); out.write(b, 0, write); written += write; } while (written < len); assertEquals(len, written); } else { for (int i = written; i < len; ++i) { out.write(fillChar); } } } } public static class KeyWritable extends FillWritable implements WritableComparable<FillWritable> { static final byte keyFill = (byte)('K' & 0xFF); public KeyWritable() { super(keyFill); } @Override public void setConf(Configuration conf) { disableRead = conf.getBoolean("test.disable.key.read", false); } } public static class ValWritable extends FillWritable { public ValWritable() { super((byte)('V' & 0xFF)); } @Override public void setConf(Configuration conf) { disableRead = conf.getBoolean("test.disable.val.read", false); } } public static class VariableComparator implements RawComparator<KeyWritable>, Configurable { private boolean readLen; public VariableComparator() { } @Override public void setConf(Configuration conf) { readLen = !conf.getBoolean("test.disable.key.read", false); } @Override public Configuration getConf() { return null; } public int compare(KeyWritable k1, KeyWritable k2) { return k1.compareTo(k2); } @Override public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { final int n1; final int n2; if (readLen) { n1 = WritableUtils.decodeVIntSize(b1[s1]); n2 = WritableUtils.decodeVIntSize(b2[s2]); } else { n1 = 0; n2 = 0; } for (int i = s1 + n1; i < l1 - n1; ++i) { assertEquals("Invalid key at " + s1, (int)KeyWritable.keyFill, b1[i]); } for (int i = s2 + n2; i < l2 - n2; ++i) { assertEquals("Invalid key at " + s2, (int)KeyWritable.keyFill, b2[i]); } return l1 - l2; } } public static class SpillReducer extends Reducer<KeyWritable,ValWritable,NullWritable,NullWritable> { private int numrecs; private int expected; @Override protected void setup(Context job) { numrecs = 0; expected = job.getConfiguration().getInt("test.spillmap.records", 100); } @Override protected void reduce(KeyWritable k, Iterable<ValWritable> values, Context context) throws IOException, InterruptedException { for (ValWritable val : values) { ++numrecs; } } @Override protected void cleanup(Context context) throws IOException, InterruptedException { assertEquals("Unexpected record count", expected, numrecs); } } public static class FakeSplit extends InputSplit implements Writable { @Override public void write(DataOutput out) throws IOException { } @Override public void readFields(DataInput in) throws IOException { } @Override public long getLength() { return 0L; } @Override public String[] getLocations() { return new String[0]; } } public abstract static class RecordFactory implements Configurable { public Configuration getConf() { return null; } public abstract int keyLen(int i); public abstract int valLen(int i); } public static class FixedRecordFactory extends RecordFactory { private int keylen; private int vallen; public FixedRecordFactory() { } public void setConf(Configuration conf) { keylen = conf.getInt("test.fixedrecord.keylen", 0); vallen = conf.getInt("test.fixedrecord.vallen", 0); } public int keyLen(int i) { return keylen; } public int valLen(int i) { return vallen; } public static void setLengths(Configuration conf, int keylen, int vallen) { conf.setInt("test.fixedrecord.keylen", keylen); conf.setInt("test.fixedrecord.vallen", vallen); conf.setBoolean("test.disable.key.read", 0 == keylen); conf.setBoolean("test.disable.val.read", 0 == vallen); } } public static class FakeIF extends InputFormat<KeyWritable,ValWritable> { public FakeIF() { } @Override public List<InputSplit> getSplits(JobContext ctxt) throws IOException { final int numSplits = ctxt.getConfiguration().getInt( "test.mapcollection.num.maps", -1); List<InputSplit> splits = new ArrayList<InputSplit>(numSplits); for (int i = 0; i < numSplits; ++i) { splits.add(i, new FakeSplit()); } return splits; } public RecordReader<KeyWritable,ValWritable> createRecordReader( InputSplit ignored, final TaskAttemptContext taskContext) { return new RecordReader<KeyWritable,ValWritable>() { private RecordFactory factory; private final KeyWritable key = new KeyWritable(); private final ValWritable val = new ValWritable(); private int current; private int records; @Override public void initialize(InputSplit split, TaskAttemptContext context) { final Configuration conf = context.getConfiguration(); key.setConf(conf); val.setConf(conf); factory = ReflectionUtils.newInstance( conf.getClass("test.mapcollection.class", FixedRecordFactory.class, RecordFactory.class), conf); assertNotNull(factory); current = 0; records = conf.getInt("test.spillmap.records", 100); } @Override public boolean nextKeyValue() { key.setLength(factory.keyLen(current)); val.setLength(factory.valLen(current)); return current++ < records; } @Override public KeyWritable getCurrentKey() { return key; } @Override public ValWritable getCurrentValue() { return val; } @Override public float getProgress() { return (float) current / records; } @Override public void close() { assertEquals("Unexpected count", records, current - 1); } }; } } private static void runTest(String name, int keylen, int vallen, int records, int ioSortMB, float spillPer) throws Exception { Configuration conf = new Configuration(); conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100); Job job = Job.getInstance(conf); conf = job.getConfiguration(); conf.setInt(MRJobConfig.IO_SORT_MB, ioSortMB); conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(spillPer)); conf.setClass("test.mapcollection.class", FixedRecordFactory.class, RecordFactory.class); FixedRecordFactory.setLengths(conf, keylen, vallen); conf.setInt("test.spillmap.records", records); runTest(name, job); } private static void runTest(String name, Job job) throws Exception { job.setNumReduceTasks(1); job.getConfiguration().set(MRConfig.FRAMEWORK_NAME, MRConfig.LOCAL_FRAMEWORK_NAME); job.getConfiguration().setInt(MRJobConfig.IO_SORT_FACTOR, 1000); job.getConfiguration().set("fs.defaultFS", "file:///"); job.getConfiguration().setInt("test.mapcollection.num.maps", 1); job.setInputFormatClass(FakeIF.class); job.setOutputFormatClass(NullOutputFormat.class); job.setMapperClass(Mapper.class); job.setReducerClass(SpillReducer.class); job.setMapOutputKeyClass(KeyWritable.class); job.setMapOutputValueClass(ValWritable.class); job.setSortComparatorClass(VariableComparator.class); LOG.info("Running " + name); assertTrue("Job failed!", job.waitForCompletion(false)); } @Test public void testValLastByte() throws Exception { // last byte of record/key is the last/first byte in the spill buffer runTest("vallastbyte", 128, 896, 1344, 1, 0.5f); runTest("keylastbyte", 512, 1024, 896, 1, 0.5f); } @Test public void testLargeRecords() throws Exception { // maps emitting records larger than mapreduce.task.io.sort.mb runTest("largerec", 100, 1024*1024, 5, 1, .8f); runTest("largekeyzeroval", 1024*1024, 0, 5, 1, .8f); } @Test public void testSpillPer2B() throws Exception { // set non-default, 100% speculative spill boundary runTest("fullspill2B", 1, 1, 10000, 1, 1.0f); runTest("fullspill200B", 100, 100, 10000, 1, 1.0f); runTest("fullspillbuf", 10 * 1024, 20 * 1024, 256, 1, 1.0f); runTest("lt50perspill", 100, 100, 10000, 1, 0.3f); } @Test public void testZeroVal() throws Exception { // test key/value at zero-length runTest("zeroval", 1, 0, 10000, 1, .8f); runTest("zerokey", 0, 1, 10000, 1, .8f); runTest("zerokeyval", 0, 0, 10000, 1, .8f); runTest("zerokeyvalfull", 0, 0, 10000, 1, 1.0f); } @Test public void testSingleRecord() throws Exception { runTest("singlerecord", 100, 100, 1, 1, 1.0f); runTest("zerokeyvalsingle", 0, 0, 1, 1, 1.0f); } @Test public void testLowSpill() throws Exception { runTest("lowspill", 4000, 96, 20, 1, 0.00390625f); } @Test public void testSplitMetaSpill() throws Exception { runTest("splitmetaspill", 7, 1, 131072, 1, 0.8f); } public static class StepFactory extends RecordFactory { public int prekey; public int postkey; public int preval; public int postval; public int steprec; public void setConf(Configuration conf) { prekey = conf.getInt("test.stepfactory.prekey", 0); postkey = conf.getInt("test.stepfactory.postkey", 0); preval = conf.getInt("test.stepfactory.preval", 0); postval = conf.getInt("test.stepfactory.postval", 0); steprec = conf.getInt("test.stepfactory.steprec", 0); } public static void setLengths(Configuration conf, int prekey, int postkey, int preval, int postval, int steprec) { conf.setInt("test.stepfactory.prekey", prekey); conf.setInt("test.stepfactory.postkey", postkey); conf.setInt("test.stepfactory.preval", preval); conf.setInt("test.stepfactory.postval", postval); conf.setInt("test.stepfactory.steprec", steprec); } public int keyLen(int i) { return i > steprec ? postkey : prekey; } public int valLen(int i) { return i > steprec ? postval : preval; } } @Test public void testPostSpillMeta() throws Exception { // write larger records until spill, then write records that generate // no writes into the serialization buffer Configuration conf = new Configuration(); conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100); Job job = Job.getInstance(conf); conf = job.getConfiguration(); conf.setInt(MRJobConfig.IO_SORT_MB, 1); // 2^20 * spill = 14336 bytes available post-spill, at most 896 meta conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(.986328125f)); conf.setClass("test.mapcollection.class", StepFactory.class, RecordFactory.class); StepFactory.setLengths(conf, 4000, 0, 96, 0, 252); conf.setInt("test.spillmap.records", 1000); conf.setBoolean("test.disable.key.read", true); conf.setBoolean("test.disable.val.read", true); runTest("postspillmeta", job); } @Test public void testLargeRecConcurrent() throws Exception { Configuration conf = new Configuration(); conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100); Job job = Job.getInstance(conf); conf = job.getConfiguration(); conf.setInt(MRJobConfig.IO_SORT_MB, 1); conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(.986328125f)); conf.setClass("test.mapcollection.class", StepFactory.class, RecordFactory.class); StepFactory.setLengths(conf, 4000, 261120, 96, 1024, 251); conf.setInt("test.spillmap.records", 255); conf.setBoolean("test.disable.key.read", false); conf.setBoolean("test.disable.val.read", false); runTest("largeconcurrent", job); } public static class RandomFactory extends RecordFactory { public int minkey; public int maxkey; public int minval; public int maxval; private final Random r = new Random(); private static int nextRand(Random r, int max) { return (int)Math.exp(r.nextDouble() * Math.log(max)); } public void setConf(Configuration conf) { r.setSeed(conf.getLong("test.randomfactory.seed", 0L)); minkey = conf.getInt("test.randomfactory.minkey", 0); maxkey = conf.getInt("test.randomfactory.maxkey", 0) - minkey; minval = conf.getInt("test.randomfactory.minval", 0); maxval = conf.getInt("test.randomfactory.maxval", 0) - minval; } public static void setLengths(Configuration conf, Random r, int max) { int k1 = nextRand(r, max); int k2 = nextRand(r, max); if (k1 > k2) { final int tmp = k1; k1 = k2; k2 = k1; } int v1 = nextRand(r, max); int v2 = nextRand(r, max); if (v1 > v2) { final int tmp = v1; v1 = v2; v2 = v1; } setLengths(conf, k1, ++k2, v1, ++v2); } public static void setLengths(Configuration conf, int minkey, int maxkey, int minval, int maxval) { assert minkey < maxkey; assert minval < maxval; conf.setInt("test.randomfactory.minkey", minkey); conf.setInt("test.randomfactory.maxkey", maxkey); conf.setInt("test.randomfactory.minval", minval); conf.setInt("test.randomfactory.maxval", maxval); conf.setBoolean("test.disable.key.read", minkey == 0); conf.setBoolean("test.disable.val.read", minval == 0); } public int keyLen(int i) { return minkey + nextRand(r, maxkey - minkey); } public int valLen(int i) { return minval + nextRand(r, maxval - minval); } } @Test public void testRandom() throws Exception { Configuration conf = new Configuration(); conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100); Job job = Job.getInstance(conf); conf = job.getConfiguration(); conf.setInt(MRJobConfig.IO_SORT_MB, 1); conf.setClass("test.mapcollection.class", RandomFactory.class, RecordFactory.class); final Random r = new Random(); final long seed = r.nextLong(); LOG.info("SEED: " + seed); r.setSeed(seed); conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(Math.max(0.1f, r.nextFloat()))); RandomFactory.setLengths(conf, r, 1 << 14); conf.setInt("test.spillmap.records", r.nextInt(500)); conf.setLong("test.randomfactory.seed", r.nextLong()); runTest("random", job); } @Test public void testRandomCompress() throws Exception { Configuration conf = new Configuration(); conf.setInt(Job.COMPLETION_POLL_INTERVAL_KEY, 100); Job job = Job.getInstance(conf); conf = job.getConfiguration(); conf.setInt(MRJobConfig.IO_SORT_MB, 1); conf.setBoolean(MRJobConfig.MAP_OUTPUT_COMPRESS, true); conf.setClass("test.mapcollection.class", RandomFactory.class, RecordFactory.class); final Random r = new Random(); final long seed = r.nextLong(); LOG.info("SEED: " + seed); r.setSeed(seed); conf.set(MRJobConfig.MAP_SORT_SPILL_PERCENT, Float.toString(Math.max(0.1f, r.nextFloat()))); RandomFactory.setLengths(conf, r, 1 << 14); conf.setInt("test.spillmap.records", r.nextInt(500)); conf.setLong("test.randomfactory.seed", r.nextLong()); runTest("randomCompress", job); } }