// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package com.google.pubsub.kafka.sink; import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.ByteString; import com.google.pubsub.kafka.common.ConnectorUtils; import com.google.pubsub.v1.PublishRequest; import com.google.pubsub.v1.PublishResponse; import com.google.pubsub.v1.PubsubMessage; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import org.apache.kafka.clients.consumer.OffsetAndMetadata; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.connect.data.Field; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.Schema.Type; import org.apache.kafka.connect.data.Struct; import org.apache.kafka.connect.errors.DataException; import org.apache.kafka.connect.sink.SinkRecord; import org.apache.kafka.connect.sink.SinkTask; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A {@link SinkTask} used by a {@link CloudPubSubSinkConnector} to write messages to <a * href="https://cloud.google.com/pubsub">Google Cloud Pub/Sub</a>. */ public class CloudPubSubSinkTask extends SinkTask { private static final Logger log = LoggerFactory.getLogger(CloudPubSubSinkTask.class); private static final int NUM_CPS_PUBLISHERS = 10; private static final int CPS_MAX_REQUEST_SIZE = (10 << 20) - 1024; // Leave room for overhead. private static final int CPS_MAX_MESSAGES_PER_REQUEST = 1000; private static final int CPS_MESSAGE_KEY_ATTRIBUTE_SIZE = ConnectorUtils.CPS_MESSAGE_KEY_ATTRIBUTE.length(); // Maps a topic to another map which contains the outstanding futures per partition private Map<String, Map<Integer, OutstandingFuturesForPartition>> allOutstandingFutures = new HashMap<>(); // Maps a topic to another map which contains the unpublished messages per partition private Map<String, Map<Integer, UnpublishedMessagesForPartition>> allUnpublishedMessages = new HashMap<>(); private String cpsTopic; private String messageBodyName; private int maxBufferSize; private CloudPubSubPublisher publisher; /** Holds a list of the publishing futures that have not been processed for a single partition. */ private class OutstandingFuturesForPartition { public List<ListenableFuture<PublishResponse>> futures = new ArrayList<>(); } /** * Holds a list of the unpublished messages for a single partition and the total size in bytes of * the messages in the list. */ private class UnpublishedMessagesForPartition { public List<PubsubMessage> messages = new ArrayList<>(); public int size = 0; } public CloudPubSubSinkTask() {} @VisibleForTesting public CloudPubSubSinkTask(CloudPubSubPublisher publisher) { this.publisher = publisher; } @Override public String version() { return new CloudPubSubSinkConnector().version(); } @Override public void start(Map<String, String> props) { Map<String, Object> validatedProps = new CloudPubSubSinkConnector().config().parse(props); cpsTopic = String.format( ConnectorUtils.CPS_TOPIC_FORMAT, validatedProps.get(ConnectorUtils.CPS_PROJECT_CONFIG), validatedProps.get(ConnectorUtils.CPS_TOPIC_CONFIG)); maxBufferSize = (Integer) validatedProps.get(CloudPubSubSinkConnector.MAX_BUFFER_SIZE_CONFIG); messageBodyName = (String) validatedProps.get(CloudPubSubSinkConnector.CPS_MESSAGE_BODY_NAME); if (publisher == null) { // Only do this if we did not use the constructor. publisher = new CloudPubSubRoundRobinPublisher(NUM_CPS_PUBLISHERS); } log.info("Start CloudPubSubSinkTask"); } @Override public void put(Collection<SinkRecord> sinkRecords) { log.debug("Received " + sinkRecords.size() + " messages to send to CPS."); PubsubMessage.Builder builder = PubsubMessage.newBuilder(); for (SinkRecord record : sinkRecords) { log.trace("Received record: " + record.toString()); Map<String, String> attributes = new HashMap<>(); ByteString value = handleValue(record.valueSchema(), record.value(), attributes); // Get the total number of bytes in this message. int messageSize = value.size(); // Assumes the topic name is in ASCII. if (record.key() != null) { attributes.put(ConnectorUtils.CPS_MESSAGE_KEY_ATTRIBUTE, record.key().toString()); } for (String key : attributes.keySet()) { messageSize+= key.getBytes().length + attributes.get(key).getBytes().length; } PubsubMessage message = builder.setData(value).putAllAttributes(attributes).build(); // Get a map containing all the unpublished messages per partition for this topic. Map<Integer, UnpublishedMessagesForPartition> unpublishedMessagesForTopic = allUnpublishedMessages.get(record.topic()); if (unpublishedMessagesForTopic == null) { unpublishedMessagesForTopic = new HashMap<>(); allUnpublishedMessages.put(record.topic(), unpublishedMessagesForTopic); } // Get the object containing the unpublished messages for the // specific topic and partition this SinkRecord is associated with. UnpublishedMessagesForPartition unpublishedMessages = unpublishedMessagesForTopic.get(record.kafkaPartition()); if (unpublishedMessages == null) { unpublishedMessages = new UnpublishedMessagesForPartition(); unpublishedMessagesForTopic.put(record.kafkaPartition(), unpublishedMessages); } int newUnpublishedSize = unpublishedMessages.size + messageSize; // Publish messages in this partition if the total number of bytes goes over limit. if (newUnpublishedSize > CPS_MAX_REQUEST_SIZE) { publishMessagesForPartition( record.topic(), record.kafkaPartition(), unpublishedMessages.messages); newUnpublishedSize = messageSize; } unpublishedMessages.size = newUnpublishedSize; unpublishedMessages.messages.add(message); // If the number of messages in this partition is greater than the batch size, then publish. if (unpublishedMessages.messages.size() >= maxBufferSize) { publishMessagesForPartition( record.topic(), record.kafkaPartition(), unpublishedMessages.messages); } } } private ByteString handleValue(Schema schema, Object value, Map<String, String> attributes) { if (schema == null) { String str = value.toString(); return ByteString.copyFromUtf8(str); } Schema.Type t = schema.type(); switch (t) { case INT8: byte b = (Byte) value; byte[] arr = {b}; return ByteString.copyFrom(arr); case INT16: ByteBuffer shortBuf = ByteBuffer.allocate(2); shortBuf.putShort((Short) value); return ByteString.copyFrom(shortBuf); case INT32: ByteBuffer intBuf = ByteBuffer.allocate(4); intBuf.putInt((Integer) value); return ByteString.copyFrom(intBuf); case INT64: ByteBuffer longBuf = ByteBuffer.allocate(8); longBuf.putLong((Long) value); return ByteString.copyFrom(longBuf); case FLOAT32: ByteBuffer floatBuf = ByteBuffer.allocate(4); floatBuf.putFloat((Float) value); return ByteString.copyFrom(floatBuf); case FLOAT64: ByteBuffer doubleBuf = ByteBuffer.allocate(8); doubleBuf.putDouble((Double) value); return ByteString.copyFrom(doubleBuf); case BOOLEAN: byte bool = (byte)((Boolean) value?1:0); byte[] boolArr = {bool}; return ByteString.copyFrom(boolArr); case STRING: String str = (String) value; return ByteString.copyFromUtf8(str); case BYTES: if (value instanceof ByteString) { return (ByteString) value; } else if (value instanceof byte[]) { return ByteString.copyFrom((byte[]) value); } else if (value instanceof ByteBuffer) { return ByteString.copyFrom((ByteBuffer) value); } else { throw new DataException("Unexpected value class with BYTES schema type."); } case STRUCT: Struct struct = (Struct) value; ByteString msgBody = null; for (Field f : schema.fields()) { Object val = struct.get(f); if (val == null) { throw new DataException("Struct message body does not support Map or Struct types."); } if (f.name().equals(messageBodyName)) { Schema bodySchema = f.schema(); msgBody = handleValue(bodySchema, val, null); } else { f.name(); attributes.put(f.name(), val.toString()); } } if (msgBody != null) { return msgBody; } else { return ByteString.EMPTY; } case MAP: Map<Object, Object> map = (Map<Object, Object>) value; Set<Object> keys = map.keySet(); ByteString mapBody = null; for (Object key : keys) { if (key.equals(messageBodyName)) { mapBody = ByteString.copyFromUtf8(map.get(key).toString()); } else { attributes.put(key.toString(), map.get(key).toString()); } } if (mapBody != null) { return mapBody; } else { return ByteString.EMPTY; } case ARRAY: Schema.Type arrType = schema.valueSchema().type(); if (arrType == Type.MAP || arrType == Type.STRUCT) { throw new DataException("Array type does not support Map or Struct types."); } ByteString out = ByteString.EMPTY; Object[] objArr = (Object[]) value; for (Object o : objArr) { out = out.concat(handleValue(schema.valueSchema(), o, null)); } return out; } return ByteString.EMPTY; } @Override public void flush(Map<TopicPartition, OffsetAndMetadata> partitionOffsets) { log.debug("Flushing..."); // Publish all messages that have not been published yet. for (Map.Entry<String, Map<Integer, UnpublishedMessagesForPartition>> entry : allUnpublishedMessages.entrySet()) { for (Map.Entry<Integer, UnpublishedMessagesForPartition> innerEntry : entry.getValue().entrySet()) { publishMessagesForPartition( entry.getKey(), innerEntry.getKey(), innerEntry.getValue().messages); } } allUnpublishedMessages.clear(); // Process results of all the outstanding futures specified by each TopicPartition. for (Map.Entry<TopicPartition, OffsetAndMetadata> partitionOffset : partitionOffsets.entrySet()) { log.trace("Received flush for partition " + partitionOffset.getKey().toString()); Map<Integer, OutstandingFuturesForPartition> outstandingFuturesForTopic = allOutstandingFutures.get(partitionOffset.getKey().topic()); if (outstandingFuturesForTopic == null) { continue; } OutstandingFuturesForPartition outstandingFutures = outstandingFuturesForTopic.get(partitionOffset.getKey().partition()); if (outstandingFutures == null) { continue; } try { for (ListenableFuture<PublishResponse> publishFuture : outstandingFutures.futures) { publishFuture.get(); } } catch (Exception e) { throw new RuntimeException(e); } } allOutstandingFutures.clear(); } /** Publish all the messages in a partition and store the Future's for each publish request. */ private void publishMessagesForPartition( String topic, Integer partition, List<PubsubMessage> messages) { // Get a map containing all futures per partition for the passed in topic. Map<Integer, OutstandingFuturesForPartition> outstandingFuturesForTopic = allOutstandingFutures.get(topic); if (outstandingFuturesForTopic == null) { outstandingFuturesForTopic = new HashMap<>(); allOutstandingFutures.put(topic, outstandingFuturesForTopic); } // Get the object containing the outstanding futures for this topic and partition.. OutstandingFuturesForPartition outstandingFutures = outstandingFuturesForTopic.get(partition); if (outstandingFutures == null) { outstandingFutures = new OutstandingFuturesForPartition(); outstandingFuturesForTopic.put(partition, outstandingFutures); } int startIndex = 0; int endIndex = Math.min(CPS_MAX_MESSAGES_PER_REQUEST, messages.size()); PublishRequest.Builder builder = PublishRequest.newBuilder(); // Publish all the messages for this partition in batches. while (startIndex < messages.size()) { PublishRequest request = builder.setTopic(cpsTopic).addAllMessages(messages.subList(startIndex, endIndex)).build(); builder.clear(); log.trace("Publishing: " + (endIndex - startIndex) + " messages"); outstandingFutures.futures.add(publisher.publish(request)); startIndex = endIndex; endIndex = Math.min(endIndex + CPS_MAX_MESSAGES_PER_REQUEST, messages.size()); } messages.clear(); } @Override public void stop() {} }