package com.dianping.pigeon.remoting.common.codec.thrift; import com.dianping.pigeon.remoting.common.domain.generic.GenericRequest; import com.dianping.pigeon.remoting.common.domain.generic.GenericResponse; import com.dianping.pigeon.remoting.common.domain.generic.thrift.Header; import com.dianping.pigeon.remoting.common.domain.generic.ThriftMapper; import com.dianping.pigeon.remoting.common.domain.generic.StatusCode; import com.dianping.pigeon.remoting.common.exception.SerializationException; import com.dianping.pigeon.util.ClassUtils; import com.dianping.pigeon.util.ThriftUtils; import org.apache.commons.lang.StringUtils; import org.apache.thrift.TApplicationException; import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.TFieldIdEnum; import org.apache.thrift.protocol.TMessage; import org.apache.thrift.protocol.TMessageType; import org.apache.thrift.protocol.TProtocol; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; /** * @author qi.yin * 2016/05/16 下午3:10. */ public class IDLThriftSerializer extends AbstractThriftSerializer { private static ConcurrentMap<String, Class<?>> cachedClass = new ConcurrentHashMap<String, Class<?>>(); private static final String BYTE_ARRAY_CLASS_NAME = "[B"; @Override protected void doDeserializeRequest(GenericRequest request, TProtocol protocol) throws Exception { TMessage message = protocol.readMessageBegin(); if (message.type == TMessageType.CALL) { String argsClassName = ThriftClassNameGenerator.generateArgsClassName( request.getServiceInterface().getName(), message.name); if (StringUtils.isEmpty(argsClassName)) { throw new SerializationException("Deserialize thrift argsClassName is empty."); } Class clazz = cachedClass.get(argsClassName); if (clazz == null) { try { clazz = ClassUtils.loadClass(argsClassName); cachedClass.putIfAbsent(argsClassName, clazz); } catch (ClassNotFoundException e) { throw new SerializationException("Deserialize class" + argsClassName + " load failed.", e); } } TBase args; try { args = (TBase) clazz.newInstance(); } catch (InstantiationException e) { throw new SerializationException("Deserialize class" + argsClassName + " new instance failed.", e); } catch (IllegalAccessException e) { throw new SerializationException("Deserialize class" + argsClassName + " new instance failed.", e); } args.read(protocol); protocol.readMessageEnd(); List<Object> parameters = new ArrayList<Object>(); List<Class<?>> parameterTypes = new ArrayList<Class<?>>(); int index = 1; while (true) { TFieldIdEnum fieldIdEnum = args.fieldForId(index++); if (fieldIdEnum == null) { break; } String fieldName = fieldIdEnum.getFieldName(); String getMethodName = ThriftUtils.generateGetMethodName(fieldName); Method getMethod; try { getMethod = clazz.getMethod(getMethodName); } catch (NoSuchMethodException e) { try { getMethod = clazz.getMethod(ThriftUtils.generateBoolMethodName(fieldName)); } catch (NoSuchMethodException e0) { throw new SerializationException("Deserialize failed.", e); } } Object getResult; try { getResult = getMethod.invoke(args); } catch (IllegalAccessException e) { throw new SerializationException("Deserialize failed.", e); } catch (InvocationTargetException e) { throw new SerializationException("Deserialize failed.", e); } if (BYTE_ARRAY_CLASS_NAME.equals(getMethod.getReturnType().getName())) { parameterTypes.add(ByteBuffer.class); parameters.add(ByteBuffer.wrap((byte[]) getResult)); } else { parameterTypes.add(getMethod.getReturnType()); parameters.add(getResult); } } request.setSeqId(message.seqid); request.setMethodName(message.name); request.setParameters(parameters.toArray()); request.setParameterTypes(parameterTypes.toArray(new Class[parameterTypes.size()])); protocol.readMessageEnd(); } } protected void doSerializeRequest(GenericRequest request, TProtocol protocol) throws Exception { TMessage message = new TMessage( request.getMethodName(), TMessageType.CALL, getSequenceId()); String argsClassName = ThriftClassNameGenerator.generateArgsClassName( request.getServiceInterface().getName(), request.getMethodName()); if (StringUtils.isEmpty(argsClassName)) { throw new SerializationException("Serialize thrift argsClassName is empty."); } Class clazz = cachedClass.get(argsClassName); if (clazz == null) { try { clazz = ClassUtils.loadClass(argsClassName); cachedClass.putIfAbsent(argsClassName, clazz); } catch (ClassNotFoundException e) { throw new SerializationException("Serialize class" + argsClassName + " load failed.", e); } } TBase args; try { args = (TBase) clazz.newInstance(); } catch (InstantiationException e) { throw new SerializationException("Serialize class" + argsClassName + " new instance failed.", e); } catch (IllegalAccessException e) { throw new SerializationException("Serialize class" + argsClassName + " new instance failed.", e); } if (request.getParameters() != null) { for (int i = 0; i < request.getParameters().length; i++) { Object paramObj = request.getParameters()[i]; if (paramObj == null) { continue; } TFieldIdEnum field = args.fieldForId(i + 1); String setMethodName = ThriftUtils.generateSetMethodName(field.getFieldName()); Method method; try { method = clazz.getMethod(setMethodName, request.getParameterTypes()[i]); } catch (NoSuchMethodException e) { throw new SerializationException("Serialize class" + setMethodName + " new instance failed.", e); } try { method.invoke(args, paramObj); } catch (IllegalAccessException e) { throw new SerializationException("Serialize set args failed.", e); } catch (InvocationTargetException e) { throw new SerializationException("Serialize set args failed.", e); } } } //body protocol.writeMessageBegin(message); args.write(protocol); protocol.writeMessageEnd(); protocol.getTransport().flush(); } protected void doDeserializeResponse(GenericResponse response, GenericRequest request, TProtocol protocol, Header header) throws Exception { // body TMessage message = protocol.readMessageBegin(); response.setSeqId(message.seqid); if (message.type == TMessageType.REPLY) { String resultClassName = ThriftClassNameGenerator.generateResultClassName( request.getServiceInterface().getName(), message.name); if (StringUtils.isEmpty(resultClassName)) { throw new SerializationException("Deserialize thrift resultClassName is empty."); } Class<?> clazz = cachedClass.get(resultClassName); if (clazz == null) { try { clazz = ClassUtils.loadClass(resultClassName); cachedClass.putIfAbsent(resultClassName, clazz); } catch (ClassNotFoundException e) { throw new SerializationException("Deserialize failed.", e); } } TBase<?, ? extends TFieldIdEnum> result; try { result = (TBase<?, ?>) clazz.newInstance(); } catch (InstantiationException e) { throw new SerializationException("Deserialize failed.", e); } catch (IllegalAccessException e) { throw new SerializationException("Deserialize failed.", e); } try { result.read(protocol); } catch (TException e) { throw new SerializationException("Deserialize failed.", e); } Object realResult = null; int index = 0; while (true) { TFieldIdEnum fieldIdEnum = result.fieldForId(index++); if (fieldIdEnum == null) { if (index == 1) { continue; } break; } Field field; try { field = clazz.getDeclaredField(fieldIdEnum.getFieldName()); field.setAccessible(true); } catch (NoSuchFieldException e) { throw new SerializationException("Deserialize failed.", e); } try { realResult = field.get(result); } catch (IllegalAccessException e) { throw new SerializationException("Deserialize failed.", e); } if (realResult != null) { break; } } response.setReturn(realResult); } else if (message.type == TMessageType.EXCEPTION) { TApplicationException exception = TApplicationException.read(protocol); ThriftMapper.mapException(header, response, exception.getMessage()); } protocol.readMessageEnd(); } protected void doSerializeResponse(GenericResponse response, TProtocol protocol, Header header, DynamicByteArrayOutputStream bos) throws Exception { String resultClassName = ThriftClassNameGenerator.generateResultClassName( response.getServiceInterface().getName(), response.getMethodName()); if (StringUtils.isEmpty(resultClassName)) { throw new SerializationException("Serialize thrift resultClassName is empty."); } Class clazz = cachedClass.get(resultClassName); if (clazz == null) { try { clazz = ClassUtils.loadClass(resultClassName); cachedClass.putIfAbsent(resultClassName, clazz); } catch (ClassNotFoundException e) { throw new SerializationException("Serialize failed.", e); } } TBase resultObj; try { resultObj = (TBase) clazz.newInstance(); } catch (InstantiationException e) { throw new SerializationException("Serialize failed.", e); } catch (IllegalAccessException e) { throw new SerializationException("Serialize failed.", e); } TApplicationException applicationException = null; TMessage message; if (response.hasException()) { Throwable throwable = (Throwable) response.getReturn(); int index = 1; boolean found = false; while (true) { TFieldIdEnum fieldIdEnum = resultObj.fieldForId(index++); if (fieldIdEnum == null) { break; } String fieldName = fieldIdEnum.getFieldName(); String getMethodName = ThriftUtils.generateGetMethodName(fieldName); String setMethodName = ThriftUtils.generateSetMethodName(fieldName); Method getMethod; Method setMethod; try { try { getMethod = clazz.getMethod(getMethodName); } catch (NoSuchMethodException e) { try { getMethod = clazz.getMethod(ThriftUtils.generateBoolMethodName(fieldName)); } catch (NoSuchMethodException e0) { throw new SerializationException("Serialize failed.", e); } } if (getMethod.getReturnType().equals(throwable.getClass())) { header.responseInfo.setStatus(StatusCode.ApplicationException.getCode()); found = true; setMethod = clazz.getMethod(setMethodName, throwable.getClass()); setMethod.invoke(resultObj, throwable); } } catch (NoSuchMethodException e) { throw new SerializationException("Serialize failed.", e); } catch (InvocationTargetException e) { throw new SerializationException("Serialize failed.", e); } catch (IllegalAccessException e) { throw new SerializationException("Serialize failed.", e); } } if (!found) { applicationException = new TApplicationException(throwable.getMessage()); } } else { Object realResult = response.getReturn(); // result field id is 0 String fieldName = resultObj.fieldForId(0).getFieldName(); String setMethodName = ThriftUtils.generateSetMethodName(fieldName); String getMethodName = ThriftUtils.generateGetMethodName(fieldName); Method getMethod; Method setMethod; try { try { getMethod = clazz.getMethod(getMethodName); } catch (NoSuchMethodException e) { try { getMethod = clazz.getMethod(ThriftUtils.generateBoolMethodName(fieldName)); } catch (NoSuchMethodException e0) { throw new SerializationException("Serialize failed.", e); } } Class<?> returnType = getMethod.getReturnType(); if (BYTE_ARRAY_CLASS_NAME.equals(getMethod.getReturnType().getName())) { returnType = ByteBuffer.class; } setMethod = clazz.getMethod(setMethodName, returnType); setMethod.invoke(resultObj, realResult); } catch (NoSuchMethodException e) { throw new SerializationException("Serialize failed.", e); } catch (InvocationTargetException e) { throw new SerializationException("Serialize failed.", e); } catch (IllegalAccessException e) { throw new SerializationException("Serialize failed.", e); } } if (applicationException != null) { message = new TMessage(response.getMethodName(), TMessageType.EXCEPTION, response.getSeqId()); } else { message = new TMessage(response.getMethodName(), TMessageType.REPLY, response.getSeqId()); } //header header.write(protocol); short headerLength = (short) (bos.size() - HEADER_FIELD_LENGTH); protocol.writeMessageBegin(message); switch (message.type) { case TMessageType.EXCEPTION: applicationException.write(protocol); break; case TMessageType.REPLY: resultObj.write(protocol); break; } protocol.writeMessageEnd(); protocol.getTransport().flush(); int messageLength = bos.size(); try { bos.setWriteIndex(0); protocol.writeI16(headerLength); } finally { bos.setWriteIndex(messageLength); } } }