package ch.unibe.scg.cells;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertThat;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.regex.Pattern;
import javax.inject.Inject;
import javax.inject.Provider;
import javax.inject.Qualifier;
import org.junit.Test;
import com.google.common.base.Charsets;
import com.google.common.primitives.Ints;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.TypeLiteral;
import com.google.inject.util.Providers;
import com.google.protobuf.ByteString;
/** Check {@link LocalCounter}. */
public final class LocalCounterTest {
/** A codec for integers. */
public static class IntegerCodec implements Codec<Integer> {
private static final long serialVersionUID = 1L;
@Override
public Cell<Integer> encode(Integer i) {
return Cell.make(ByteString.copyFrom(Ints.toByteArray(i)), ByteString.copyFromUtf8("t"),
ByteString.copyFromUtf8("t"));
}
@Override
public Integer decode(Cell<Integer> encoded) throws IOException {
return new Integer(Ints.fromByteArray(encoded.getRowKey().toByteArray()));
}
}
/** A counter for io exceptions. */
@Qualifier
@Target({ FIELD, PARAMETER, METHOD })
@Retention(RUNTIME)
public static @interface IOExceptions {}
/** A counter for user exceptions. */
@Qualifier
@Target({ FIELD, PARAMETER, METHOD })
@Retention(RUNTIME)
public static @interface UsrExceptions {}
/** A simple mapper, that maps integers exactly to themselves. */
public static class IdentityMapper implements Mapper<Integer, Integer> {
private static final long serialVersionUID = 1L;
final Counter ioCounter;
final Counter usrCounter;
String finalIoCount;
String finalUsrCount;
@Inject
IdentityMapper(@IOExceptions Counter ioCounter, @UsrExceptions Counter usrCounter) {
this.ioCounter = ioCounter;
this.usrCounter = usrCounter;
}
@Override
public void close() {
finalIoCount = ioCounter.toString();
finalUsrCount = usrCounter.toString();
}
@Override
public void map(Integer first, OneShotIterable<Integer> row, Sink<Integer> sink) throws IOException,
InterruptedException {
for (Integer i : row) {
sink.write(i);
ioCounter.increment(1L);
usrCounter.increment(2L);
}
}
}
/** Checks that counter stays alive after being serialized. */
@Test
public void testCounterSerializationIsLive() throws IOException {
Set<LocalCounter> registry = new LinkedHashSet<>();
LocalCounter cnt = new LocalCounter("cnt1", Providers.of(new AtomicLong()),
Providers.of(registry));
LocalCounter cntCopy = ShallowSerializingCopy.clone(cnt);
cntCopy.increment(1L);
assertThat(cntCopy.toString(), equalTo("cnt1: 1"));
assertThat(cnt.toString(), equalTo("cnt1: 1"));
assertThat(registry.toString(), equalTo("[cnt1: 1]"));
}
/** Checks that counters do not carry their values between pipeline stages. */
@Test
public void testCounterResetsAcrossStages() throws IOException, InterruptedException {
Module m = makeCellsModule();
Injector inj = Guice.createInjector(m, new LocalExecutionModule());
try (InMemoryPipeline<Integer, Integer> pipe
= inj.getInstance(InMemoryPipeline.Builder.class)
.make(Cells.shard(Cells.encode(generateSequence(1000), new IntegerCodec())))) {
Runner r = inj.getInstance(Runner.class);
r.run(pipe);
assertThat(r.mapper.finalIoCount, equalTo("ch.unibe.scg.cells.LocalCounterTest$IOExceptions: 1000"));
assertThat(r.mapper.finalUsrCount, equalTo("ch.unibe.scg.cells.LocalCounterTest$UsrExceptions: 2000"));
}
}
/** Checks counters being printed as pipeline progresses. */
@Test
public void testCounterProgressIsPrinted() throws IOException, InterruptedException {
Module m = makeCellsModule();
Injector inj = Guice.createInjector(m, new LocalExecutionModule());
ByteArrayOutputStream bos = new ByteArrayOutputStream();
PrintStream out = new PrintStream(bos, true); // flush after each write.
try (InMemoryPipeline<Integer, Integer> pipe
= new InMemoryPipeline<>(
Cells.shard(Cells.encode(generateSequence(1000), new IntegerCodec())),
inj.getInstance(PipelineStageScope.class),
inj.getInstance(Key.get(new TypeLiteral<Provider<Set<LocalCounter>>>(){})),
out)) {
inj.getInstance(Runner.class).run(pipe);
String log = bos.toString(Charsets.UTF_8.toString());
// each line in the log should be of following format: <some name>: <number>.
Pattern linePattern = Pattern.compile("\\S+: \\d+");
for (String logLine: log.split(System.lineSeparator())) {
assertThat(linePattern.matcher(logLine).matches(), equalTo(true));
}
// these final lines should occur once per pipeline stage. In the test we have 2 stages.
assertThat(countMatches(log, "ch.unibe.scg.cells.LocalCounterTest$IOExceptions: 1000"), equalTo(2));
assertThat(countMatches(log, "ch.unibe.scg.cells.LocalCounterTest$UsrExceptions: 2000"), equalTo(2));
// TODO: this test will print 0 counters for large total number. The implementation should be fixed.
}
}
@SuppressWarnings("javadoc")
public static class Runner {
final IdentityMapper mapper;
@Inject
Runner(IdentityMapper mapper) {
this.mapper = mapper;
}
/** Runs a pipeline that maps ints to ints. */
public void run(Pipeline<Integer, Integer> pipe) throws IOException, InterruptedException {
pipe.influx(new IntegerCodec())
.map(mapper)
.shuffle(new IntegerCodec())
.mapAndEfflux(mapper, new IntegerCodec());
}
}
/** Generates a sequence of integers of specified length. */
public static Iterable<Integer> generateSequence(int number) {
List<Integer> ret = new ArrayList<>(number);
for (int i = 0; i < number; i++) {
ret.add(i);
}
return ret;
}
private static int countMatches(String str, String sub) {
int lastIndex = 0;
int count = 0;
while (lastIndex != -1) {
lastIndex = str.indexOf(sub, lastIndex);
if (lastIndex != -1) {
count++;
lastIndex += sub.length();
}
}
return count;
}
private static Module makeCellsModule() {
return new CellsModule() {
@Override protected void configure() {
installCounter(IOExceptions.class, new LocalCounterModule());
installCounter(UsrExceptions.class, new LocalCounterModule());
}
};
}
}