/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hive.spark.client.rpc; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; import org.objenesis.strategy.StdInstantiatorStrategy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.ByteBufferInputStream; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageCodec; /** * Codec that serializes / deserializes objects using Kryo. Objects are encoded with a 4-byte * header with the length of the serialized data. */ class KryoMessageCodec extends ByteToMessageCodec<Object> { private static final Logger LOG = LoggerFactory.getLogger(KryoMessageCodec.class); // Kryo docs say 0-8 are taken. Strange things happen if you don't set an ID when registering // classes. private static final int REG_ID_BASE = 16; private final int maxMessageSize; private final List<Class<?>> messages; private final ThreadLocal<Kryo> kryos = new ThreadLocal<Kryo>() { @Override protected Kryo initialValue() { Kryo kryo = new Kryo(); int count = 0; for (Class<?> klass : messages) { kryo.register(klass, REG_ID_BASE + count); count++; } kryo.setInstantiatorStrategy(new Kryo.DefaultInstantiatorStrategy(new StdInstantiatorStrategy())); return kryo; } }; private volatile EncryptionHandler encryptionHandler; public KryoMessageCodec(int maxMessageSize, Class<?>... messages) { this.maxMessageSize = maxMessageSize; this.messages = Arrays.asList(messages); this.encryptionHandler = null; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { if (in.readableBytes() < 4) { return; } in.markReaderIndex(); int msgSize = in.readInt(); checkSize(msgSize); if (in.readableBytes() < msgSize) { // Incomplete message in buffer. in.resetReaderIndex(); return; } try { ByteBuffer nioBuffer = maybeDecrypt(in.nioBuffer(in.readerIndex(), msgSize)); Input kryoIn = new Input(new ByteBufferInputStream(nioBuffer)); Object msg = kryos.get().readClassAndObject(kryoIn); LOG.debug("Decoded message of type {} ({} bytes)", msg != null ? msg.getClass().getName() : msg, msgSize); out.add(msg); } finally { in.skipBytes(msgSize); } } @Override protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf buf) throws Exception { ByteArrayOutputStream bytes = new ByteArrayOutputStream(); Output kryoOut = new Output(bytes); kryos.get().writeClassAndObject(kryoOut, msg); kryoOut.flush(); byte[] msgData = maybeEncrypt(bytes.toByteArray()); LOG.debug("Encoded message of type {} ({} bytes)", msg.getClass().getName(), msgData.length); checkSize(msgData.length); buf.ensureWritable(msgData.length + 4); buf.writeInt(msgData.length); buf.writeBytes(msgData); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (encryptionHandler != null) { encryptionHandler.dispose(); } super.channelInactive(ctx); } private void checkSize(int msgSize) { Preconditions.checkArgument(msgSize > 0, "Message size (%s bytes) must be positive.", msgSize); Preconditions.checkArgument(maxMessageSize <= 0 || msgSize <= maxMessageSize, "Message (%s bytes) exceeds maximum allowed size (%s bytes).", msgSize, maxMessageSize); } private byte[] maybeEncrypt(byte[] data) throws Exception { return (encryptionHandler != null) ? encryptionHandler.wrap(data, 0, data.length) : data; } private ByteBuffer maybeDecrypt(ByteBuffer data) throws Exception { if (encryptionHandler != null) { byte[] encrypted; int len = data.limit() - data.position(); int offset; if (data.hasArray()) { encrypted = data.array(); offset = data.position() + data.arrayOffset(); data.position(data.limit()); } else { encrypted = new byte[len]; offset = 0; data.get(encrypted); } return ByteBuffer.wrap(encryptionHandler.unwrap(encrypted, offset, len)); } else { return data; } } void setEncryptionHandler(EncryptionHandler handler) { this.encryptionHandler = handler; } interface EncryptionHandler { byte[] wrap(byte[] data, int offset, int len) throws IOException; byte[] unwrap(byte[] data, int offset, int len) throws IOException; void dispose() throws IOException; } }