/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.io.kafka; import static org.apache.beam.sdk.metrics.MetricResultsMatchers.attemptedMetricsResult; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.hamcrest.Matchers.hasItem; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.BigEndianLongCoder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.io.kafka.serialization.InstantDeserializer; import org.apache.beam.sdk.metrics.GaugeResult; import org.apache.beam.sdk.metrics.MetricName; import org.apache.beam.sdk.metrics.MetricNameFilter; import org.apache.beam.sdk.metrics.MetricQueryResults; import org.apache.beam.sdk.metrics.MetricResult; import org.apache.beam.sdk.metrics.MetricsFilter; import org.apache.beam.sdk.metrics.SinkMetrics; import org.apache.beam.sdk.metrics.SourceMetrics; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Distinct; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.Max; import org.apache.beam.sdk.transforms.Min; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionList; import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.MockConsumer; import org.apache.kafka.clients.consumer.OffsetResetStrategy; import org.apache.kafka.clients.producer.MockProducer; import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.serialization.IntegerDeserializer; import org.apache.kafka.common.serialization.IntegerSerializer; import org.apache.kafka.common.serialization.LongDeserializer; import org.apache.kafka.common.serialization.LongSerializer; import org.apache.kafka.common.serialization.Serializer; import org.apache.kafka.common.serialization.StringDeserializer; import org.apache.kafka.common.utils.Utils; import org.hamcrest.collection.IsIterableContainingInAnyOrder; import org.hamcrest.collection.IsIterableWithSize; import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** * Tests of {@link KafkaIO}. * Run with 'mvn test -Dkafka.clients.version=0.10.1.1', * or 'mvn test -Dkafka.clients.version=0.9.0.1' for either Kafka client version */ @RunWith(JUnit4.class) public class KafkaIOTest { /* * The tests below borrow code and structure from CountingSourceTest. In addition verifies * the reader interleaves the records from multiple partitions. * * Other tests to consider : * - test KafkaRecordCoder */ @Rule public final transient TestPipeline p = TestPipeline.create(); @Rule public ExpectedException thrown = ExpectedException.none(); // Update mock consumer with records distributed among the given topics, each with given number // of partitions. Records are assigned in round-robin order among the partitions. private static MockConsumer<byte[], byte[]> mkMockConsumer( List<String> topics, int partitionsPerTopic, int numElements, OffsetResetStrategy offsetResetStrategy) { final List<TopicPartition> partitions = new ArrayList<>(); final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> records = new HashMap<>(); Map<String, List<PartitionInfo>> partitionMap = new HashMap<>(); for (String topic : topics) { List<PartitionInfo> partIds = new ArrayList<>(partitionsPerTopic); for (int i = 0; i < partitionsPerTopic; i++) { TopicPartition tp = new TopicPartition(topic, i); partitions.add(tp); partIds.add(new PartitionInfo(topic, i, null, null, null)); records.put(tp, new ArrayList<ConsumerRecord<byte[], byte[]>>()); } partitionMap.put(topic, partIds); } int numPartitions = partitions.size(); long[] offsets = new long[numPartitions]; for (int i = 0; i < numElements; i++) { int pIdx = i % numPartitions; TopicPartition tp = partitions.get(pIdx); records.get(tp).add( new ConsumerRecord<>( tp.topic(), tp.partition(), offsets[pIdx]++, ByteBuffer.wrap(new byte[4]).putInt(i).array(), // key is 4 byte record id ByteBuffer.wrap(new byte[8]).putLong(i).array())); // value is 8 byte record id } // This is updated when reader assigns partitions. final AtomicReference<List<TopicPartition>> assignedPartitions = new AtomicReference<>(Collections.<TopicPartition>emptyList()); final MockConsumer<byte[], byte[]> consumer = new MockConsumer<byte[], byte[]>(offsetResetStrategy) { // override assign() in order to set offset limits & to save assigned partitions. //remove keyword '@Override' here, it can work with Kafka client 0.9 and 0.10 as: //1. SpEL can find this function, either input is List or Collection; //2. List extends Collection, so super.assign() could find either assign(List) // or assign(Collection). public void assign(final List<TopicPartition> assigned) { super.assign(assigned); assignedPartitions.set(ImmutableList.copyOf(assigned)); for (TopicPartition tp : assigned) { updateBeginningOffsets(ImmutableMap.of(tp, 0L)); updateEndOffsets(ImmutableMap.of(tp, (long) records.get(tp).size())); } } }; for (String topic : topics) { consumer.updatePartitions(topic, partitionMap.get(topic)); } // MockConsumer does not maintain any relationship between partition seek position and the // records added. e.g. if we add 10 records to a partition and then seek to end of the // partition, MockConsumer is still going to return the 10 records in next poll. It is // our responsibility to make sure currently enqueued records sync with partition offsets. // The following task will be called inside each invocation to MockConsumer.poll(). // We enqueue only the records with the offset >= partition's current position. Runnable recordEnqueueTask = new Runnable() { @Override public void run() { // add all the records with offset >= current partition position. for (TopicPartition tp : assignedPartitions.get()) { long curPos = consumer.position(tp); for (ConsumerRecord<byte[], byte[]> r : records.get(tp)) { if (r.offset() >= curPos) { consumer.addRecord(r); } } } consumer.schedulePollTask(this); } }; consumer.schedulePollTask(recordEnqueueTask); return consumer; } private static class ConsumerFactoryFn implements SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> { private final List<String> topics; private final int partitionsPerTopic; private final int numElements; private final OffsetResetStrategy offsetResetStrategy; public ConsumerFactoryFn(List<String> topics, int partitionsPerTopic, int numElements, OffsetResetStrategy offsetResetStrategy) { this.topics = topics; this.partitionsPerTopic = partitionsPerTopic; this.numElements = numElements; this.offsetResetStrategy = offsetResetStrategy; } @Override public Consumer<byte[], byte[]> apply(Map<String, Object> config) { return mkMockConsumer(topics, partitionsPerTopic, numElements, offsetResetStrategy); } } /** * Creates a consumer with two topics, with 10 partitions each. * numElements are (round-robin) assigned all the 20 partitions. */ private static KafkaIO.Read<Integer, Long> mkKafkaReadTransform( int numElements, @Nullable SerializableFunction<KV<Integer, Long>, Instant> timestampFn) { List<String> topics = ImmutableList.of("topic_a", "topic_b"); KafkaIO.Read<Integer, Long> reader = KafkaIO.<Integer, Long>read() .withBootstrapServers("myServer1:9092,myServer2:9092") .withTopics(topics) .withConsumerFactoryFn(new ConsumerFactoryFn( topics, 10, numElements, OffsetResetStrategy.EARLIEST)) // 20 partitions .withKeyDeserializer(IntegerDeserializer.class) .withValueDeserializer(LongDeserializer.class) .withMaxNumRecords(numElements); if (timestampFn != null) { return reader.withTimestampFn(timestampFn); } else { return reader; } } private static class AssertMultipleOf implements SerializableFunction<Iterable<Long>, Void> { private final int num; public AssertMultipleOf(int num) { this.num = num; } @Override public Void apply(Iterable<Long> values) { for (Long v : values) { assertEquals(0, v % num); } return null; } } public static void addCountingAsserts(PCollection<Long> input, long numElements) { // Count == numElements PAssert .thatSingleton(input.apply("Count", Count.<Long>globally())) .isEqualTo(numElements); // Unique count == numElements PAssert .thatSingleton(input.apply(Distinct.<Long>create()) .apply("UniqueCount", Count.<Long>globally())) .isEqualTo(numElements); // Min == 0 PAssert .thatSingleton(input.apply("Min", Min.<Long>globally())) .isEqualTo(0L); // Max == numElements-1 PAssert .thatSingleton(input.apply("Max", Max.<Long>globally())) .isEqualTo(numElements - 1); } @Test public void testUnboundedSource() { int numElements = 1000; PCollection<Long> input = p .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) .withoutMetadata()) .apply(Values.<Long>create()); addCountingAsserts(input, numElements); p.run(); } @Test public void testUnboundedSourceWithSingleTopic() { // same as testUnboundedSource, but with single topic int numElements = 1000; String topic = "my_topic"; KafkaIO.Read<Integer, Long> reader = KafkaIO.<Integer, Long>read() .withBootstrapServers("none") .withTopic("my_topic") .withConsumerFactoryFn(new ConsumerFactoryFn( ImmutableList.of(topic), 10, numElements, OffsetResetStrategy.EARLIEST)) .withMaxNumRecords(numElements) .withKeyDeserializer(IntegerDeserializer.class) .withValueDeserializer(LongDeserializer.class); PCollection<Long> input = p .apply(reader.withoutMetadata()) .apply(Values.<Long>create()); addCountingAsserts(input, numElements); p.run(); } @Test public void testUnboundedSourceWithExplicitPartitions() { int numElements = 1000; List<String> topics = ImmutableList.of("test"); KafkaIO.Read<byte[], Long> reader = KafkaIO.<byte[], Long>read() .withBootstrapServers("none") .withTopicPartitions(ImmutableList.of(new TopicPartition("test", 5))) .withConsumerFactoryFn(new ConsumerFactoryFn( topics, 10, numElements, OffsetResetStrategy.EARLIEST)) // 10 partitions .withKeyDeserializer(ByteArrayDeserializer.class) .withValueDeserializer(LongDeserializer.class) .withMaxNumRecords(numElements / 10); PCollection<Long> input = p .apply(reader.withoutMetadata()) .apply(Values.<Long>create()); // assert that every element is a multiple of 5. PAssert .that(input) .satisfies(new AssertMultipleOf(5)); PAssert .thatSingleton(input.apply(Count.<Long>globally())) .isEqualTo(numElements / 10L); p.run(); } private static class ElementValueDiff extends DoFn<Long, Long> { @ProcessElement public void processElement(ProcessContext c) throws Exception { c.output(c.element() - c.timestamp().getMillis()); } } @Test public void testUnboundedSourceTimestamps() { int numElements = 1000; PCollection<Long> input = p .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()).withoutMetadata()) .apply(Values.<Long>create()); addCountingAsserts(input, numElements); PCollection<Long> diffs = input .apply("TimestampDiff", ParDo.of(new ElementValueDiff())) .apply("DistinctTimestamps", Distinct.<Long>create()); // This assert also confirms that diffs only has one unique value. PAssert.thatSingleton(diffs).isEqualTo(0L); p.run(); } private static class RemoveKafkaMetadata<K, V> extends DoFn<KafkaRecord<K, V>, KV<K, V>> { @ProcessElement public void processElement(ProcessContext ctx) throws Exception { ctx.output(ctx.element().getKV()); } } @Test public void testUnboundedSourceSplits() throws Exception { int numElements = 1000; int numSplits = 10; // Coders must be specified explicitly here due to the way the transform // is used in the test. UnboundedSource<KafkaRecord<Integer, Long>, ?> initial = mkKafkaReadTransform(numElements, null) .withKeyDeserializerAndCoder(IntegerDeserializer.class, BigEndianIntegerCoder.of()) .withValueDeserializerAndCoder(LongDeserializer.class, BigEndianLongCoder.of()) .makeSource(); List<? extends UnboundedSource<KafkaRecord<Integer, Long>, ?>> splits = initial.split(numSplits, p.getOptions()); assertEquals("Expected exact splitting", numSplits, splits.size()); long elementsPerSplit = numElements / numSplits; assertEquals("Expected even splits", numElements, elementsPerSplit * numSplits); PCollectionList<Long> pcollections = PCollectionList.empty(p); for (int i = 0; i < splits.size(); ++i) { pcollections = pcollections.and( p.apply("split" + i, Read.from(splits.get(i)).withMaxNumRecords(elementsPerSplit)) .apply("Remove Metadata " + i, ParDo.of(new RemoveKafkaMetadata<Integer, Long>())) .apply("collection " + i, Values.<Long>create())); } PCollection<Long> input = pcollections.apply(Flatten.<Long>pCollections()); addCountingAsserts(input, numElements); p.run(); } /** * A timestamp function that uses the given value as the timestamp. */ private static class ValueAsTimestampFn implements SerializableFunction<KV<Integer, Long>, Instant> { @Override public Instant apply(KV<Integer, Long> input) { return new Instant(input.getValue()); } } // Kafka records are read in a separate thread inside the reader. As a result advance() might not // read any records even from the mock consumer, especially for the first record. // This is a helper method to loop until we read a record. private static void advanceOnce(UnboundedReader<?> reader, boolean isStarted) throws IOException { if (!isStarted && reader.start()) { return; } while (!reader.advance()) { // very rarely will there be more than one attempts. // In case of a bug we might end up looping forever, and test will fail with a timeout. // Avoid hard cpu spinning in case of a test failure. try { Thread.sleep(1); } catch (InterruptedException e) { throw new RuntimeException(e); } } } @Test public void testUnboundedSourceCheckpointMark() throws Exception { int numElements = 85; // 85 to make sure some partitions have more records than other. // create a single split: UnboundedSource<KafkaRecord<Integer, Long>, KafkaCheckpointMark> source = mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) .makeSource() .split(1, PipelineOptionsFactory.create()) .get(0); UnboundedReader<KafkaRecord<Integer, Long>> reader = source.createReader(null, null); final int numToSkip = 20; // one from each partition. // advance numToSkip elements for (int i = 0; i < numToSkip; ++i) { advanceOnce(reader, i > 0); } // Confirm that we get the expected element in sequence before checkpointing. assertEquals(numToSkip - 1, (long) reader.getCurrent().getKV().getValue()); assertEquals(numToSkip - 1, reader.getCurrentTimestamp().getMillis()); // Checkpoint and restart, and confirm that the source continues correctly. KafkaCheckpointMark mark = CoderUtils.clone( source.getCheckpointMarkCoder(), (KafkaCheckpointMark) reader.getCheckpointMark()); reader = source.createReader(null, mark); // Confirm that we get the next elements in sequence. // This also confirms that Reader interleaves records from each partitions by the reader. for (int i = numToSkip; i < numElements; i++) { advanceOnce(reader, i > numToSkip); assertEquals(i, (long) reader.getCurrent().getKV().getValue()); assertEquals(i, reader.getCurrentTimestamp().getMillis()); } } @Test public void testUnboundedSourceCheckpointMarkWithEmptyPartitions() throws Exception { // Similar to testUnboundedSourceCheckpointMark(), but verifies that source resumes // properly from empty partitions, without missing messages added since checkpoint. // Initialize consumer with fewer elements than number of partitions so that some are empty. int initialNumElements = 5; UnboundedSource<KafkaRecord<Integer, Long>, KafkaCheckpointMark> source = mkKafkaReadTransform(initialNumElements, new ValueAsTimestampFn()) .makeSource() .split(1, PipelineOptionsFactory.create()) .get(0); UnboundedReader<KafkaRecord<Integer, Long>> reader = source.createReader(null, null); for (int l = 0; l < initialNumElements; ++l) { advanceOnce(reader, l > 0); } // Checkpoint and restart, and confirm that the source continues correctly. KafkaCheckpointMark mark = CoderUtils.clone( source.getCheckpointMarkCoder(), (KafkaCheckpointMark) reader.getCheckpointMark()); // Create another source with MockConsumer with OffsetResetStrategy.LATEST. This insures that // the reader need to explicitly need to seek to first offset for partitions that were empty. int numElements = 100; // all the 20 partitions will have elements List<String> topics = ImmutableList.of("topic_a", "topic_b"); source = KafkaIO.<Integer, Long>read() .withBootstrapServers("none") .withTopics(topics) .withConsumerFactoryFn(new ConsumerFactoryFn( topics, 10, numElements, OffsetResetStrategy.LATEST)) .withKeyDeserializer(IntegerDeserializer.class) .withValueDeserializer(LongDeserializer.class) .withMaxNumRecords(numElements) .withTimestampFn(new ValueAsTimestampFn()) .makeSource() .split(1, PipelineOptionsFactory.create()) .get(0); reader = source.createReader(null, mark); // Verify in any order. As the partitions are unevenly read, the returned records are not in a // simple order. Note that testUnboundedSourceCheckpointMark() verifies round-robin oder. List<Long> expected = new ArrayList<>(); List<Long> actual = new ArrayList<>(); for (long i = initialNumElements; i < numElements; i++) { advanceOnce(reader, i > initialNumElements); expected.add(i); actual.add(reader.getCurrent().getKV().getValue()); } assertThat(actual, IsIterableContainingInAnyOrder.containsInAnyOrder(expected.toArray())); } @Test public void testUnboundedSourceMetrics() { int numElements = 1000; String readStep = "readFromKafka"; p.apply(readStep, mkKafkaReadTransform(numElements, new ValueAsTimestampFn()).withoutMetadata()); PipelineResult result = p.run(); String splitId = "0"; MetricName elementsRead = SourceMetrics.elementsRead().getName(); MetricName elementsReadBySplit = SourceMetrics.elementsReadBySplit(splitId).getName(); MetricName bytesRead = SourceMetrics.bytesRead().getName(); MetricName bytesReadBySplit = SourceMetrics.bytesReadBySplit(splitId).getName(); MetricName backlogElementsOfSplit = SourceMetrics.backlogElementsOfSplit(splitId).getName(); MetricName backlogBytesOfSplit = SourceMetrics.backlogBytesOfSplit(splitId).getName(); MetricQueryResults metrics = result.metrics().queryMetrics( MetricsFilter.builder().build()); Iterable<MetricResult<Long>> counters = metrics.counters(); assertThat(counters, hasItem(attemptedMetricsResult( elementsRead.namespace(), elementsRead.name(), readStep, 1000L))); assertThat(counters, hasItem(attemptedMetricsResult( elementsReadBySplit.namespace(), elementsReadBySplit.name(), readStep, 1000L))); assertThat(counters, hasItem(attemptedMetricsResult( bytesRead.namespace(), bytesRead.name(), readStep, 12000L))); assertThat(counters, hasItem(attemptedMetricsResult( bytesReadBySplit.namespace(), bytesReadBySplit.name(), readStep, 12000L))); MetricQueryResults backlogElementsMetrics = result.metrics().queryMetrics( MetricsFilter.builder() .addNameFilter( MetricNameFilter.named( backlogElementsOfSplit.namespace(), backlogElementsOfSplit.name())) .build()); // since gauge values may be inconsistent in some environments assert only on their existence. assertThat(backlogElementsMetrics.gauges(), IsIterableWithSize.<MetricResult<GaugeResult>>iterableWithSize(1)); MetricQueryResults backlogBytesMetrics = result.metrics().queryMetrics( MetricsFilter.builder() .addNameFilter( MetricNameFilter.named( backlogBytesOfSplit.namespace(), backlogBytesOfSplit.name())) .build()); // since gauge values may be inconsistent in some environments assert only on their existence. assertThat(backlogBytesMetrics.gauges(), IsIterableWithSize.<MetricResult<GaugeResult>>iterableWithSize(1)); } @Test public void testSink() throws Exception { // Simply read from kafka source and write to kafka sink. Then verify the records // are correctly published to mock kafka producer. int numElements = 1000; synchronized (MOCK_PRODUCER_LOCK) { MOCK_PRODUCER.clear(); ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread().start(); String topic = "test"; p .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) .withoutMetadata()) .apply(KafkaIO.<Integer, Long>write() .withBootstrapServers("none") .withTopic(topic) .withKeySerializer(IntegerSerializer.class) .withValueSerializer(LongSerializer.class) .withProducerFactoryFn(new ProducerFactoryFn())); p.run(); completionThread.shutdown(); verifyProducerRecords(topic, numElements, false); } } @Test public void testValuesSink() throws Exception { // similar to testSink(), but use values()' interface. int numElements = 1000; synchronized (MOCK_PRODUCER_LOCK) { MOCK_PRODUCER.clear(); ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread().start(); String topic = "test"; p .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) .withoutMetadata()) .apply(Values.<Long>create()) // there are no keys .apply(KafkaIO.<Integer, Long>write() .withBootstrapServers("none") .withTopic(topic) .withValueSerializer(LongSerializer.class) .withProducerFactoryFn(new ProducerFactoryFn()) .values()); p.run(); completionThread.shutdown(); verifyProducerRecords(topic, numElements, true); } } @Test public void testSinkWithSendErrors() throws Throwable { // similar to testSink(), except that up to 10 of the send calls to producer will fail // asynchronously. // TODO: Ideally we want the pipeline to run to completion by retrying bundles that fail. // We limit the number of errors injected to 10 below. This would reflect a real streaming // pipeline. But I am sure how to achieve that. For now expect an exception: thrown.expect(InjectedErrorException.class); thrown.expectMessage("Injected Error #1"); int numElements = 1000; synchronized (MOCK_PRODUCER_LOCK) { MOCK_PRODUCER.clear(); String topic = "test"; ProducerSendCompletionThread completionThreadWithErrors = new ProducerSendCompletionThread(10, 100).start(); p .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) .withoutMetadata()) .apply(KafkaIO.<Integer, Long>write() .withBootstrapServers("none") .withTopic(topic) .withKeySerializer(IntegerSerializer.class) .withValueSerializer(LongSerializer.class) .withProducerFactoryFn(new ProducerFactoryFn())); try { p.run(); } catch (PipelineExecutionException e) { // throwing inner exception helps assert that first exception is thrown from the Sink throw e.getCause().getCause(); } finally { completionThreadWithErrors.shutdown(); } } } @Test public void testSourceDisplayData() { KafkaIO.Read<Integer, Long> read = mkKafkaReadTransform(10, null); DisplayData displayData = DisplayData.from(read); assertThat(displayData, hasDisplayItem("topics", "topic_a,topic_b")); assertThat(displayData, hasDisplayItem("enable.auto.commit", false)); assertThat(displayData, hasDisplayItem("bootstrap.servers", "myServer1:9092,myServer2:9092")); assertThat(displayData, hasDisplayItem("auto.offset.reset", "latest")); assertThat(displayData, hasDisplayItem("receive.buffer.bytes", 524288)); } @Test public void testSourceWithExplicitPartitionsDisplayData() { KafkaIO.Read<byte[], Long> read = KafkaIO.<byte[], Long>read() .withBootstrapServers("myServer1:9092,myServer2:9092") .withTopicPartitions(ImmutableList.of(new TopicPartition("test", 5), new TopicPartition("test", 6))) .withConsumerFactoryFn(new ConsumerFactoryFn( Lists.newArrayList("test"), 10, 10, OffsetResetStrategy.EARLIEST)) // 10 partitions .withKeyDeserializer(ByteArrayDeserializer.class) .withValueDeserializer(LongDeserializer.class); DisplayData displayData = DisplayData.from(read); assertThat(displayData, hasDisplayItem("topicPartitions", "test-5,test-6")); assertThat(displayData, hasDisplayItem("enable.auto.commit", false)); assertThat(displayData, hasDisplayItem("bootstrap.servers", "myServer1:9092,myServer2:9092")); assertThat(displayData, hasDisplayItem("auto.offset.reset", "latest")); assertThat(displayData, hasDisplayItem("receive.buffer.bytes", 524288)); } @Test public void testSinkDisplayData() { KafkaIO.Write<Integer, Long> write = KafkaIO.<Integer, Long>write() .withBootstrapServers("myServerA:9092,myServerB:9092") .withTopic("myTopic") .withValueSerializer(LongSerializer.class) .withProducerFactoryFn(new ProducerFactoryFn()); DisplayData displayData = DisplayData.from(write); assertThat(displayData, hasDisplayItem("topic", "myTopic")); assertThat(displayData, hasDisplayItem("bootstrap.servers", "myServerA:9092,myServerB:9092")); assertThat(displayData, hasDisplayItem("retries", 3)); } // interface for testing coder inference private interface DummyInterface<T> { } // interface for testing coder inference private interface DummyNonparametricInterface { } // class for testing coder inference private static class DeserializerWithInterfaces implements DummyInterface<String>, DummyNonparametricInterface, Deserializer<Long> { @Override public void configure(Map<String, ?> configs, boolean isKey) { } @Override public Long deserialize(String topic, byte[] bytes) { return 0L; } @Override public void close() { } } // class for which a coder cannot be infered private static class NonInferableObject { } // class for testing coder inference private static class NonInferableObjectDeserializer implements Deserializer<NonInferableObject> { @Override public void configure(Map<String, ?> configs, boolean isKey) { } @Override public NonInferableObject deserialize(String topic, byte[] bytes) { return new NonInferableObject(); } @Override public void close() { } } @Test public void testInferKeyCoder() { CoderRegistry registry = CoderRegistry.createDefault(); assertTrue(KafkaIO.inferCoder(registry, LongDeserializer.class).getValueCoder() instanceof VarLongCoder); assertTrue(KafkaIO.inferCoder(registry, StringDeserializer.class).getValueCoder() instanceof StringUtf8Coder); assertTrue(KafkaIO.inferCoder(registry, InstantDeserializer.class).getValueCoder() instanceof InstantCoder); assertTrue(KafkaIO.inferCoder(registry, DeserializerWithInterfaces.class).getValueCoder() instanceof VarLongCoder); } @Rule public ExpectedException cannotInferException = ExpectedException.none(); @Test public void testInferKeyCoderFailure() throws Exception { cannotInferException.expect(RuntimeException.class); CoderRegistry registry = CoderRegistry.createDefault(); KafkaIO.inferCoder(registry, NonInferableObjectDeserializer.class); } @Test public void testSinkMetrics() throws Exception { // Simply read from kafka source and write to kafka sink. Then verify the metrics are reported. int numElements = 1000; synchronized (MOCK_PRODUCER_LOCK) { MOCK_PRODUCER.clear(); ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread().start(); String topic = "test"; p .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn()) .withoutMetadata()) .apply("writeToKafka", KafkaIO.<Integer, Long>write() .withBootstrapServers("none") .withTopic(topic) .withKeySerializer(IntegerSerializer.class) .withValueSerializer(LongSerializer.class) .withProducerFactoryFn(new ProducerFactoryFn())); PipelineResult result = p.run(); MetricName elementsWritten = SinkMetrics.elementsWritten().getName(); MetricQueryResults metrics = result.metrics().queryMetrics( MetricsFilter.builder() .addNameFilter(MetricNameFilter.inNamespace(elementsWritten.namespace())) .build()); assertThat(metrics.counters(), hasItem( attemptedMetricsResult( elementsWritten.namespace(), elementsWritten.name(), "writeToKafka", 1000L))); completionThread.shutdown(); } } private static void verifyProducerRecords(String topic, int numElements, boolean keyIsAbsent) { // verify that appropriate messages are written to kafka List<ProducerRecord<Integer, Long>> sent = MOCK_PRODUCER.history(); // sort by values Collections.sort(sent, new Comparator<ProducerRecord<Integer, Long>>() { @Override public int compare(ProducerRecord<Integer, Long> o1, ProducerRecord<Integer, Long> o2) { return Long.compare(o1.value(), o2.value()); } }); for (int i = 0; i < numElements; i++) { ProducerRecord<Integer, Long> record = sent.get(i); assertEquals(topic, record.topic()); if (keyIsAbsent) { assertNull(record.key()); } else { assertEquals(i, record.key().intValue()); } assertEquals(i, record.value().longValue()); } } /** * Singleton MockProudcer. Using a singleton here since we need access to the object to fetch * the actual records published to the producer. This prohibits running the tests using * the producer in parallel, but there are only one or two tests. */ private static final MockProducer<Integer, Long> MOCK_PRODUCER = new MockProducer<Integer, Long>( false, // disable synchronous completion of send. see ProducerSendCompletionThread below. new IntegerSerializer(), new LongSerializer()) { // override flush() so that it does not complete all the waiting sends, giving a chance to // ProducerCompletionThread to inject errors. @Override public void flush() { while (completeNext()) { // there are some uncompleted records. let the completion thread handle them. try { Thread.sleep(10); } catch (InterruptedException e) { } } } }; // use a separate object serialize tests using MOCK_PRODUCER so that we don't interfere // with Kafka MockProducer locking itself. private static final Object MOCK_PRODUCER_LOCK = new Object(); private static class ProducerFactoryFn implements SerializableFunction<Map<String, Object>, Producer<Integer, Long>> { @SuppressWarnings("unchecked") @Override public Producer<Integer, Long> apply(Map<String, Object> config) { // Make sure the config is correctly set up for serializers. // There may not be a key serializer if we're interested only in values. if (config.get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG) != null) { Utils.newInstance( ((Class<?>) config.get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG)) .asSubclass(Serializer.class) ).configure(config, true); } Utils.newInstance( ((Class<?>) config.get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG)) .asSubclass(Serializer.class) ).configure(config, false); return MOCK_PRODUCER; } } private static class InjectedErrorException extends RuntimeException { public InjectedErrorException(String message) { super(message); } } /** * We start MockProducer with auto-completion disabled. That implies a record is not marked sent * until #completeNext() is called on it. This class starts a thread to asynchronously 'complete' * the the sends. During completion, we can also make those requests fail. This error injection * is used in one of the tests. */ private static class ProducerSendCompletionThread { private final int maxErrors; private final int errorFrequency; private final AtomicBoolean done = new AtomicBoolean(false); private final ExecutorService injectorThread; private int numCompletions = 0; ProducerSendCompletionThread() { // complete everything successfully this(0, 0); } ProducerSendCompletionThread(final int maxErrors, final int errorFrequency) { this.maxErrors = maxErrors; this.errorFrequency = errorFrequency; injectorThread = Executors.newSingleThreadExecutor(); } ProducerSendCompletionThread start() { injectorThread.submit(new Runnable() { @Override public void run() { int errorsInjected = 0; while (!done.get()) { boolean successful; if (errorsInjected < maxErrors && ((numCompletions + 1) % errorFrequency) == 0) { successful = MOCK_PRODUCER.errorNext( new InjectedErrorException("Injected Error #" + (errorsInjected + 1))); if (successful) { errorsInjected++; } } else { successful = MOCK_PRODUCER.completeNext(); } if (successful) { numCompletions++; } else { // wait a bit since there are no unsent records try { Thread.sleep(1); } catch (InterruptedException e) { } } } } }); return this; } void shutdown() { done.set(true); injectorThread.shutdown(); try { assertTrue(injectorThread.awaitTermination(10, TimeUnit.SECONDS)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); } } } }