package com.dianping.pigeon.remoting.common.codec.thrift;
import com.dianping.pigeon.remoting.common.codec.AbstractSerializer;
import com.dianping.pigeon.remoting.common.domain.InvocationRequest;
import com.dianping.pigeon.remoting.common.domain.InvocationResponse;
import com.dianping.pigeon.remoting.common.domain.generic.GenericRequest;
import com.dianping.pigeon.remoting.common.domain.generic.GenericResponse;
import com.dianping.pigeon.remoting.common.domain.generic.MessageType;
import com.dianping.pigeon.remoting.common.domain.generic.ThriftMapper;
import com.dianping.pigeon.remoting.common.domain.generic.thrift.Header;
import com.dianping.pigeon.remoting.common.exception.SerializationException;
import com.dianping.pigeon.remoting.common.util.Constants;
import com.dianping.pigeon.remoting.invoker.domain.InvokerContext;
import com.dianping.pigeon.remoting.invoker.service.ServiceInvocationRepository;
import com.dianping.pigeon.remoting.provider.publish.ServicePublisher;
import com.dianping.pigeon.util.ThriftUtils;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.transport.TIOStreamTransport;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
/**
* @author qi.yin
* 2016/06/27 上午11:21.
*/
public class ThriftSerializer extends AbstractSerializer {
protected static final int HEADER_FIELD_LENGTH = 2;
protected ServiceInvocationRepository repository = ServiceInvocationRepository.getInstance();
private IDLThriftSerializer idlThriftSerializer = new IDLThriftSerializer();
private AnnotationThriftSerializer annotationThriftSerializer = new AnnotationThriftSerializer();
private ConcurrentMap<Class<?>, AbstractThriftSerializer> serializers =
new ConcurrentHashMap<Class<?>, AbstractThriftSerializer>();
public ThriftSerializer() {
validate();
}
private void validate() {
try {
ByteArrayOutputStream os = new ByteArrayOutputStream();
GenericRequest request = new GenericRequest();
request.setServiceName("test");
request.setMethodName("test");
request.setMessageType(Constants.MESSAGE_TYPE_HEART);
request.setTimeout(0);
serializeRequest(os, request);
InputStream is = new ByteArrayInputStream(os.toByteArray());
deserializeRequest(is);
} catch (RuntimeException e) {
throw e;
}
}
@Override
public Object deserializeRequest(InputStream is) throws SerializationException {
GenericRequest request = null;
TIOStreamTransport transport = new TIOStreamTransport(is);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
try {
//headerLength
protocol.readI16();
//header
Header header = new Header();
header.read(protocol);
if (header.getRequestInfo() == null) {
throw new SerializationException("Deserialize requestInfo is no legal. header " + header);
}
request = ThriftMapper.convertHeaderToRequest(header);
if (request.getMessageType() == Constants.MESSAGE_TYPE_SERVICE) {
//body
Class<?> iface = ServicePublisher.getInterface(request.getServiceName());
if (iface == null) {
throw new SerializationException("Deserialize thrift serviceName is invalid.");
}
request.setServiceInterface(iface);
AbstractThriftSerializer serializer = getSerializer(iface);
serializer.doDeserializeRequest(request, protocol);
}
} catch (Exception e) {
throw new SerializationException("Deserialize request failed.", e);
}
return request;
}
@Override
public void serializeRequest(OutputStream os, Object obj) throws SerializationException {
if (!(obj instanceof GenericRequest)) {
throw new SerializationException("Unsupported this request obj serialize.");
} else {
try {
DynamicByteArrayOutputStream bos = new DynamicByteArrayOutputStream(1024);
GenericRequest request = (GenericRequest) obj;
TIOStreamTransport transport = new TIOStreamTransport(bos);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
//headerlength
protocol.writeI16(Short.MAX_VALUE);
//header
Header header = ThriftMapper.convertRequestToHeader(request);
header.write(protocol);
short headerLength = (short) (bos.size() - HEADER_FIELD_LENGTH);
if (header.getMessageType() == MessageType.Normal.getCode()) {
Class<?> iface = request.getServiceInterface();
if (iface == null) {
throw new SerializationException("Serialize thrift interface is null.");
}
AbstractThriftSerializer serializer = getSerializer(iface);
serializer.doSerializeRequest(request, protocol);
}
int messageLength = bos.size();
try {
bos.setWriteIndex(0);
protocol.writeI16(headerLength);
} finally {
bos.setWriteIndex(messageLength);
}
os.write(bos.toByteArray());
} catch (Exception e) {
throw new SerializationException("serialize request failed.", e);
}
}
}
@Override
public Object deserializeResponse(InputStream is) throws SerializationException {
GenericResponse response = null;
TIOStreamTransport transport = new TIOStreamTransport(is);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
try {
//headerLength
protocol.readI16();
//header
Header header = new Header();
header.read(protocol);
if (header.getResponseInfo() == null) {
throw new SerializationException("Deserialize response is no legal. header " + header);
}
response = ThriftMapper.convertHeaderToResponse(header);
if (header.getMessageType() == MessageType.Normal.getCode()) {
GenericRequest request = (GenericRequest) repository.get(
header.getResponseInfo().getSequenceId());
if (request == null) {
throw new SerializationException("Deserialize cannot find related request, May be timeout. sequenceId " + header.getResponseInfo().getSequenceId());
}
Class<?> iface = request.getServiceInterface();
if (iface == null) {
throw new SerializationException("Deserialize interface is null.");
}
AbstractThriftSerializer serializer = getSerializer(iface);
//body
serializer.doDeserializeResponse(response, request, protocol, header);
}
} catch (Exception e) {
throw new SerializationException("Deserialize response failed.", e);
}
return response;
}
@Override
public void serializeResponse(OutputStream os, Object obj) throws SerializationException {
if (!(obj instanceof GenericResponse)) {
throw new SerializationException("Unsupported this response obj serialize.");
} else {
try {
DynamicByteArrayOutputStream bos = new DynamicByteArrayOutputStream(1024);
GenericResponse response = (GenericResponse) obj;
TIOStreamTransport transport = new TIOStreamTransport(bos);
TBinaryProtocol protocol = new TBinaryProtocol(transport);
//headerlength
protocol.writeI16(Short.MAX_VALUE);
//header
Header header = ThriftMapper.convertResponseToHeader(response);
if (header.getMessageType() == MessageType.Normal.getCode()) {
Class<?> iface = ServicePublisher.getInterface(response.getServiceName());
if (iface == null) {
throw new SerializationException("Serialize thrift serviceName is invalid.");
}
response.setServiceInterface(iface);
AbstractThriftSerializer serializer = getSerializer(iface);
//body
serializer.doSerializeResponse(response, protocol, header, bos);
} else {
//header
header.write(protocol);
short headerLength = (short) (bos.size() - HEADER_FIELD_LENGTH);
int messageLength = bos.size();
try {
bos.setWriteIndex(0);
protocol.writeI16(headerLength);
} finally {
bos.setWriteIndex(messageLength);
}
}
os.write(bos.toByteArray());
} catch (Exception e) {
throw new SerializationException("Serialize failed.", e);
}
}
}
@Override
public InvocationResponse newResponse() throws SerializationException {
return new GenericResponse();
}
@Override
public InvocationRequest newRequest(InvokerContext invokerContext) throws SerializationException {
return new GenericRequest(invokerContext);
}
protected AbstractThriftSerializer getSerializer(Class<?> clazz) {
AbstractThriftSerializer serializer = serializers.get(clazz);
if (serializer == null) {
if (ThriftUtils.isAnnotation(clazz)) {
serializer = annotationThriftSerializer;
serializers.putIfAbsent(clazz, annotationThriftSerializer);
} else if (ThriftUtils.isIDL(clazz)) {
serializer = idlThriftSerializer;
serializers.putIfAbsent(clazz, idlThriftSerializer);
} else {
throw new SerializationException("Service interface " + clazz.getName() +
" do not support thrift serialize");
}
}
return serializer;
}
}