/* * Copyright 2009-2016 Weibo, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.weibo.api.motan.codec; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.HashMap; import java.util.Map; import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.weibo.api.motan.common.MotanConstants; import com.weibo.api.motan.common.URLParamType; import com.weibo.api.motan.core.extension.ExtensionLoader; import com.weibo.api.motan.core.extension.SpiMeta; import com.weibo.api.motan.exception.MotanErrorMsgConstant; import com.weibo.api.motan.exception.MotanFrameworkException; import com.weibo.api.motan.protocol.rpc.RpcProtocolVersion; import com.weibo.api.motan.rpc.DefaultRequest; import com.weibo.api.motan.rpc.DefaultResponse; import com.weibo.api.motan.rpc.Request; import com.weibo.api.motan.rpc.Response; import com.weibo.api.motan.transport.Channel; import com.weibo.api.motan.util.ByteUtil; import com.weibo.api.motan.util.ExceptionUtil; import com.weibo.api.motan.util.ReflectUtil; /** * protobuf2/3兼容codec,序列化时不允许attachments中有键或值为null * * @author zhouhaocheng * */ @SpiMeta(name = "protobuf") public class ProtobufCodec implements Codec { private static final short MAGIC = (short) 0xF0F0; private static final byte MASK = 0x07; @Override public byte[] encode(Channel channel, Object message) throws IOException { try { if (message instanceof Request) { return encodeRequest(channel, (Request) message); } else if (message instanceof Response) { return encodeResponse(channel, (Response) message); } } catch (Exception e) { if (ExceptionUtil.isMotanException(e)) { throw (RuntimeException) e; } else { throw new MotanFrameworkException("encode error: isResponse=" + (message instanceof Response), e, MotanErrorMsgConstant.FRAMEWORK_ENCODE_ERROR); } } throw new MotanFrameworkException("encode error: message type not support, " + message.getClass(), MotanErrorMsgConstant.FRAMEWORK_ENCODE_ERROR); } /** * request body 数据: * * <pre> * * body: * * byte[] data : * * serialize(interface_name, method_name, method_param_desc, method_param_value, attachments_size, attachments_value) * * method_param_desc: for_each (string.append(method_param_interface_name)) * * method_param_value: for_each (method_param_name, method_param_value) * * attachments_value: for_each (attachment_name, attachment_value) * * </pre> * * @param request * @return * @throws IOException */ private byte[] encodeRequest(Channel channel, Request request) throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CodedOutputStream output = CodedOutputStream.newInstance(outputStream); output.writeStringNoTag(request.getInterfaceName()); output.writeStringNoTag(request.getMethodName()); output.writeStringNoTag(request.getParamtersDesc()); Serialization serialization = ExtensionLoader.getExtensionLoader(Serialization.class).getExtension( channel.getUrl().getParameter(URLParamType.serialize.getName(), URLParamType.serialize.getValue())); if (request.getArguments() != null && request.getArguments().length > 0) { for (Object obj : request.getArguments()) { output.writeByteArrayNoTag(serialization.serialize(obj)); } } if (request.getAttachments() == null || request.getAttachments().isEmpty()) { // empty attachments output.writeUInt32NoTag(0); } else { output.writeUInt32NoTag(request.getAttachments().size()); for (Map.Entry<String, String> entry : request.getAttachments().entrySet()) { //此处不允许attachement键值为null output.writeStringNoTag(entry.getKey()); output.writeStringNoTag(entry.getValue()); } } output.flush(); byte[] body = outputStream.toByteArray(); byte flag = MotanConstants.FLAG_REQUEST; return encode(body, flag, request.getRequestId()); } /** * response body 数据: * * <pre> * * body: * * byte[] : serialize (result) or serialize (exception) * * </pre> * * @param channel * @param value * @return * @throws IOException */ private byte[] encodeResponse(Channel channel, Response value) throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); CodedOutputStream output = CodedOutputStream.newInstance(outputStream); Serialization serialization = ExtensionLoader.getExtensionLoader(Serialization.class).getExtension( channel.getUrl().getParameter(URLParamType.serialize.getName(), URLParamType.serialize.getValue())); byte flag = 0; output.writeUInt64NoTag(value.getProcessTime()); if (value.getException() != null) { output.writeStringNoTag(value.getException().getClass().getName()); output.writeByteArrayNoTag(serialization.serialize(value.getException())); flag = MotanConstants.FLAG_RESPONSE_EXCEPTION; } else if (value.getValue() == null) { flag = MotanConstants.FLAG_RESPONSE_VOID; } else { output.writeStringNoTag(value.getValue().getClass().getName()); output.writeByteArrayNoTag(serialization.serialize(value.getValue())); flag = MotanConstants.FLAG_RESPONSE; } output.flush(); byte[] body = outputStream.toByteArray(); return encode(body, flag, value.getRequestId()); } @Override public Object decode(Channel channel, String remoteIp, byte[] data) throws IOException { if (data.length <= RpcProtocolVersion.VERSION_1.getHeaderLength()) { throw new MotanFrameworkException("decode error: format problem", MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } short type = ByteUtil.bytes2short(data, 0); if (type != MAGIC) { throw new MotanFrameworkException("decode error: magic error", MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } if (data[2] != RpcProtocolVersion.VERSION_1.getVersion()) { throw new MotanFrameworkException("decode error: version error", MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } int bodyLength = ByteUtil.bytes2int(data, 12); if (RpcProtocolVersion.VERSION_1.getHeaderLength() + bodyLength != data.length) { throw new MotanFrameworkException("decode error: content length error", MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } byte flag = data[3]; byte dataType = (byte) (flag & MASK); boolean isResponse = (dataType != MotanConstants.FLAG_REQUEST); CodedInputStream body = CodedInputStream.newInstance(data, RpcProtocolVersion.VERSION_1.getHeaderLength(), bodyLength); long requestId = ByteUtil.bytes2long(data, 4); Serialization serialization = ExtensionLoader.getExtensionLoader(Serialization.class).getExtension( channel.getUrl().getParameter(URLParamType.serialize.getName(), URLParamType.serialize.getValue())); try { if (isResponse) { // response return decodeResponse(body, dataType, requestId, serialization); } else { return decodeRequest(body, requestId, serialization); } } catch (ClassNotFoundException e) { throw new MotanFrameworkException( "decode " + (isResponse ? "response" : "request") + " error: class not found", e, MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } catch (Exception e) { if (ExceptionUtil.isMotanException(e)) { throw (RuntimeException) e; } else { throw new MotanFrameworkException("decode error: isResponse=" + isResponse, e, MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } } } private Object decodeRequest(CodedInputStream input, long requestId, Serialization serialization) throws IOException, ClassNotFoundException { String interfaceName = input.readString(); String methodName = input.readString(); String paramtersDesc = input.readString(); DefaultRequest rpcRequest = new DefaultRequest(); rpcRequest.setRequestId(requestId); rpcRequest.setInterfaceName(interfaceName); rpcRequest.setMethodName(methodName); rpcRequest.setParamtersDesc(paramtersDesc); rpcRequest.setArguments(decodeRequestParameter(input, paramtersDesc, serialization)); rpcRequest.setAttachments(decodeRequestAttachments(input)); return rpcRequest; } private Object[] decodeRequestParameter(CodedInputStream input, String parameterDesc, Serialization serialization) throws IOException, ClassNotFoundException { if (parameterDesc == null || parameterDesc.equals("")) { return null; } Class<?>[] classTypes = ReflectUtil.forNames(parameterDesc); Object[] paramObjs = new Object[classTypes.length]; for (int i = 0; i < classTypes.length; i++) { paramObjs[i] = serialization.deserialize(input.readByteArray(), classTypes[i]); } return paramObjs; } private Map<String, String> decodeRequestAttachments(CodedInputStream input) throws IOException, ClassNotFoundException { int size = input.readUInt32(); if (size <= 0) { return null; } Map<String, String> attachments = new HashMap<String, String>(); for (int i = 0; i < size; i++) { attachments.put(input.readString(), input.readString()); } return attachments; } private Object decodeResponse(CodedInputStream input, byte dataType, long requestId, Serialization serialization) throws IOException, ClassNotFoundException { long processTime = input.readInt64(); DefaultResponse response = new DefaultResponse(); response.setRequestId(requestId); response.setProcessTime(processTime); if (dataType == MotanConstants.FLAG_RESPONSE_VOID) { return response; } String className = input.readString(); Class<?> clz = ReflectUtil.forName(className); Object result = serialization.deserialize(input.readByteArray(), clz); if (dataType == MotanConstants.FLAG_RESPONSE) { response.setValue(result); } else if (dataType == MotanConstants.FLAG_RESPONSE_EXCEPTION) { response.setException((Exception) result); } else { throw new MotanFrameworkException("decode error: response dataType not support " + dataType, MotanErrorMsgConstant.FRAMEWORK_DECODE_ERROR); } response.setRequestId(requestId); return response; } /** * 数据协议: * * <pre> * * header: 16个字节 * * 0-15 bit : magic * 16-23 bit : version * 24-31 bit : extend flag , 其中: 29-30 bit: event 可支持4种event,比如normal, exception等, 31 bit : 0 is request , 1 is response * 32-95 bit : request id * 96-127 bit : body content length * * </pre> * * @param body * @param flag * @param requestId * @return * @throws IOException */ private byte[] encode(byte[] body, byte flag, long requestId) throws IOException { byte[] header = new byte[RpcProtocolVersion.VERSION_1.getHeaderLength()]; int offset = 0; // 0 - 15 bit : magic ByteUtil.short2bytes(MAGIC, header, offset); offset += 2; // 16 - 23 bit : version header[offset++] = RpcProtocolVersion.VERSION_1.getVersion(); // 24 - 31 bit : extend flag header[offset++] = flag; // 32 - 95 bit : requestId ByteUtil.long2bytes(requestId, header, offset); offset += 8; // 96 - 127 bit : body content length ByteUtil.int2bytes(body.length, header, offset); byte[] data = new byte[header.length + body.length]; System.arraycopy(header, 0, data, 0, header.length); System.arraycopy(body, 0, data, header.length, body.length); return data; } }