package io.dstream.tez;
import static io.dstream.utils.KVUtils.kv;
import static org.junit.Assert.assertEquals;
import java.util.List;
import java.util.Map.Entry;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.After;
import org.junit.Test;
import io.dstream.DStream;
import io.dstream.SerializableStreamAssets.SerBinaryOperator;
import io.dstream.utils.KVUtils;
import junit.framework.Assert;
public class MapSideCombineTests extends BaseTezTests {
private final String applicationName = this.getClass().getSimpleName();
private final TestCombiner bo = new TestCombiner(getUnsafe().allocateMemory(4));
@After
public void after(){
clean(applicationName);
this.bo.reset();
}
@Test
public void computeMapSideCombine() throws Exception {
DStream<String> sourcePipeline = DStream.ofType(String.class, "ms");
Future<Stream<Stream<Entry<String, Integer>>>> resultFuture = sourcePipeline.<Entry<String, Integer>>compute(stream -> stream
.flatMap(line -> Stream.of(line.split("\\s+")))
.map(word -> kv(word, 1))
).reduceValues(s -> s.getKey(), s -> s.getValue(), bo)
.executeAs(this.applicationName);
Stream<Stream<Entry<String, Integer>>> result = resultFuture.get(1000000, TimeUnit.MILLISECONDS);
List<Stream<Entry<String, Integer>>> resultStreams = result.collect(Collectors.toList());
Assert.assertEquals(1, resultStreams.size());
Stream<Entry<String, Integer>> firstResultStream = resultStreams.get(0);
List<Entry<String, Integer>> firstResult = firstResultStream.collect(Collectors.toList());
Assert.assertEquals(10, firstResult.size());
Assert.assertEquals(KVUtils.kv("bar", 3), firstResult.get(0));
Assert.assertEquals(KVUtils.kv("dee", 4), firstResult.get(4));
Assert.assertEquals(KVUtils.kv("doo", 5), firstResult.get(7));
result.close();
assertEquals(3, this.bo.getTotalInvocations());
}
@Test
public void computeReduceSideCombineOnly() throws Exception {
DStream<String> sourcePipeline = DStream.ofType(String.class, "rs");
Future<Stream<Stream<Entry<String, Integer>>>> resultFuture = sourcePipeline.<Entry<String, Integer>>compute(stream -> stream
.flatMap(line -> Stream.of(line.split("\\s+")))
.map(word -> kv(word, 1))
).reduceValues(s -> s.getKey(), s -> s.getValue(), bo)
.executeAs(this.applicationName);
Stream<Stream<Entry<String, Integer>>> result = resultFuture.get(10000, TimeUnit.MILLISECONDS);
List<Stream<Entry<String, Integer>>> resultStreams = result.collect(Collectors.toList());
Assert.assertEquals(1, resultStreams.size());
Stream<Entry<String, Integer>> firstResultStream = resultStreams.get(0);
List<Entry<String, Integer>> firstResult = firstResultStream.collect(Collectors.toList());
Assert.assertEquals(10, firstResult.size());
Assert.assertEquals(KVUtils.kv("bar", 3), firstResult.get(0));
Assert.assertEquals(KVUtils.kv("dee", 4), firstResult.get(4));
Assert.assertEquals(KVUtils.kv("doo", 5), firstResult.get(7));
result.close();
assertEquals(1, this.bo.getTotalInvocations());
}
@Test
public void streamMapSideCombine() throws Exception {
DStream<String> sourceStream = DStream.ofType(String.class, "ms");
Future<Stream<Stream<Entry<String, Integer>>>> resultFuture = sourceStream
.flatMap(line -> Stream.of(line.split("\\s+")))
// .reduceValues(s -> s, s -> 1, bo)
.map(word -> kv(word, 1))
.reduceValues(s -> s.getKey(), s -> s.getValue(), bo)
.executeAs(this.applicationName);
Stream<Stream<Entry<String, Integer>>> result = resultFuture.get(1000000, TimeUnit.MILLISECONDS);
List<Stream<Entry<String, Integer>>> resultStreams = result.collect(Collectors.toList());
Assert.assertEquals(1, resultStreams.size());
Stream<Entry<String, Integer>> firstResultStream = resultStreams.get(0);
List<Entry<String, Integer>> firstResult = firstResultStream.collect(Collectors.toList());
Assert.assertEquals(10, firstResult.size());
Assert.assertEquals(KVUtils.kv("bar", 3), firstResult.get(0));
Assert.assertEquals(KVUtils.kv("dee", 4), firstResult.get(4));
Assert.assertEquals(KVUtils.kv("doo", 5), firstResult.get(7));
result.close();
assertEquals(3, this.bo.getTotalInvocations());
}
@Test
public void streamReduceSideCombineOnly() throws Exception {
DStream<String> sourceStream = DStream.ofType(String.class, "rs");
Future<Stream<Stream<Entry<String, Integer>>>> resultFuture = sourceStream
.flatMap(line -> Stream.of(line.split("\\s+")))
.map(word -> kv(word, 1))
.reduceValues(s -> s.getKey(), s -> s.getValue(), bo)
.executeAs(this.applicationName);
Stream<Stream<Entry<String, Integer>>> result = resultFuture.get(10000, TimeUnit.MILLISECONDS);
List<Stream<Entry<String, Integer>>> resultStreams = result.collect(Collectors.toList());
Assert.assertEquals(1, resultStreams.size());
Stream<Entry<String, Integer>> firstResultStream = resultStreams.get(0);
List<Entry<String, Integer>> firstResult = firstResultStream.collect(Collectors.toList());
Assert.assertEquals(10, firstResult.size());
Assert.assertEquals(KVUtils.kv("bar", 3), firstResult.get(0));
Assert.assertEquals(KVUtils.kv("dee", 4), firstResult.get(4));
Assert.assertEquals(KVUtils.kv("doo", 5), firstResult.get(7));
result.close();
assertEquals(1, this.bo.getTotalInvocations());
}
private static class TestCombiner implements SerBinaryOperator<Integer> {
private static final long serialVersionUID = 8366519776101104961L;
private final long pointer;
private boolean invoked;
public TestCombiner(long pointer) {
this.pointer = pointer;
this.reset();
}
public int getTotalInvocations(){
return getUnsafe().getInt(this.pointer);
}
public void reset(){
getUnsafe().putInt(this.pointer, 0);
}
@Override
public Integer apply(Integer t, Integer u) {
if (!this.invoked){
int invocations = getUnsafe().getInt(this.pointer)+1;
getUnsafe().putInt(this.pointer, invocations);
this.invoked = true;
}
return t + u;
}
}
}