package com.dianping.pigeon.remoting.common.codec.thrift; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.apache.thrift.TApplicationException; import org.apache.thrift.protocol.TMessage; import org.apache.thrift.protocol.TMessageType; import org.apache.thrift.protocol.TProtocol; import com.dianping.pigeon.remoting.common.codec.thrift.annotation.ThriftClientMetadata; import com.dianping.pigeon.remoting.common.codec.thrift.annotation.ThriftMethodHandler; import com.dianping.pigeon.remoting.common.codec.thrift.annotation.ThriftMethodProcessor; import com.dianping.pigeon.remoting.common.codec.thrift.annotation.ThriftServerMetadata; import com.dianping.pigeon.remoting.common.domain.generic.GenericRequest; import com.dianping.pigeon.remoting.common.domain.generic.GenericResponse; import com.dianping.pigeon.remoting.common.domain.generic.StatusCode; import com.dianping.pigeon.remoting.common.domain.generic.ThriftMapper; import com.dianping.pigeon.remoting.common.domain.generic.thrift.Header; import com.dianping.pigeon.remoting.common.exception.SerializationException; import com.dianping.pigeon.util.ClassUtils; /** * @author qi.yin 2016/05/23 下午4:28. */ public class AnnotationThriftSerializer extends AbstractThriftSerializer { private ConcurrentMap<String, ThriftClientMetadata> clientMetadatas = new ConcurrentHashMap<String, ThriftClientMetadata>(); private ConcurrentMap<String, ThriftServerMetadata> serverMetadatas = new ConcurrentHashMap<String, ThriftServerMetadata>(); @Override protected void doDeserializeRequest(GenericRequest request, TProtocol protocol) throws Exception { // body TMessage message = protocol.readMessageBegin(); ThriftMethodProcessor methodProcessor = getMethodProcessor(request.getServiceInterface().getName(), message.name); if (methodProcessor == null) { throw new SerializationException("@ThriftMethod annotation is required for " + request.getServiceInterface().getName() + "#" + message.name); } Object[] parameters = methodProcessor.readArguments(protocol); request.setSeqId(message.seqid); request.setMethodName(message.name); request.setParameters(parameters); protocol.readMessageEnd(); } protected void doSerializeRequest(GenericRequest request, TProtocol protocol) throws Exception { ThriftMethodHandler methodHandler = getMethodHandler(request.getServiceInterface().getName(), request.getMethodName()); if (methodHandler == null) { throw new SerializationException("@ThriftMethod annotation is required for " + request.getServiceInterface().getName() + "#" + request.getMethodName()); } // body methodHandler.writeArguments(protocol, getSequenceId(), request.getParameters()); } public void doDeserializeResponse(GenericResponse response, GenericRequest request, TProtocol protocol, Header header) throws Exception { // body TMessage message = protocol.readMessageBegin(); ThriftMethodHandler methodHandler = getMethodHandler(request.getServiceInterface().getName(), message.name); response.setSeqId(message.seqid); // body if (message.type == TMessageType.REPLY) { Object result = null; try { result = methodHandler.readResponse(protocol); } catch (TApplicationException e) { if (e.getType() == TApplicationException.MISSING_RESULT) { result = null; } } response.setReturn(result); } else if (message.type == TMessageType.EXCEPTION) { TApplicationException exception = TApplicationException.read(protocol); ThriftMapper.mapException(header, response, exception.getMessage()); } protocol.readMessageEnd(); } protected void doSerializeResponse(GenericResponse response, TProtocol protocol, Header header, DynamicByteArrayOutputStream bos) throws Exception { ThriftMethodProcessor methodProcessor = getMethodProcessor(response.getServiceInterface().getName(), response.getMethodName()); TApplicationException applicationException = null; TMessage message; boolean isUserException = false; if (response.hasException()) { if (methodProcessor.isUserException(response.getReturn())) { header.responseInfo.setStatus(StatusCode.ApplicationException.getCode()); isUserException = true; } else { applicationException = new TApplicationException(((Throwable) response.getReturn()).getMessage()); } } if (applicationException != null) { message = new TMessage(response.getMethodName(), TMessageType.EXCEPTION, response.getSeqId()); } else { message = new TMessage(response.getMethodName(), TMessageType.REPLY, response.getSeqId()); } // header header.write(protocol); short headerLength = (short) (bos.size() - HEADER_FIELD_LENGTH); protocol.writeMessageBegin(message); switch (message.type) { case TMessageType.EXCEPTION: applicationException.write(protocol); break; case TMessageType.REPLY: methodProcessor.writeResponse(protocol, response.getReturn(), isUserException); break; } protocol.writeMessageEnd(); protocol.getTransport().flush(); int messageLength = bos.size(); try { bos.setWriteIndex(0); protocol.writeI16(headerLength); } finally { bos.setWriteIndex(messageLength); } } private ThriftMethodProcessor getMethodProcessor(String serviceName, String methodName) throws ClassNotFoundException { ThriftServerMetadata serverMetadata = serverMetadatas.get(serviceName); if (serverMetadata == null) { Class<?> serverType = ClassUtils.loadClass(serviceName); serverMetadata = new ThriftServerMetadata(serverType, serviceName); serverMetadatas.putIfAbsent(serviceName, serverMetadata); } return serverMetadata.getMethodProcessor(methodName); } private ThriftMethodHandler getMethodHandler(String serviceName, String methodName) throws ClassNotFoundException { ThriftClientMetadata clientMetadata = clientMetadatas.get(serviceName); if (clientMetadata == null) { Class<?> serverType = ClassUtils.loadClass(serviceName); clientMetadata = new ThriftClientMetadata(serverType, serviceName); clientMetadatas.putIfAbsent(serviceName, clientMetadata); } return clientMetadata.getMethodHandler(methodName); } }