/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.network.netty;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.MessageToMessageDecoder;
import org.apache.flink.runtime.event.TaskEvent;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.util.List;
/**
* A simple and generic interface to serialize messages to Netty's buffer space.
*/
abstract class NettyMessage {
// ------------------------------------------------------------------------
// Note: Every NettyMessage subtype needs to have a public 0-argument
// constructor in order to work with the generic deserializer.
// ------------------------------------------------------------------------
static final int HEADER_LENGTH = 4 + 4 + 1; // frame length (4), magic number (4), msg ID (1)
static final int MAGIC_NUMBER = 0xBADC0FFE;
abstract ByteBuf write(ByteBufAllocator allocator) throws Exception;
abstract void readFrom(ByteBuf buffer) throws Exception;
// ------------------------------------------------------------------------
private static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id) {
return allocateBuffer(allocator, id, 0);
}
private static ByteBuf allocateBuffer(ByteBufAllocator allocator, byte id, int length) {
final ByteBuf buffer = length != 0 ? allocator.directBuffer(HEADER_LENGTH + length) : allocator.directBuffer();
buffer.writeInt(HEADER_LENGTH + length);
buffer.writeInt(MAGIC_NUMBER);
buffer.writeByte(id);
return buffer;
}
// ------------------------------------------------------------------------
// Generic NettyMessage encoder and decoder
// ------------------------------------------------------------------------
@ChannelHandler.Sharable
static class NettyMessageEncoder extends ChannelOutboundHandlerAdapter {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof NettyMessage) {
ByteBuf serialized = null;
try {
serialized = ((NettyMessage) msg).write(ctx.alloc());
}
catch (Throwable t) {
throw new IOException("Error while serializing message: " + msg, t);
}
finally {
if (serialized != null) {
ctx.write(serialized, promise);
}
}
}
else {
ctx.write(msg, promise);
}
}
// Create the frame length decoder here as it depends on the encoder
//
// +------------------+------------------+--------++----------------+
// | FRAME LENGTH (4) | MAGIC NUMBER (4) | ID (1) || CUSTOM MESSAGE |
// +------------------+------------------+--------++----------------+
static LengthFieldBasedFrameDecoder createFrameLengthDecoder() {
return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, -4, 4);
}
}
@ChannelHandler.Sharable
static class NettyMessageDecoder extends MessageToMessageDecoder<ByteBuf> {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) throws Exception {
int magicNumber = msg.readInt();
if (magicNumber != MAGIC_NUMBER) {
throw new IllegalStateException("Network stream corrupted: received incorrect magic number.");
}
byte msgId = msg.readByte();
NettyMessage decodedMsg = null;
if (msgId == BufferResponse.ID) {
decodedMsg = new BufferResponse();
}
else if (msgId == PartitionRequest.ID) {
decodedMsg = new PartitionRequest();
}
else if (msgId == TaskEventRequest.ID) {
decodedMsg = new TaskEventRequest();
}
else if (msgId == ErrorResponse.ID) {
decodedMsg = new ErrorResponse();
}
else if (msgId == CancelPartitionRequest.ID) {
decodedMsg = new CancelPartitionRequest();
}
else if (msgId == CloseRequest.ID) {
decodedMsg = new CloseRequest();
}
else {
throw new IllegalStateException("Received unknown message from producer: " + msg);
}
if (decodedMsg != null) {
decodedMsg.readFrom(msg);
out.add(decodedMsg);
}
}
}
// ------------------------------------------------------------------------
// Server responses
// ------------------------------------------------------------------------
static class BufferResponse extends NettyMessage {
private static final byte ID = 0;
final Buffer buffer;
InputChannelID receiverId;
int sequenceNumber;
// ---- Deserialization -----------------------------------------------
boolean isBuffer;
int size;
ByteBuf retainedSlice;
public BufferResponse() {
// When deserializing we first have to request a buffer from the respective buffer
// provider (at the handler) and copy the buffer from Netty's space to ours.
buffer = null;
}
public BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId) {
this.buffer = buffer;
this.sequenceNumber = sequenceNumber;
this.receiverId = receiverId;
}
boolean isBuffer() {
return isBuffer;
}
int getSize() {
return size;
}
ByteBuf getNettyBuffer() {
return retainedSlice;
}
void releaseBuffer() {
if (retainedSlice != null) {
retainedSlice.release();
retainedSlice = null;
}
}
// --------------------------------------------------------------------
// Serialization
// --------------------------------------------------------------------
@Override
ByteBuf write(ByteBufAllocator allocator) throws IOException {
int length = 16 + 4 + 1 + 4 + buffer.getSize();
ByteBuf result = null;
try {
result = allocateBuffer(allocator, ID, length);
receiverId.writeTo(result);
result.writeInt(sequenceNumber);
result.writeBoolean(buffer.isBuffer());
result.writeInt(buffer.getSize());
result.writeBytes(buffer.getNioBuffer());
return result;
}
catch (Throwable t) {
if (result != null) {
result.release();
}
throw new IOException(t);
}
finally {
if (buffer != null) {
buffer.recycle();
}
}
}
@Override
void readFrom(ByteBuf buffer) {
receiverId = InputChannelID.fromByteBuf(buffer);
sequenceNumber = buffer.readInt();
isBuffer = buffer.readBoolean();
size = buffer.readInt();
retainedSlice = buffer.readSlice(size);
retainedSlice.retain();
}
}
static class ErrorResponse extends NettyMessage {
private static final byte ID = 1;
Throwable cause;
InputChannelID receiverId;
public ErrorResponse() {
}
ErrorResponse(Throwable cause) {
this.cause = cause;
}
ErrorResponse(Throwable cause, InputChannelID receiverId) {
this.cause = cause;
this.receiverId = receiverId;
}
boolean isFatalError() {
return receiverId == null;
}
@Override
ByteBuf write(ByteBufAllocator allocator) throws IOException {
final ByteBuf result = allocateBuffer(allocator, ID);
try (ObjectOutputStream oos = new ObjectOutputStream(new ByteBufOutputStream(result))) {
oos.writeObject(cause);
if (receiverId != null) {
result.writeBoolean(true);
receiverId.writeTo(result);
} else {
result.writeBoolean(false);
}
// Update frame length...
result.setInt(0, result.readableBytes());
return result;
}
catch (Throwable t) {
result.release();
if (t instanceof IOException) {
throw (IOException) t;
} else {
throw new IOException(t);
}
}
}
@Override
void readFrom(ByteBuf buffer) throws Exception {
try (ObjectInputStream ois = new ObjectInputStream(new ByteBufInputStream(buffer))) {
Object obj = ois.readObject();
if (!(obj instanceof Throwable)) {
throw new ClassCastException("Read object expected to be of type Throwable, " +
"actual type is " + obj.getClass() + ".");
} else {
cause = (Throwable) obj;
if (buffer.readBoolean()) {
receiverId = InputChannelID.fromByteBuf(buffer);
}
}
}
}
}
// ------------------------------------------------------------------------
// Client requests
// ------------------------------------------------------------------------
static class PartitionRequest extends NettyMessage {
final static byte ID = 2;
ResultPartitionID partitionId;
int queueIndex;
InputChannelID receiverId;
public PartitionRequest() {
}
PartitionRequest(ResultPartitionID partitionId, int queueIndex, InputChannelID receiverId) {
this.partitionId = partitionId;
this.queueIndex = queueIndex;
this.receiverId = receiverId;
}
@Override
ByteBuf write(ByteBufAllocator allocator) throws IOException {
ByteBuf result = null;
try {
result = allocateBuffer(allocator, ID, 16 + 16 + 4 + 16);
partitionId.getPartitionId().writeTo(result);
partitionId.getProducerId().writeTo(result);
result.writeInt(queueIndex);
receiverId.writeTo(result);
return result;
}
catch (Throwable t) {
if (result != null) {
result.release();
}
throw new IOException(t);
}
}
@Override
public void readFrom(ByteBuf buffer) {
partitionId = new ResultPartitionID(IntermediateResultPartitionID.fromByteBuf(buffer), ExecutionAttemptID.fromByteBuf(buffer));
queueIndex = buffer.readInt();
receiverId = InputChannelID.fromByteBuf(buffer);
}
@Override
public String toString() {
return String.format("PartitionRequest(%s:%d)", partitionId, queueIndex);
}
}
static class TaskEventRequest extends NettyMessage {
final static byte ID = 3;
TaskEvent event;
InputChannelID receiverId;
ResultPartitionID partitionId;
public TaskEventRequest() {
}
TaskEventRequest(TaskEvent event, ResultPartitionID partitionId, InputChannelID receiverId) {
this.event = event;
this.receiverId = receiverId;
this.partitionId = partitionId;
}
@Override
ByteBuf write(ByteBufAllocator allocator) throws IOException {
ByteBuf result = null;
try {
// TODO Directly serialize to Netty's buffer
ByteBuffer serializedEvent = EventSerializer.toSerializedEvent(event);
result = allocateBuffer(allocator, ID, 4 + serializedEvent.remaining() + 16 + 16 + 16);
result.writeInt(serializedEvent.remaining());
result.writeBytes(serializedEvent);
partitionId.getPartitionId().writeTo(result);
partitionId.getProducerId().writeTo(result);
receiverId.writeTo(result);
return result;
}
catch (Throwable t) {
if (result != null) {
result.release();
}
throw new IOException(t);
}
}
@Override
public void readFrom(ByteBuf buffer) throws IOException {
// TODO Directly deserialize fromNetty's buffer
int length = buffer.readInt();
ByteBuffer serializedEvent = ByteBuffer.allocate(length);
buffer.readBytes(serializedEvent);
serializedEvent.flip();
event = (TaskEvent) EventSerializer.fromSerializedEvent(serializedEvent, getClass().getClassLoader());
partitionId = new ResultPartitionID(IntermediateResultPartitionID.fromByteBuf(buffer), ExecutionAttemptID.fromByteBuf(buffer));
receiverId = InputChannelID.fromByteBuf(buffer);
}
}
/**
* Cancels the partition request of the {@link InputChannel} identified by
* {@link InputChannelID}.
*
* <p> There is a 1:1 mapping between the input channel and partition per physical channel.
* Therefore, the {@link InputChannelID} instance is enough to identify which request to cancel.
*/
static class CancelPartitionRequest extends NettyMessage {
final static byte ID = 4;
InputChannelID receiverId;
public CancelPartitionRequest() {
}
public CancelPartitionRequest(InputChannelID receiverId) {
this.receiverId = receiverId;
}
@Override
ByteBuf write(ByteBufAllocator allocator) throws Exception {
ByteBuf result = null;
try {
result = allocateBuffer(allocator, ID, 16);
receiverId.writeTo(result);
}
catch (Throwable t) {
if (result != null) {
result.release();
}
throw new IOException(t);
}
return result;
}
@Override
void readFrom(ByteBuf buffer) throws Exception {
receiverId = InputChannelID.fromByteBuf(buffer);
}
}
static class CloseRequest extends NettyMessage {
private static final byte ID = 5;
public CloseRequest() {
}
@Override
ByteBuf write(ByteBufAllocator allocator) throws Exception {
return allocateBuffer(allocator, ID, 0);
}
@Override
void readFrom(ByteBuf buffer) throws Exception {
}
}
}