/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.qubole.presto.kinesis; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.qubole.presto.kinesis.decoder.KinesisFieldDecoder; import com.qubole.presto.kinesis.decoder.KinesisRowDecoder; import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import java.nio.ByteBuffer; import java.util.Date; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient; import com.amazonaws.services.kinesis.model.GetRecordsRequest; import com.amazonaws.services.kinesis.model.GetRecordsResult; import com.amazonaws.services.kinesis.model.GetShardIteratorRequest; import com.amazonaws.services.kinesis.model.GetShardIteratorResult; import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kinesis.model.ResourceNotFoundException; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.type.Type; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; public class KinesisRecordSet implements RecordSet { /** Indicates how close to current we want to be before stopping the fetch of records in a query. */ public static final int MILLIS_BEHIND_LIMIT = 10000; private static final Logger log = Logger.get(KinesisRecordSet.class); private static final byte [] EMPTY_BYTE_ARRAY = new byte [0]; private final KinesisSplit split; private final ConnectorSession session; private final KinesisClientProvider clientManager; private final KinesisConnectorConfig kinesisConnectorConfig; private final KinesisRowDecoder messageDecoder; private final Map<KinesisColumnHandle, KinesisFieldDecoder<?>> messageFieldDecoders; private final List<KinesisColumnHandle> columnHandles; private final List<Type> columnTypes; private final int batchSize; private final int maxBatches; private final int fetchAttempts; private final long sleepTime; //for checkpointing private final boolean checkpointEnabled; private String lastReadSeqNo; private KinesisShardCheckpointer kinesisShardCheckpointer; private final Set<KinesisFieldValueProvider> globalInternalFieldValueProviders; KinesisRecordSet(KinesisSplit split, ConnectorSession session, KinesisClientProvider clientManager, List<KinesisColumnHandle> columnHandles, KinesisRowDecoder messageDecoder, Map<KinesisColumnHandle, KinesisFieldDecoder<?>> messageFieldDecoders, KinesisConnectorConfig kinesisConnectorConfig) { this.split = checkNotNull(split, "split is null"); this.session = checkNotNull(session, "session is null"); this.kinesisConnectorConfig = checkNotNull(kinesisConnectorConfig, "KinesisConnectorConfig is null"); this.globalInternalFieldValueProviders = ImmutableSet.of( KinesisInternalFieldDescription.SHARD_ID_FIELD.forByteValue(split.getShardId().getBytes()), KinesisInternalFieldDescription.SEGMENT_START_FIELD.forByteValue(split.getStart().getBytes())); this.clientManager = checkNotNull(clientManager, "clientManager is null"); this.messageDecoder = checkNotNull(messageDecoder, "rowDecoder is null"); this.messageFieldDecoders = checkNotNull(messageFieldDecoders, "messageFieldDecoders is null"); this.columnHandles = checkNotNull(columnHandles, "columnHandles is null"); ImmutableList.Builder<Type> typeBuilder = ImmutableList.builder(); for (KinesisColumnHandle handle : columnHandles) { typeBuilder.add(handle.getType()); } this.columnTypes = typeBuilder.build(); // Note: these default to what is in the configuration if not given in the session this.batchSize = SessionVariables.getBatchSize(this.session); this.maxBatches = SessionVariables.getMaxBatches(this.session); this.fetchAttempts = kinesisConnectorConfig.getFetchAttempts(); this.sleepTime = kinesisConnectorConfig.getSleepTime().toMillis(); this.checkpointEnabled = kinesisConnectorConfig.isCheckpointEnabled(); this.lastReadSeqNo = null; this.kinesisShardCheckpointer = null; checkpoint(); } public void checkpoint() { if (checkpointEnabled) { if (kinesisShardCheckpointer == null) { AmazonDynamoDBClient dynamoDBClient = clientManager.getDynamoDBClient(); long dynamoReadCapacity = kinesisConnectorConfig.getDynamoReadCapacity(); long dynamoWriteCapacity = kinesisConnectorConfig.getDynamoWriteCapacity(); long checkpointIntervalMs = kinesisConnectorConfig.getCheckpointIntervalMS().toMillis(); String logicalProcessName = kinesisConnectorConfig.getLogicalProcessName(); String dynamoDBTable = split.getStreamName(); int curIterationNumber = kinesisConnectorConfig.getIterationNumber(); String sessionIterationNo = SessionVariables.getSessionProperty(this.session, SessionVariables.ITERATION_NUMBER); String sessionLogicalName = SessionVariables.getSessionProperty(this.session, SessionVariables.CHECKPOINT_LOGICAL_NAME); if (sessionIterationNo != null) { curIterationNumber = Integer.parseInt(sessionIterationNo); } if (sessionLogicalName != null) { logicalProcessName = sessionLogicalName; } kinesisShardCheckpointer = new KinesisShardCheckpointer(dynamoDBClient, dynamoDBTable, split, logicalProcessName, curIterationNumber, checkpointIntervalMs, dynamoReadCapacity, dynamoWriteCapacity); lastReadSeqNo = kinesisShardCheckpointer.getLastReadSeqNo(); } } } @Override public List<Type> getColumnTypes() { return columnTypes; } @Override public RecordCursor cursor() { return new KinesisRecordCursor(); } public class KinesisRecordCursor implements RecordCursor { // TODO: total bytes here only includes records we iterate through, not total read from Kinesis. // This may not be an issue, but if total vs. completed is an important signal to Presto then // the implementation below could be a problem. Need to investigate. private long batchesRead = 0; private long messagesRead = 0; private long totalBytes = 0; private long totalMessages = 0; private long lastReadTime = 0; private String shardIterator; private KinesisFieldValueProvider[] fieldValueProviders; private List<Record> kinesisRecords; private Iterator<Record> listIterator; private GetRecordsRequest getRecordsRequest; private GetRecordsResult getRecordsResult; @Override public long getTotalBytes() { return totalBytes; } @Override public long getCompletedBytes() { return totalBytes; } @Override public long getReadTimeNanos() { return 0; } @Override public Type getType(int field) { checkArgument(field < columnHandles.size(), "Invalid field index"); return columnHandles.get(field).getType(); } /** * Advances the cursor by one position, retrieving more records from Kinesis if needed. * * We retrieve records from Kinesis in batches, using the getRecordsRequest. After a * getRecordsRequest we keep iterating through that list of records until we run out. Then * we will get another batch unless we've hit the limit or have caught up. * * @return */ @Override public boolean advanceNextPosition() { if (shardIterator == null && getRecordsRequest == null) { getIterator(); // first shard iterator log.info("Starting read. Retrieved first shard iterator from AWS Kinesis."); } if (getRecordsRequest == null || (!listIterator.hasNext() && shouldGetMoreRecords())) { getKinesisRecords(); } if (listIterator.hasNext()) { return nextRow(); } else { log.info("Read all of the records from the shard: %d batches and %d messages and %d total bytes.", batchesRead, totalMessages, totalBytes); return false; } } /** Determine whether or not to retrieve another batch of records from Kinesis. */ private boolean shouldGetMoreRecords() { return shardIterator != null && batchesRead < maxBatches && getMillisBehindLatest() > MILLIS_BEHIND_LIMIT; } /** * Retrieves the next batch of records from Kinesis using the shard iterator. * * Most of the time this results in one getRecords call. However we allow for * a call to return an empty list, and we'll try again if we are far enough * away from the latest record. */ private void getKinesisRecords() throws ResourceNotFoundException { // Normally this loop will execute once, but we have to allow for the odd Kinesis // behavior, per the docs: // A single call to getRecords might return an empty record list, even when the shard contains // more records at later sequence numbers boolean fetchedRecords = false; int attempts = 0; while (!fetchedRecords && attempts < fetchAttempts) { long now = System.currentTimeMillis(); if (now - lastReadTime <= sleepTime) { try { Thread.sleep(now - lastReadTime); } catch (InterruptedException e) { log.error("Sleep interrupted.", e); } } getRecordsRequest = new GetRecordsRequest(); getRecordsRequest.setShardIterator(shardIterator); getRecordsRequest.setLimit(batchSize); getRecordsResult = clientManager.getClient().getRecords(getRecordsRequest); lastReadTime = System.currentTimeMillis(); shardIterator = getRecordsResult.getNextShardIterator(); kinesisRecords = getRecordsResult.getRecords(); if (kinesisConnectorConfig.isLogBatches()) { log.info("Fetched %d records from Kinesis. MillisBehindLatest=%d", kinesisRecords.size(), getRecordsResult.getMillisBehindLatest()); } fetchedRecords = (kinesisRecords.size() > 0 || getMillisBehindLatest() <= MILLIS_BEHIND_LIMIT); attempts++; } listIterator = kinesisRecords.iterator(); batchesRead++; messagesRead += kinesisRecords.size(); } /** Working from the internal list, advance to the next row and decode it. */ private boolean nextRow() { Record currentRecord = listIterator.next(); String partitionKey = currentRecord.getPartitionKey(); log.debug("Reading record with partition key %s", partitionKey); byte[] messageData = EMPTY_BYTE_ARRAY; ByteBuffer message = currentRecord.getData(); if (message != null) { messageData = new byte[message.remaining()]; message.get(messageData); } totalBytes += messageData.length; totalMessages++; log.debug("Fetching %d bytes from current record. %d messages read so far", messageData.length, totalMessages); Set<KinesisFieldValueProvider> fieldValueProviders = new HashSet<>(); // Note: older version of SDK used in Presto doesn't support getApproximateArrivalTimestamp so can't get message timestamp! fieldValueProviders.addAll(globalInternalFieldValueProviders); fieldValueProviders.add(KinesisInternalFieldDescription.SEGMENT_COUNT_FIELD.forLongValue(totalMessages)); fieldValueProviders.add(KinesisInternalFieldDescription.SHARD_SEQUENCE_ID_FIELD.forByteValue(currentRecord.getSequenceNumber().getBytes())); fieldValueProviders.add(KinesisInternalFieldDescription.MESSAGE_FIELD.forByteValue(messageData)); fieldValueProviders.add(KinesisInternalFieldDescription.MESSAGE_LENGTH_FIELD.forLongValue(messageData.length)); fieldValueProviders.add(KinesisInternalFieldDescription.MESSAGE_TIMESTAMP.forLongValue(currentRecord.getApproximateArrivalTimestamp().getTime())); fieldValueProviders.add(KinesisInternalFieldDescription.MESSAGE_VALID_FIELD.forBooleanValue(messageDecoder.decodeRow(messageData, fieldValueProviders, columnHandles, messageFieldDecoders))); fieldValueProviders.add(KinesisInternalFieldDescription.PARTITION_KEY_FIELD.forByteValue(partitionKey.getBytes())); this.fieldValueProviders = new KinesisFieldValueProvider[columnHandles.size()]; for (int i = 0; i < columnHandles.size(); i++) { for (KinesisFieldValueProvider fieldValueProvider : fieldValueProviders) { if (fieldValueProvider.accept(columnHandles.get(i))) { this.fieldValueProviders[i] = fieldValueProvider; break; } } } lastReadSeqNo = currentRecord.getSequenceNumber(); if (checkpointEnabled) { kinesisShardCheckpointer.checkpointIfTimeUp(lastReadSeqNo); } return true; } /** Protect against possibly null values if this isn't set (not expected) */ private long getMillisBehindLatest() { if (getRecordsResult != null && getRecordsResult.getMillisBehindLatest() != null) { return getRecordsResult.getMillisBehindLatest(); } else { return MILLIS_BEHIND_LIMIT + 1; } } @Override public boolean getBoolean(int field) { checkArgument(field < columnHandles.size(), "Invalid field index"); checkFieldType(field, boolean.class); return !isNull(field) && fieldValueProviders[field].getBoolean(); } @Override public long getLong(int field) { checkArgument(field < columnHandles.size(), "Invalid field index"); checkFieldType(field, long.class); return isNull(field) ? 0L : fieldValueProviders[field].getLong(); } @Override public double getDouble(int field) { checkArgument(field < columnHandles.size(), "Invalid field index"); checkFieldType(field, double.class); return isNull(field) ? 0.0d : fieldValueProviders[field].getDouble(); } @Override public Slice getSlice(int field) { checkArgument(field < columnHandles.size(), "Invalid field index"); checkFieldType(field, Slice.class); return isNull(field) ? Slices.EMPTY_SLICE : fieldValueProviders[field].getSlice(); } @Override public Object getObject(int i) { // TODO: review if we want to support this throw new UnsupportedOperationException(); } @Override public boolean isNull(int field) { checkArgument(field < columnHandles.size(), "Invalid field index"); return fieldValueProviders[field] == null || fieldValueProviders[field].isNull(); } @Override public void close() { log.info("Closing cursor - read complete. Total read: %d batches %d messages, processed: %d messages and %d bytes.", batchesRead, messagesRead, totalMessages, totalBytes); if (checkpointEnabled && lastReadSeqNo != null) { kinesisShardCheckpointer.checkpoint(lastReadSeqNo); } } private void checkFieldType(int field, Class<?> expected) { Class<?> actual = getType(field).getJavaType(); checkArgument(actual == expected, "Expected field %s to be type %s but is %s", field, expected, actual); } private void getIterator() throws ResourceNotFoundException { GetShardIteratorRequest getShardIteratorRequest = new GetShardIteratorRequest(); getShardIteratorRequest.setStreamName(split.getStreamName()); getShardIteratorRequest.setShardId(split.getShardId()); // Explanation: when we have a sequence number from a prior read or checkpoint, always use it. // Otherwise, decide if starting at a timestamp or the trim horizon based on configuration. // If starting at a timestamp, sue the session variable ITER_START_TIMESTAMP when given, otherwise // fallback on starting at ITER_OFFSET_SECONDS from timestamp. if (lastReadSeqNo == null) { // Important: shard iterator type AT_TIMESTAMP requires 1.11.x or above of the AWS SDK. if (SessionVariables.getIterFromTimestamp(session)) { getShardIteratorRequest.setShardIteratorType("AT_TIMESTAMP"); long iterStartTs = SessionVariables.getIterStartTimestamp(session); if (iterStartTs == 0) { long startTs = System.currentTimeMillis() - (SessionVariables.getIterOffsetSeconds(session) * 1000); getShardIteratorRequest.setTimestamp(new Date(startTs)); } else { getShardIteratorRequest.setTimestamp(new Date(iterStartTs)); } } else { getShardIteratorRequest.setShardIteratorType("TRIM_HORIZON"); } } else { getShardIteratorRequest.setShardIteratorType("AFTER_SEQUENCE_NUMBER"); getShardIteratorRequest.setStartingSequenceNumber(lastReadSeqNo); } GetShardIteratorResult getShardIteratorResult = clientManager.getClient().getShardIterator(getShardIteratorRequest); shardIterator = getShardIteratorResult.getShardIterator(); } } }