package com.dianping.pigeon.remoting.common.codec.protobuf; import com.dianping.pigeon.remoting.common.codec.AbstractSerializer; import com.dianping.pigeon.remoting.common.domain.DefaultRequest; import com.dianping.pigeon.remoting.common.domain.DefaultResponse; import com.dianping.pigeon.remoting.common.domain.InvocationRequest; import com.dianping.pigeon.remoting.common.domain.InvocationResponse; import com.dianping.pigeon.remoting.common.exception.SerializationException; import com.dianping.pigeon.remoting.common.util.ContextUtils; import com.google.protobuf.Any; import com.google.protobuf.Message; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.exception.ExceptionUtils; import org.objenesis.Objenesis; import org.objenesis.ObjenesisStd; import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.HashMap; import java.util.List; import java.util.Map; /** * Created by chenchongze on 17/3/7. */ public class Protobuf3Serializer extends AbstractSerializer { private static Objenesis objenesis = new ObjenesisStd(true); public Protobuf3Serializer() { } @Override public Object deserializeRequest(InputStream is) throws SerializationException { try { Entities.Pb3Request pb3Request = Entities.Pb3Request.parseFrom(is); List<Any> parametersList = pb3Request.getParametersList(); Object[] parameters = new Object[parametersList.size()]; for (int i = 0; i < parametersList.size(); i++) { Any paramAny = parametersList.get(i); if (StringUtils.isBlank(paramAny.getTypeUrl())) { parameters[i] = null; } else { parameters[i] = unpack(paramAny); } } DefaultRequest request = new DefaultRequest( pb3Request.getServiceName(), pb3Request.getMethodName(), parameters, (byte)pb3Request.getSerialize(), pb3Request.getMessageType(), pb3Request.getTimeout(), pb3Request.getCallType(), pb3Request.getSeq()); request.setVersion(pb3Request.getVersion()); request.setApp(pb3Request.getApp()); request.setGlobalValues(ContextUtils.convertContext(pb3Request.getGlobalValuesMap())); request.setRequestValues(ContextUtils.convertContext(pb3Request.getRequestValuesMap())); return request; } catch (Throwable t) { throw new SerializationException(t.getMessage(), t); } } @Override public void serializeRequest(OutputStream os, Object obj) throws SerializationException { try { InvocationRequest invocationRequest = (InvocationRequest) obj; Map<String, String> globalValues = new HashMap<String, String>(); ContextUtils.convertContext(invocationRequest.getGlobalValues(), globalValues); Map<String, String> requestValues = new HashMap<String, String>(); ContextUtils.convertContext(invocationRequest.getRequestValues(), requestValues); Entities.Pb3Request.Builder requestBuilder = Entities.Pb3Request.newBuilder() .setSerialize(invocationRequest.getSerialize()) .setSeq(invocationRequest.getSequence()) .setMessageType(invocationRequest.getMessageType()) .setTimeout(invocationRequest.getTimeout()) .setServiceName(invocationRequest.getServiceName()) .setMethodName(invocationRequest.getMethodName()) .setCallType(invocationRequest.getCallType()) .setVersion(invocationRequest.getVersion() == null ? "" : invocationRequest.getVersion()) .setApp(invocationRequest.getApp() == null ? "" : invocationRequest.getApp()) .putAllGlobalValues(globalValues) .putAllRequestValues(requestValues); if (invocationRequest.getParameters() != null) { for (Object param : invocationRequest.getParameters()) { if (param == null) { requestBuilder.addParameters(Any.getDefaultInstance()); } else { requestBuilder.addParameters(pack((Message) param)); } } } requestBuilder.build().writeTo(os); } catch (Throwable t) { throw new SerializationException(t.getMessage(), t); } } @Override public Object deserializeResponse(InputStream is) throws SerializationException { try { Entities.Pb3Response pb3Response = Entities.Pb3Response.parseFrom(is); Object returnVal = null; Any returnValAny = pb3Response.getReturnVal(); if (StringUtils.isNotBlank(returnValAny.getTypeUrl())) { returnVal = unpack(returnValAny); } else if (StringUtils.isNotBlank(pb3Response.getException().getCause())) { returnVal = objenesis.newInstance(Class.forName(pb3Response.getException().getCause())); Field msgField = Throwable.class.getDeclaredField("detailMessage"); msgField.setAccessible(true); msgField.set(returnVal, pb3Response.getException().getDetailMessage()); msgField.setAccessible(false); } DefaultResponse response = new DefaultResponse( (byte)pb3Response.getSerialize(), pb3Response.getSeq(), pb3Response.getMessageType(), returnVal, pb3Response.getCause()); response.setResponseValues(ContextUtils.convertContext(pb3Response.getResponseValuesMap())); return response; } catch (Throwable t) { throw new SerializationException(t.getMessage(), t); } } @Override public void serializeResponse(OutputStream os, Object obj) throws SerializationException { try { InvocationResponse invocationResponse = (InvocationResponse) obj; Map<String, String> responseValues = new HashMap<String, String>(); ContextUtils.convertContext(invocationResponse.getResponseValues(), responseValues); Entities.Pb3Response.Builder responseBuilder = Entities.Pb3Response.newBuilder() .setSerialize(invocationResponse.getSerialize()) .setSeq(invocationResponse.getSequence()) .setMessageType(invocationResponse.getMessageType()) .setCause(invocationResponse.getCause() == null ? "" : invocationResponse.getCause()) .putAllResponseValues(responseValues); // exception or normal Object returnVal = invocationResponse.getReturn(); if (returnVal != null) { if (returnVal instanceof Message) { responseBuilder.setReturnVal(pack((Message) returnVal)); } else if (returnVal instanceof Throwable) { responseBuilder.setException( Entities.Pb3Exception.newBuilder() .setCause(returnVal.getClass().getName()) .setDetailMessage(ExceptionUtils.getStackTrace((Throwable) returnVal)) ); } else { throw new RuntimeException("return val must be a Message or Exception! return class: " + returnVal.getClass().getName()); } } responseBuilder.build().writeTo(os); } catch (Throwable t) { throw new SerializationException(t.getMessage(), t); } } private static Any pack(Message msg) { return Any.newBuilder().setTypeUrl(msg.getClass().getName()) .setValue(msg.toByteString()).build(); } private static Object unpack(Any any) { Object obj; try { Class clazz = Class.forName(any.getTypeUrl()); Method method = clazz.getMethod("getDefaultInstance"); obj = method.invoke(method); return ((Message) obj).getParserForType().parseFrom(any.getValue()); } catch (Throwable t) { throw new RuntimeException("Failed to unpack " + any, t); } } }