/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.cassandra.transport; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.CharacterCodingException; import java.nio.charset.CharsetDecoder; import java.nio.charset.CoderResult; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.PooledByteBufAllocator; import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.util.CharsetUtil; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.config.Config; import org.apache.cassandra.db.ConsistencyLevel; import org.apache.cassandra.db.TypeSizes; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.Pair; import org.apache.cassandra.utils.UUIDGen; /** * ByteBuf utility methods. * Note that contrarily to ByteBufferUtil, these method do "read" the * ByteBuf advancing its (read) position. They also write by * advancing the write position. Functions are also provided to create * ByteBuf while avoiding copies. */ public abstract class CBUtil { public static final boolean USE_HEAP_ALLOCATOR = Boolean.getBoolean(Config.PROPERTY_PREFIX + "netty_use_heap_allocator"); public static final ByteBufAllocator allocator = USE_HEAP_ALLOCATOR ? new UnpooledByteBufAllocator(false) : new PooledByteBufAllocator(true); private final static FastThreadLocal<CharsetDecoder> TL_UTF8_DECODER = new FastThreadLocal<CharsetDecoder>() { @Override protected CharsetDecoder initialValue() { return StandardCharsets.UTF_8.newDecoder(); } }; private final static FastThreadLocal<CharBuffer> TL_CHAR_BUFFER = new FastThreadLocal<>(); private CBUtil() {} // Taken from Netty's ChannelBuffers.decodeString(). We need to use our own decoder to properly handle invalid // UTF-8 sequences. See CASSANDRA-8101 for more details. This can be removed once https://github.com/netty/netty/pull/2999 // is resolved in a release used by Cassandra. private static String decodeString(ByteBuffer src) throws CharacterCodingException { // the decoder needs to be reset every time we use it, hence the copy per thread CharsetDecoder theDecoder = TL_UTF8_DECODER.get(); theDecoder.reset(); CharBuffer dst = TL_CHAR_BUFFER.get(); int capacity = (int) ((double) src.remaining() * theDecoder.maxCharsPerByte()); if (dst == null) { capacity = Math.max(capacity, 4096); dst = CharBuffer.allocate(capacity); TL_CHAR_BUFFER.set(dst); } else { dst.clear(); if (dst.capacity() < capacity) { dst = CharBuffer.allocate(capacity); TL_CHAR_BUFFER.set(dst); } } CoderResult cr = theDecoder.decode(src, dst, true); if (!cr.isUnderflow()) cr.throwException(); return dst.flip().toString(); } private static String readString(ByteBuf cb, int length) { if (length == 0) return ""; ByteBuffer buffer = cb.nioBuffer(cb.readerIndex(), length); try { String str = decodeString(buffer); cb.readerIndex(cb.readerIndex() + length); return str; } catch (IllegalStateException | CharacterCodingException e) { throw new ProtocolException("Cannot decode string as UTF8: '" + ByteBufferUtil.bytesToHex(buffer) + "'; " + e); } } public static String readString(ByteBuf cb) { try { int length = cb.readUnsignedShort(); return readString(cb, length); } catch (IndexOutOfBoundsException e) { throw new ProtocolException("Not enough bytes to read an UTF8 serialized string preceded by its 2 bytes length"); } } public static void writeString(String str, ByteBuf cb) { int writerIndex = cb.writerIndex(); cb.writeShort(0); int lengthBytes = ByteBufUtil.writeUtf8(cb, str); cb.setShort(writerIndex, lengthBytes); } public static int sizeOfString(String str) { return 2 + TypeSizes.encodedUTF8Length(str); } public static String readLongString(ByteBuf cb) { try { int length = cb.readInt(); return readString(cb, length); } catch (IndexOutOfBoundsException e) { throw new ProtocolException("Not enough bytes to read an UTF8 serialized string preceded by its 4 bytes length"); } } public static void writeLongString(String str, ByteBuf cb) { byte[] bytes = str.getBytes(CharsetUtil.UTF_8); cb.writeInt(bytes.length); cb.writeBytes(bytes); } public static int sizeOfLongString(String str) { return 4 + str.getBytes(CharsetUtil.UTF_8).length; } public static byte[] readBytes(ByteBuf cb) { try { int length = cb.readUnsignedShort(); byte[] bytes = new byte[length]; cb.readBytes(bytes); return bytes; } catch (IndexOutOfBoundsException e) { throw new ProtocolException("Not enough bytes to read a byte array preceded by its 2 bytes length"); } } public static void writeBytes(byte[] bytes, ByteBuf cb) { cb.writeShort(bytes.length); cb.writeBytes(bytes); } public static int sizeOfBytes(byte[] bytes) { return 2 + bytes.length; } public static Map<String, ByteBuffer> readBytesMap(ByteBuf cb) { int length = cb.readUnsignedShort(); Map<String, ByteBuffer> m = new HashMap<>(length); for (int i = 0; i < length; i++) { String k = readString(cb); ByteBuffer v = readValue(cb); m.put(k, v); } return m; } public static void writeBytesMap(Map<String, ByteBuffer> m, ByteBuf cb) { cb.writeShort(m.size()); for (Map.Entry<String, ByteBuffer> entry : m.entrySet()) { writeString(entry.getKey(), cb); writeValue(entry.getValue(), cb); } } public static int sizeOfBytesMap(Map<String, ByteBuffer> m) { int size = 2; for (Map.Entry<String, ByteBuffer> entry : m.entrySet()) { size += sizeOfString(entry.getKey()); size += sizeOfValue(entry.getValue()); } return size; } public static ConsistencyLevel readConsistencyLevel(ByteBuf cb) { return ConsistencyLevel.fromCode(cb.readUnsignedShort()); } public static void writeConsistencyLevel(ConsistencyLevel consistency, ByteBuf cb) { cb.writeShort(consistency.code); } public static int sizeOfConsistencyLevel(ConsistencyLevel consistency) { return 2; } public static <T extends Enum<T>> T readEnumValue(Class<T> enumType, ByteBuf cb) { String value = CBUtil.readString(cb); try { return Enum.valueOf(enumType, value.toUpperCase()); } catch (IllegalArgumentException e) { throw new ProtocolException(String.format("Invalid value '%s' for %s", value, enumType.getSimpleName())); } } public static <T extends Enum<T>> void writeEnumValue(T enumValue, ByteBuf cb) { writeString(enumValue.toString(), cb); } public static <T extends Enum<T>> int sizeOfEnumValue(T enumValue) { return sizeOfString(enumValue.toString()); } public static UUID readUUID(ByteBuf cb) { byte[] bytes = new byte[16]; cb.readBytes(bytes); return UUIDGen.getUUID(ByteBuffer.wrap(bytes)); } public static void writeUUID(UUID uuid, ByteBuf cb) { cb.writeBytes(UUIDGen.decompose(uuid)); } public static int sizeOfUUID(UUID uuid) { return 16; } public static List<String> readStringList(ByteBuf cb) { int length = cb.readUnsignedShort(); List<String> l = new ArrayList<String>(length); for (int i = 0; i < length; i++) l.add(readString(cb)); return l; } public static void writeStringList(List<String> l, ByteBuf cb) { cb.writeShort(l.size()); for (String str : l) writeString(str, cb); } public static int sizeOfStringList(List<String> l) { int size = 2; for (String str : l) size += sizeOfString(str); return size; } public static Map<String, String> readStringMap(ByteBuf cb) { int length = cb.readUnsignedShort(); Map<String, String> m = new HashMap<String, String>(length); for (int i = 0; i < length; i++) { String k = readString(cb); String v = readString(cb); m.put(k, v); } return m; } public static void writeStringMap(Map<String, String> m, ByteBuf cb) { cb.writeShort(m.size()); for (Map.Entry<String, String> entry : m.entrySet()) { writeString(entry.getKey(), cb); writeString(entry.getValue(), cb); } } public static int sizeOfStringMap(Map<String, String> m) { int size = 2; for (Map.Entry<String, String> entry : m.entrySet()) { size += sizeOfString(entry.getKey()); size += sizeOfString(entry.getValue()); } return size; } public static Map<String, List<String>> readStringToStringListMap(ByteBuf cb) { int length = cb.readUnsignedShort(); Map<String, List<String>> m = new HashMap<String, List<String>>(length); for (int i = 0; i < length; i++) { String k = readString(cb).toUpperCase(); List<String> v = readStringList(cb); m.put(k, v); } return m; } public static void writeStringToStringListMap(Map<String, List<String>> m, ByteBuf cb) { cb.writeShort(m.size()); for (Map.Entry<String, List<String>> entry : m.entrySet()) { writeString(entry.getKey(), cb); writeStringList(entry.getValue(), cb); } } public static int sizeOfStringToStringListMap(Map<String, List<String>> m) { int size = 2; for (Map.Entry<String, List<String>> entry : m.entrySet()) { size += sizeOfString(entry.getKey()); size += sizeOfStringList(entry.getValue()); } return size; } public static ByteBuffer readValue(ByteBuf cb) { int length = cb.readInt(); if (length < 0) return null; ByteBuf slice = cb.readSlice(length); return ByteBuffer.wrap(readRawBytes(slice)); } public static ByteBuffer readBoundValue(ByteBuf cb, ProtocolVersion protocolVersion) { int length = cb.readInt(); if (length < 0) { if (protocolVersion.isSmallerThan(ProtocolVersion.V4)) // backward compatibility for pre-version 4 return null; if (length == -1) return null; else if (length == -2) return ByteBufferUtil.UNSET_BYTE_BUFFER; else throw new ProtocolException("Invalid ByteBuf length " + length); } ByteBuf slice = cb.readSlice(length); return ByteBuffer.wrap(readRawBytes(slice)); } public static void writeValue(byte[] bytes, ByteBuf cb) { if (bytes == null) { cb.writeInt(-1); return; } cb.writeInt(bytes.length); cb.writeBytes(bytes); } public static void writeValue(ByteBuffer bytes, ByteBuf cb) { if (bytes == null) { cb.writeInt(-1); return; } int remaining = bytes.remaining(); cb.writeInt(remaining); if (remaining > 0) cb.writeBytes(bytes.duplicate()); } public static int sizeOfValue(byte[] bytes) { return 4 + (bytes == null ? 0 : bytes.length); } public static int sizeOfValue(ByteBuffer bytes) { return 4 + (bytes == null ? 0 : bytes.remaining()); } // The size of serializing a value given the size (in bytes) of said value. The provided size can be negative // to indicate that the value is null. public static int sizeOfValue(int valueSize) { return 4 + (valueSize < 0 ? 0 : valueSize); } public static List<ByteBuffer> readValueList(ByteBuf cb, ProtocolVersion protocolVersion) { int size = cb.readUnsignedShort(); if (size == 0) return Collections.<ByteBuffer>emptyList(); List<ByteBuffer> l = new ArrayList<ByteBuffer>(size); for (int i = 0; i < size; i++) l.add(readBoundValue(cb, protocolVersion)); return l; } public static void writeValueList(List<ByteBuffer> values, ByteBuf cb) { cb.writeShort(values.size()); for (ByteBuffer value : values) CBUtil.writeValue(value, cb); } public static int sizeOfValueList(List<ByteBuffer> values) { int size = 2; for (ByteBuffer value : values) size += CBUtil.sizeOfValue(value); return size; } public static Pair<List<String>, List<ByteBuffer>> readNameAndValueList(ByteBuf cb, ProtocolVersion protocolVersion) { int size = cb.readUnsignedShort(); if (size == 0) return Pair.create(Collections.<String>emptyList(), Collections.<ByteBuffer>emptyList()); List<String> s = new ArrayList<>(size); List<ByteBuffer> l = new ArrayList<>(size); for (int i = 0; i < size; i++) { s.add(readString(cb)); l.add(readBoundValue(cb, protocolVersion)); } return Pair.create(s, l); } public static InetSocketAddress readInet(ByteBuf cb) { int addrSize = cb.readByte() & 0xFF; byte[] address = new byte[addrSize]; cb.readBytes(address); int port = cb.readInt(); try { return new InetSocketAddress(InetAddress.getByAddress(address), port); } catch (UnknownHostException e) { throw new ProtocolException(String.format("Invalid IP address (%d.%d.%d.%d) while deserializing inet address", address[0], address[1], address[2], address[3])); } } public static void writeInet(InetSocketAddress inet, ByteBuf cb) { byte[] address = inet.getAddress().getAddress(); cb.writeByte(address.length); cb.writeBytes(address); cb.writeInt(inet.getPort()); } public static int sizeOfInet(InetSocketAddress inet) { byte[] address = inet.getAddress().getAddress(); return 1 + address.length + 4; } public static InetAddress readInetAddr(ByteBuf cb) { int addressSize = cb.readByte() & 0xFF; byte[] address = new byte[addressSize]; cb.readBytes(address); try { return InetAddress.getByAddress(address); } catch (UnknownHostException e) { throw new ProtocolException("Invalid IP address while deserializing inet address"); } } public static void writeInetAddr(InetAddress inetAddr, ByteBuf cb) { byte[] address = inetAddr.getAddress(); cb.writeByte(address.length); cb.writeBytes(address); } public static int sizeOfInetAddr(InetAddress inetAddr) { return 1 + inetAddr.getAddress().length; } /* * Reads *all* readable bytes from {@code cb} and return them. */ public static byte[] readRawBytes(ByteBuf cb) { byte[] bytes = new byte[cb.readableBytes()]; cb.readBytes(bytes); return bytes; } }