package com.dianping.pigeon.remoting.common.codec.thrift.annotation;
import com.facebook.swift.codec.ThriftCodec;
import com.facebook.swift.codec.ThriftCodecManager;
import com.facebook.swift.codec.internal.TProtocolReader;
import com.facebook.swift.codec.internal.TProtocolWriter;
import com.facebook.swift.codec.metadata.ThriftFieldMetadata;
import com.facebook.swift.codec.metadata.ThriftType;
import com.google.common.base.Defaults;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import com.google.common.reflect.TypeToken;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import javax.annotation.concurrent.ThreadSafe;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.Map;
/**
* @author qi.yin
* 2016/05/23 下午6:07.
*/
@ThreadSafe
public class ThriftMethodProcessor {
private final String name;
private final String serviceName;
private final String qualifiedName;
private final Object service;
private final Method method;
private final String resultStructName;
private final boolean oneway;
private final ImmutableList<ThriftFieldMetadata> parameters;
private final Map<Short, ThriftCodec<?>> parameterCodecs;
private final Map<Short, Short> thriftParameterIdToJavaArgumentListPositionMap;
private final ThriftCodec<Object> successCodec;
private final Map<Class<?>, ExceptionProcessor> exceptionCodecs;
public ThriftMethodProcessor(
Object service,
String serviceName,
ThriftMethodMetadata methodMetadata,
ThriftCodecManager codecManager
) {
this.service = service;
this.serviceName = serviceName;
name = methodMetadata.getName();
qualifiedName = serviceName + "." + name;
resultStructName = name + "_result";
method = methodMetadata.getMethod();
oneway = methodMetadata.getOneway();
parameters = ImmutableList.copyOf(methodMetadata.getParameters());
ImmutableMap.Builder<Short, ThriftCodec<?>> builder = ImmutableMap.builder();
for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) {
builder.put(fieldMetadata.getId(), codecManager.getCodec(fieldMetadata.getThriftType()));
}
parameterCodecs = builder.build();
ImmutableMap.Builder<Short, Short> parameterOrderingBuilder = ImmutableMap.builder();
short javaArgumentPosition = 0;
for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) {
parameterOrderingBuilder.put(fieldMetadata.getId(), javaArgumentPosition++);
}
thriftParameterIdToJavaArgumentListPositionMap = parameterOrderingBuilder.build();
ImmutableMap.Builder<Class<?>, ExceptionProcessor> exceptions = ImmutableMap.builder();
for (Map.Entry<Short, ThriftType> entry : methodMetadata.getExceptions().entrySet()) {
Class<?> type = TypeToken.of(entry.getValue().getJavaType()).getRawType();
ExceptionProcessor processor = new ExceptionProcessor(entry.getKey(), codecManager.getCodec(entry.getValue()));
exceptions.put(type, processor);
}
exceptionCodecs = exceptions.build();
successCodec = (ThriftCodec<Object>) codecManager.getCodec(methodMetadata.getReturnType());
}
public String getName() {
return name;
}
public Class<?> getServiceClass() {
return service.getClass();
}
public String getServiceName() {
return serviceName;
}
public String getQualifiedName() {
return qualifiedName;
}
public Object[] readArguments(TProtocol in)
throws Exception {
try {
int numArgs = method.getParameterTypes().length;
Object[] args = new Object[numArgs];
TProtocolReader reader = new TProtocolReader(in);
reader.readStructBegin();
while (reader.nextField()) {
short fieldId = reader.getFieldId();
ThriftCodec<?> codec = parameterCodecs.get(fieldId);
if (codec == null) {
// unknown field
reader.skipFieldData();
} else {
args[thriftParameterIdToJavaArgumentListPositionMap.get(fieldId)] = reader.readField(codec);
}
}
reader.readStructEnd();
int argumentPosition = 0;
for (ThriftFieldMetadata argument : parameters) {
if (args[argumentPosition] == null) {
Type argumentType = argument.getThriftType().getJavaType();
if (argumentType instanceof Class) {
Class<?> argumentClass = (Class<?>) argumentType;
argumentClass = Primitives.unwrap(argumentClass);
args[argumentPosition] = Defaults.defaultValue(argumentClass);
}
}
argumentPosition++;
}
return args;
} catch (TProtocolException e) {
throw new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage());
}
}
public <T> void writeResponse(TProtocol out, T result, boolean isException) throws Exception {
if (!isException) {
writeResponse(out, "success", (short) 0, successCodec, result);
} else {
writeExceptionResponse(out, result);
}
}
public boolean isUserException(Object exception) {
ExceptionProcessor exceptionCodec = exceptionCodecs.get(exception.getClass());
if (exceptionCodec != null) {
return true;
}
return false;
}
protected <T> void writeExceptionResponse(TProtocol out,
T exception) throws Exception {
ExceptionProcessor exceptionCodec = exceptionCodecs.get(exception.getClass());
if (exceptionCodec != null) {
writeResponse(out, "exception", exceptionCodec.getId(), exceptionCodec.getCodec(), exception);
}
}
public <T> void writeResponse(TProtocol out,
String responseFieldName,
short responseFieldId,
ThriftCodec<T> responseCodec,
T result) throws Exception {
TProtocolWriter writer = new TProtocolWriter(out);
writer.writeStructBegin(resultStructName);
writer.writeField(responseFieldName, (short) responseFieldId, responseCodec, result);
writer.writeStructEnd();
}
private static final class ExceptionProcessor {
private final short id;
private final ThriftCodec<Object> codec;
private ExceptionProcessor(short id, ThriftCodec<?> coded) {
this.id = id;
this.codec = (ThriftCodec<Object>) coded;
}
public short getId() {
return id;
}
public ThriftCodec<Object> getCodec() {
return codec;
}
}
}