package com.dianping.pigeon.remoting.common.codec.thrift.annotation; import com.facebook.swift.codec.ThriftCodec; import com.facebook.swift.codec.ThriftCodecManager; import com.facebook.swift.codec.internal.TProtocolReader; import com.facebook.swift.codec.internal.TProtocolWriter; import com.facebook.swift.codec.metadata.ThriftFieldMetadata; import com.facebook.swift.codec.metadata.ThriftType; import com.google.common.base.Defaults; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Primitives; import com.google.common.reflect.TypeToken; import org.apache.thrift.TApplicationException; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolException; import javax.annotation.concurrent.ThreadSafe; import java.lang.reflect.Method; import java.lang.reflect.Type; import java.util.Map; /** * @author qi.yin * 2016/05/23 下午6:07. */ @ThreadSafe public class ThriftMethodProcessor { private final String name; private final String serviceName; private final String qualifiedName; private final Object service; private final Method method; private final String resultStructName; private final boolean oneway; private final ImmutableList<ThriftFieldMetadata> parameters; private final Map<Short, ThriftCodec<?>> parameterCodecs; private final Map<Short, Short> thriftParameterIdToJavaArgumentListPositionMap; private final ThriftCodec<Object> successCodec; private final Map<Class<?>, ExceptionProcessor> exceptionCodecs; public ThriftMethodProcessor( Object service, String serviceName, ThriftMethodMetadata methodMetadata, ThriftCodecManager codecManager ) { this.service = service; this.serviceName = serviceName; name = methodMetadata.getName(); qualifiedName = serviceName + "." + name; resultStructName = name + "_result"; method = methodMetadata.getMethod(); oneway = methodMetadata.getOneway(); parameters = ImmutableList.copyOf(methodMetadata.getParameters()); ImmutableMap.Builder<Short, ThriftCodec<?>> builder = ImmutableMap.builder(); for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) { builder.put(fieldMetadata.getId(), codecManager.getCodec(fieldMetadata.getThriftType())); } parameterCodecs = builder.build(); ImmutableMap.Builder<Short, Short> parameterOrderingBuilder = ImmutableMap.builder(); short javaArgumentPosition = 0; for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) { parameterOrderingBuilder.put(fieldMetadata.getId(), javaArgumentPosition++); } thriftParameterIdToJavaArgumentListPositionMap = parameterOrderingBuilder.build(); ImmutableMap.Builder<Class<?>, ExceptionProcessor> exceptions = ImmutableMap.builder(); for (Map.Entry<Short, ThriftType> entry : methodMetadata.getExceptions().entrySet()) { Class<?> type = TypeToken.of(entry.getValue().getJavaType()).getRawType(); ExceptionProcessor processor = new ExceptionProcessor(entry.getKey(), codecManager.getCodec(entry.getValue())); exceptions.put(type, processor); } exceptionCodecs = exceptions.build(); successCodec = (ThriftCodec<Object>) codecManager.getCodec(methodMetadata.getReturnType()); } public String getName() { return name; } public Class<?> getServiceClass() { return service.getClass(); } public String getServiceName() { return serviceName; } public String getQualifiedName() { return qualifiedName; } public Object[] readArguments(TProtocol in) throws Exception { try { int numArgs = method.getParameterTypes().length; Object[] args = new Object[numArgs]; TProtocolReader reader = new TProtocolReader(in); reader.readStructBegin(); while (reader.nextField()) { short fieldId = reader.getFieldId(); ThriftCodec<?> codec = parameterCodecs.get(fieldId); if (codec == null) { // unknown field reader.skipFieldData(); } else { args[thriftParameterIdToJavaArgumentListPositionMap.get(fieldId)] = reader.readField(codec); } } reader.readStructEnd(); int argumentPosition = 0; for (ThriftFieldMetadata argument : parameters) { if (args[argumentPosition] == null) { Type argumentType = argument.getThriftType().getJavaType(); if (argumentType instanceof Class) { Class<?> argumentClass = (Class<?>) argumentType; argumentClass = Primitives.unwrap(argumentClass); args[argumentPosition] = Defaults.defaultValue(argumentClass); } } argumentPosition++; } return args; } catch (TProtocolException e) { throw new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage()); } } public <T> void writeResponse(TProtocol out, T result, boolean isException) throws Exception { if (!isException) { writeResponse(out, "success", (short) 0, successCodec, result); } else { writeExceptionResponse(out, result); } } public boolean isUserException(Object exception) { ExceptionProcessor exceptionCodec = exceptionCodecs.get(exception.getClass()); if (exceptionCodec != null) { return true; } return false; } protected <T> void writeExceptionResponse(TProtocol out, T exception) throws Exception { ExceptionProcessor exceptionCodec = exceptionCodecs.get(exception.getClass()); if (exceptionCodec != null) { writeResponse(out, "exception", exceptionCodec.getId(), exceptionCodec.getCodec(), exception); } } public <T> void writeResponse(TProtocol out, String responseFieldName, short responseFieldId, ThriftCodec<T> responseCodec, T result) throws Exception { TProtocolWriter writer = new TProtocolWriter(out); writer.writeStructBegin(resultStructName); writer.writeField(responseFieldName, (short) responseFieldId, responseCodec, result); writer.writeStructEnd(); } private static final class ExceptionProcessor { private final short id; private final ThriftCodec<Object> codec; private ExceptionProcessor(short id, ThriftCodec<?> coded) { this.id = id; this.codec = (ThriftCodec<Object>) coded; } public short getId() { return id; } public ThriftCodec<Object> getCodec() { return codec; } } }