/* * Copyright (C) 2012-2015 DataStax Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.datastax.driver.core; import com.datastax.driver.core.exceptions.DriverInternalError; import com.datastax.driver.core.exceptions.UnsupportedFeatureException; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageEncoder; import io.netty.util.AttributeKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; import java.util.*; /** * A message from the CQL binary protocol. */ abstract class Message { protected static final Logger logger = LoggerFactory.getLogger(Message.class); static AttributeKey<CodecRegistry> CODEC_REGISTRY_ATTRIBUTE_KEY = AttributeKey.valueOf("com.datastax.driver.core.CodecRegistry"); interface Coder<R extends Request> { void encode(R request, ByteBuf dest, ProtocolVersion version); int encodedSize(R request, ProtocolVersion version); } interface Decoder<R extends Response> { R decode(ByteBuf body, ProtocolVersion version, CodecRegistry codecRegistry); } private volatile int streamId = -1; /** * A generic key-value custom payload. Custom payloads are simply * ignored by the default QueryHandler implementation server-side. * * @since Protocol V4 */ private volatile Map<String, ByteBuffer> customPayload; protected Message() { } Message setStreamId(int streamId) { this.streamId = streamId; return this; } int getStreamId() { return streamId; } Map<String, ByteBuffer> getCustomPayload() { return customPayload; } Message setCustomPayload(Map<String, ByteBuffer> customPayload) { this.customPayload = customPayload; return this; } static abstract class Request extends Message { enum Type { STARTUP(1, Requests.Startup.coder), CREDENTIALS(4, Requests.Credentials.coder), OPTIONS(5, Requests.Options.coder), QUERY(7, Requests.Query.coder), PREPARE(9, Requests.Prepare.coder), EXECUTE(10, Requests.Execute.coder), REGISTER(11, Requests.Register.coder), BATCH(13, Requests.Batch.coder), AUTH_RESPONSE(15, Requests.AuthResponse.coder); final int opcode; final Coder<?> coder; Type(int opcode, Coder<?> coder) { this.opcode = opcode; this.coder = coder; } } final Type type; private final boolean tracingRequested; protected Request(Type type) { this(type, false); } protected Request(Type type, boolean tracingRequested) { this.type = type; this.tracingRequested = tracingRequested; } @Override Request setStreamId(int streamId) { // JAVA-1179: defensively guard against reusing the same Request object twice. // If no streamId was ever set we can use this object directly, otherwise make a copy. if (getStreamId() < 0) return (Request) super.setStreamId(streamId); else { Request copy = this.copy(); copy.setStreamId(streamId); return copy; } } boolean isTracingRequested() { return tracingRequested; } ConsistencyLevel consistency() { switch (this.type) { case QUERY: return ((Requests.Query) this).options.consistency; case EXECUTE: return ((Requests.Execute) this).options.consistency; case BATCH: return ((Requests.Batch) this).options.consistency; default: return null; } } ConsistencyLevel serialConsistency() { switch (this.type) { case QUERY: return ((Requests.Query) this).options.serialConsistency; case EXECUTE: return ((Requests.Execute) this).options.serialConsistency; case BATCH: return ((Requests.Batch) this).options.serialConsistency; default: return null; } } long defaultTimestamp() { switch (this.type) { case QUERY: return ((Requests.Query) this).options.defaultTimestamp; case EXECUTE: return ((Requests.Execute) this).options.defaultTimestamp; case BATCH: return ((Requests.Batch) this).options.defaultTimestamp; default: return 0; } } ByteBuffer pagingState() { switch (this.type) { case QUERY: return ((Requests.Query) this).options.pagingState; case EXECUTE: return ((Requests.Execute) this).options.pagingState; default: return null; } } Request copy() { Request request = copyInternal(); request.setCustomPayload(this.getCustomPayload()); return request; } protected abstract Request copyInternal(); Request copy(ConsistencyLevel newConsistencyLevel) { Request request = copyInternal(newConsistencyLevel); request.setCustomPayload(this.getCustomPayload()); return request; } protected Request copyInternal(ConsistencyLevel newConsistencyLevel) { throw new UnsupportedOperationException(); } } static abstract class Response extends Message { enum Type { ERROR(0, Responses.Error.decoder), READY(2, Responses.Ready.decoder), AUTHENTICATE(3, Responses.Authenticate.decoder), SUPPORTED(6, Responses.Supported.decoder), RESULT(8, Responses.Result.decoder), EVENT(12, Responses.Event.decoder), AUTH_CHALLENGE(14, Responses.AuthChallenge.decoder), AUTH_SUCCESS(16, Responses.AuthSuccess.decoder); final int opcode; final Decoder<?> decoder; private static final Type[] opcodeIdx; static { int maxOpcode = -1; for (Type type : Type.values()) maxOpcode = Math.max(maxOpcode, type.opcode); opcodeIdx = new Type[maxOpcode + 1]; for (Type type : Type.values()) { if (opcodeIdx[type.opcode] != null) throw new IllegalStateException("Duplicate opcode"); opcodeIdx[type.opcode] = type; } } Type(int opcode, Decoder<?> decoder) { this.opcode = opcode; this.decoder = decoder; } static Type fromOpcode(int opcode) { if (opcode < 0 || opcode >= opcodeIdx.length) throw new DriverInternalError(String.format("Unknown response opcode %d", opcode)); Type t = opcodeIdx[opcode]; if (t == null) throw new DriverInternalError(String.format("Unknown response opcode %d", opcode)); return t; } } final Type type; protected volatile UUID tracingId; protected volatile List<String> warnings; protected Response(Type type) { this.type = type; } Response setTracingId(UUID tracingId) { this.tracingId = tracingId; return this; } UUID getTracingId() { return tracingId; } Response setWarnings(List<String> warnings) { this.warnings = warnings; return this; } } @ChannelHandler.Sharable static class ProtocolDecoder extends MessageToMessageDecoder<Frame> { @Override protected void decode(ChannelHandlerContext ctx, Frame frame, List<Object> out) throws Exception { boolean isTracing = frame.header.flags.contains(Frame.Header.Flag.TRACING); boolean isCustomPayload = frame.header.flags.contains(Frame.Header.Flag.CUSTOM_PAYLOAD); UUID tracingId = isTracing ? CBUtil.readUUID(frame.body) : null; Map<String, ByteBuffer> customPayload = isCustomPayload ? CBUtil.readBytesMap(frame.body) : null; if (customPayload != null && logger.isTraceEnabled()) { logger.trace("Received payload: {} ({} bytes total)", printPayload(customPayload), CBUtil.sizeOfBytesMap(customPayload)); } boolean hasWarnings = frame.header.flags.contains(Frame.Header.Flag.WARNING); List<String> warnings = hasWarnings ? CBUtil.readStringList(frame.body) : Collections.<String>emptyList(); try { CodecRegistry codecRegistry = ctx.channel().attr(CODEC_REGISTRY_ATTRIBUTE_KEY).get(); assert codecRegistry != null; Response response = Response.Type.fromOpcode(frame.header.opcode).decoder.decode(frame.body, frame.header.version, codecRegistry); response .setTracingId(tracingId) .setWarnings(warnings) .setCustomPayload(customPayload) .setStreamId(frame.header.streamId); out.add(response); } finally { frame.body.release(); } } } @ChannelHandler.Sharable static class ProtocolEncoder extends MessageToMessageEncoder<Request> { private final ProtocolVersion protocolVersion; ProtocolEncoder(ProtocolVersion version) { this.protocolVersion = version; } @Override protected void encode(ChannelHandlerContext ctx, Request request, List<Object> out) throws Exception { EnumSet<Frame.Header.Flag> flags = EnumSet.noneOf(Frame.Header.Flag.class); if (request.isTracingRequested()) flags.add(Frame.Header.Flag.TRACING); if (protocolVersion == ProtocolVersion.NEWEST_BETA) flags.add(Frame.Header.Flag.USE_BETA); Map<String, ByteBuffer> customPayload = request.getCustomPayload(); if (customPayload != null) { if (protocolVersion.compareTo(ProtocolVersion.V4) < 0) throw new UnsupportedFeatureException( protocolVersion, "Custom payloads are only supported since native protocol V4"); flags.add(Frame.Header.Flag.CUSTOM_PAYLOAD); } @SuppressWarnings("unchecked") Coder<Request> coder = (Coder<Request>) request.type.coder; int messageSize = coder.encodedSize(request, protocolVersion); int payloadLength = -1; if (customPayload != null) { payloadLength = CBUtil.sizeOfBytesMap(customPayload); messageSize += payloadLength; } ByteBuf body = ctx.alloc().buffer(messageSize); if (customPayload != null) { CBUtil.writeBytesMap(customPayload, body); if (logger.isTraceEnabled()) { logger.trace("Sending payload: {} ({} bytes total)", printPayload(customPayload), payloadLength); } } coder.encode(request, body, protocolVersion); out.add(Frame.create(protocolVersion, request.type.opcode, request.getStreamId(), flags, body)); } } // private stuff to debug custom payloads private static final char[] hexArray = "0123456789ABCDEF".toCharArray(); static String printPayload(Map<String, ByteBuffer> customPayload) { if (customPayload == null) return "null"; if (customPayload.isEmpty()) return "{}"; StringBuilder sb = new StringBuilder("{"); Iterator<Map.Entry<String, ByteBuffer>> iterator = customPayload.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry<String, ByteBuffer> entry = iterator.next(); sb.append(entry.getKey()); sb.append(":"); if (entry.getValue() == null) sb.append("null"); else bytesToHex(entry.getValue(), sb); if (iterator.hasNext()) sb.append(", "); } sb.append("}"); return sb.toString(); } // this method doesn't modify the given ByteBuffer static void bytesToHex(ByteBuffer bytes, StringBuilder sb) { int length = Math.min(bytes.remaining(), 50); sb.append("0x"); for (int i = 0; i < length; i++) { int v = bytes.get(i) & 0xFF; sb.append(hexArray[v >>> 4]); sb.append(hexArray[v & 0x0F]); } if (bytes.remaining() > 50) sb.append("... [TRUNCATED]"); } }