/** * File Created at 2011-12-05 * $Id$ * * Copyright 2008 Alibaba.com Croporation Limited. * All rights reserved. * * This software is the confidential and proprietary information of * Alibaba Company. ("Confidential Information"). You shall not * disclose such Confidential Information and shall use it only in * accordance with the terms of the license agreement you entered into * with Alibaba.com. */ package com.alibaba.dubbo.rpc.protocol.thrift; import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang.StringUtils; import org.apache.thrift.TApplicationException; import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.TFieldIdEnum; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TMessage; import org.apache.thrift.protocol.TMessageType; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.transport.TFramedTransport; import org.apache.thrift.transport.TIOStreamTransport; import com.alibaba.dubbo.common.Constants; import com.alibaba.dubbo.common.extension.ExtensionLoader; import com.alibaba.dubbo.common.utils.ClassHelper; import com.alibaba.dubbo.remoting.Channel; import com.alibaba.dubbo.remoting.Codec2; import com.alibaba.dubbo.remoting.buffer.ChannelBuffer; import com.alibaba.dubbo.remoting.buffer.ChannelBufferInputStream; import com.alibaba.dubbo.remoting.exchange.Request; import com.alibaba.dubbo.remoting.exchange.Response; import com.alibaba.dubbo.rpc.RpcException; import com.alibaba.dubbo.rpc.RpcInvocation; import com.alibaba.dubbo.rpc.RpcResult; import com.alibaba.dubbo.rpc.protocol.thrift.io.RandomAccessByteArrayOutputStream; /** * Thrift framed protocol codec. * * <pre> * |<- message header ->|<- message body ->| * +----------------+----------------------+------------------+---------------------------+------------------+ * | magic (2 bytes)|message size (4 bytes)|head size(2 bytes)| version (1 byte) | header | message body | * +----------------+----------------------+------------------+---------------------------+------------------+ * |<- message size ->| * </pre> * * <p> * <b>header fields in version 1</b> * <ol> * <li>string - service name</li> * <li>long - dubbo request id</li> * </ol> * </p> * * @author <a href="mailto:gang.lvg@alibaba-inc.com">gang.lvg</a> */ public class ThriftCodec implements Codec2 { private static final AtomicInteger THRIFT_SEQ_ID = new AtomicInteger( 0 ); private static final ConcurrentMap<String, Class<?>> cachedClass = new ConcurrentHashMap<String, Class<?>>(); static final ConcurrentMap<Long, RequestData> cachedRequest = new ConcurrentHashMap<Long, RequestData>(); public static final int MESSAGE_LENGTH_INDEX = 2; public static final int MESSAGE_HEADER_LENGTH_INDEX = 6; public static final int MESSAGE_SHORTEST_LENGTH = 10; public static final String NAME = "thrift"; public static final String PARAMETER_CLASS_NAME_GENERATOR = "class.name.generator"; public static final byte VERSION = (byte)1; public static final short MAGIC = (short) 0xdabc; public void encode( Channel channel, ChannelBuffer buffer, Object message ) throws IOException { if ( message instanceof Request ) { encodeRequest( channel, buffer, ( Request ) message ); } else if ( message instanceof Response ) { encodeResponse( channel, buffer, ( Response ) message ); } else { throw new UnsupportedOperationException( new StringBuilder( 32 ) .append( "Thrift codec only support encode " ) .append( Request.class.getName() ) .append( " and " ) .append( Response.class.getName() ) .toString() ); } } public Object decode( Channel channel, ChannelBuffer buffer ) throws IOException { int available = buffer.readableBytes(); if ( available < MESSAGE_SHORTEST_LENGTH ) { return DecodeResult.NEED_MORE_INPUT; } else { TIOStreamTransport transport = new TIOStreamTransport( new ChannelBufferInputStream(buffer)); TBinaryProtocol protocol = new TBinaryProtocol( transport ); short magic; int messageLength; try{ // protocol.readI32(); // skip the first message length byte[] bytes = new byte[4]; transport.read( bytes, 0, 4 ); magic = protocol.readI16(); messageLength = protocol.readI32(); } catch ( TException e ) { throw new IOException( e.getMessage(), e ); } if ( MAGIC != magic ) { throw new IOException( new StringBuilder( 32 ) .append( "Unknown magic code " ) .append( magic ) .toString() ); } if ( available < messageLength ) { return DecodeResult.NEED_MORE_INPUT; } return decode( protocol ); } } private Object decode( TProtocol protocol ) throws IOException { // version String serviceName; long id; TMessage message; try { protocol.readI16(); protocol.readByte(); serviceName = protocol.readString(); id = protocol.readI64(); message = protocol.readMessageBegin(); } catch ( TException e ) { throw new IOException( e.getMessage(), e ); } if ( message.type == TMessageType.CALL ) { RpcInvocation result = new RpcInvocation(); result.setAttachment(Constants.INTERFACE_KEY, serviceName ); result.setMethodName( message.name ); String argsClassName = ExtensionLoader.getExtensionLoader(ClassNameGenerator.class) .getExtension(ThriftClassNameGenerator.NAME).generateArgsClassName( serviceName, message.name ); if ( StringUtils.isEmpty( argsClassName ) ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, "The specified interface name incorrect." ); } Class clazz = cachedClass.get( argsClassName ); if ( clazz == null ) { try { clazz = ClassHelper.forNameWithThreadContextClassLoader( argsClassName ); cachedClass.putIfAbsent( argsClassName, clazz ); } catch ( ClassNotFoundException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } TBase args; try { args = ( TBase ) clazz.newInstance(); } catch ( InstantiationException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } try{ args.read( protocol ); protocol.readMessageEnd(); } catch ( TException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } List<Object> parameters = new ArrayList<Object>(); List<Class<?>> parameterTypes =new ArrayList<Class<?>>(); int index = 1; while ( true ) { TFieldIdEnum fieldIdEnum = args.fieldForId( index++ ); if ( fieldIdEnum == null ) { break; } String fieldName = fieldIdEnum.getFieldName(); String getMethodName = ThriftUtils.generateGetMethodName( fieldName ); Method getMethod; try { getMethod = clazz.getMethod( getMethodName ); } catch ( NoSuchMethodException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } parameterTypes.add( getMethod.getReturnType() ); try { parameters.add( getMethod.invoke( args ) ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( InvocationTargetException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } result.setArguments( parameters.toArray() ); result.setParameterTypes(parameterTypes.toArray(new Class[parameterTypes.size()])); Request request = new Request( id ); request.setData( result ); cachedRequest.putIfAbsent( id, RequestData.create( message.seqid, serviceName, message.name ) ); return request; } else if ( message.type == TMessageType.EXCEPTION ) { TApplicationException exception; try { exception = TApplicationException.read( protocol ); protocol.readMessageEnd(); } catch ( TException e ) { throw new IOException( e.getMessage(), e ); } RpcResult result = new RpcResult(); result.setException( new RpcException( exception.getMessage() ) ); Response response = new Response(); response.setResult( result ); response.setId( id ); return response; } else if ( message.type == TMessageType.REPLY ) { String resultClassName = ExtensionLoader.getExtensionLoader( ClassNameGenerator.class ) .getExtension(ThriftClassNameGenerator.NAME).generateResultClassName( serviceName, message.name ); if ( StringUtils.isEmpty( resultClassName ) ) { throw new IllegalArgumentException( new StringBuilder( 32 ) .append( "Could not infer service result class name from service name " ) .append( serviceName ) .append( ", the service name you specified may not generated by thrift idl compiler" ) .toString() ); } Class<?> clazz = cachedClass.get( resultClassName ); if ( clazz == null ) { try { clazz = ClassHelper.forNameWithThreadContextClassLoader( resultClassName ); cachedClass.putIfAbsent( resultClassName, clazz ); } catch ( ClassNotFoundException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } TBase<?,? extends TFieldIdEnum> result; try { result = ( TBase<?,?> ) clazz.newInstance(); } catch ( InstantiationException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } try { result.read( protocol ); protocol.readMessageEnd(); } catch ( TException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } Object realResult = null; int index = 0; while ( true ) { TFieldIdEnum fieldIdEnum = result.fieldForId( index++ ); if ( fieldIdEnum == null ) { break ; } Field field; try { field = clazz.getDeclaredField( fieldIdEnum.getFieldName() ); field.setAccessible( true ); } catch ( NoSuchFieldException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } try { realResult = field.get( result ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } if ( realResult != null ) { break ; } } Response response = new Response(); response.setId( id ); RpcResult rpcResult = new RpcResult(); if ( realResult instanceof Throwable ) { rpcResult.setException( ( Throwable ) realResult ); } else { rpcResult.setValue(realResult); } response.setResult( rpcResult ); return response; } else { // Impossible throw new IOException( ); } } private void encodeRequest( Channel channel, ChannelBuffer buffer, Request request ) throws IOException { RpcInvocation inv = ( RpcInvocation ) request.getData(); int seqId = nextSeqId(); String serviceName = inv.getAttachment(Constants.INTERFACE_KEY); if ( StringUtils.isEmpty( serviceName ) ) { throw new IllegalArgumentException( new StringBuilder( 32 ) .append( "Could not find service name in attachment with key " ) .append(Constants.INTERFACE_KEY) .toString() ); } TMessage message = new TMessage( inv.getMethodName(), TMessageType.CALL, seqId ); String methodArgs = ExtensionLoader.getExtensionLoader( ClassNameGenerator.class ) .getExtension(channel.getUrl().getParameter(ThriftConstants.CLASS_NAME_GENERATOR_KEY, ThriftClassNameGenerator.NAME)) .generateArgsClassName(serviceName, inv.getMethodName()); if ( StringUtils.isEmpty( methodArgs ) ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, new StringBuilder(32).append( "Could not encode request, the specified interface may be incorrect." ).toString() ); } Class<?> clazz = cachedClass.get( methodArgs ); if ( clazz == null ) { try { clazz = ClassHelper.forNameWithThreadContextClassLoader( methodArgs ); cachedClass.putIfAbsent( methodArgs, clazz ); } catch ( ClassNotFoundException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } TBase args; try { args = (TBase) clazz.newInstance(); } catch ( InstantiationException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } for( int i = 0; i < inv.getArguments().length; i++ ) { Object obj = inv.getArguments()[i]; if ( obj == null ) { continue; } TFieldIdEnum field = args.fieldForId( i + 1 ); String setMethodName = ThriftUtils.generateSetMethodName( field.getFieldName() ); Method method; try { method = clazz.getMethod( setMethodName, inv.getParameterTypes()[i] ); } catch ( NoSuchMethodException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } try { method.invoke( args, obj ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( InvocationTargetException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream( 1024 ); TIOStreamTransport transport = new TIOStreamTransport( bos ); TBinaryProtocol protocol = new TBinaryProtocol( transport ); int headerLength, messageLength; byte[] bytes = new byte[4]; try { // magic protocol.writeI16( MAGIC ); // message length placeholder protocol.writeI32( Integer.MAX_VALUE ); // message header length placeholder protocol.writeI16( Short.MAX_VALUE ); // version protocol.writeByte( VERSION ); // service name protocol.writeString( serviceName ); // dubbo request id protocol.writeI64( request.getId() ); protocol.getTransport().flush(); // header size headerLength = bos.size(); // message body protocol.writeMessageBegin( message ); args.write( protocol ); protocol.writeMessageEnd(); protocol.getTransport().flush(); int oldIndex = messageLength = bos.size(); // fill in message length and header length try { TFramedTransport.encodeFrameSize( messageLength, bytes ); bos.setWriteIndex( MESSAGE_LENGTH_INDEX ); protocol.writeI32( messageLength ); bos.setWriteIndex( MESSAGE_HEADER_LENGTH_INDEX ); protocol.writeI16( ( short )( 0xffff & headerLength ) ); } finally { bos.setWriteIndex( oldIndex ); } } catch ( TException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } buffer.writeBytes(bytes); buffer.writeBytes(bos.toByteArray()); } private void encodeResponse( Channel channel, ChannelBuffer buffer, Response response ) throws IOException { RpcResult result = ( RpcResult ) response.getResult(); RequestData rd = cachedRequest.get( response.getId() ); String resultClassName = ExtensionLoader.getExtensionLoader( ClassNameGenerator.class ).getExtension( channel.getUrl().getParameter(ThriftConstants.CLASS_NAME_GENERATOR_KEY, ThriftClassNameGenerator.NAME)) .generateResultClassName(rd.serviceName, rd.methodName); if ( StringUtils.isEmpty( resultClassName ) ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, new StringBuilder( 32 ).append( "Could not encode response, the specified interface may be incorrect." ).toString() ); } Class clazz = cachedClass.get( resultClassName ); if ( clazz == null ) { try { clazz = ClassHelper.forNameWithThreadContextClassLoader(resultClassName); cachedClass.putIfAbsent( resultClassName, clazz ); } catch ( ClassNotFoundException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } TBase resultObj; try { resultObj = ( TBase ) clazz.newInstance(); } catch ( InstantiationException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } TApplicationException applicationException = null; TMessage message; if ( result.hasException() ) { Throwable throwable = result.getException(); int index = 1; boolean found = false; while ( true ) { TFieldIdEnum fieldIdEnum = resultObj.fieldForId( index++ ); if ( fieldIdEnum == null ) { break; } String fieldName = fieldIdEnum.getFieldName(); String getMethodName = ThriftUtils.generateGetMethodName( fieldName ); String setMethodName = ThriftUtils.generateSetMethodName( fieldName ); Method getMethod; Method setMethod; try { getMethod = clazz.getMethod( getMethodName ); if ( getMethod.getReturnType().equals( throwable.getClass() ) ) { found = true; setMethod = clazz.getMethod( setMethodName, throwable.getClass() ); setMethod.invoke( resultObj, throwable ); } } catch ( NoSuchMethodException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( InvocationTargetException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } if ( !found ) { applicationException = new TApplicationException( throwable.getMessage() ); } } else { Object realResult = result.getResult(); // result field id is 0 String fieldName = resultObj.fieldForId( 0 ).getFieldName(); String setMethodName = ThriftUtils.generateSetMethodName( fieldName ); String getMethodName = ThriftUtils.generateGetMethodName( fieldName ); Method getMethod; Method setMethod; try { getMethod = clazz.getMethod( getMethodName ); setMethod = clazz.getMethod( setMethodName, getMethod.getReturnType() ); setMethod.invoke( resultObj, realResult ); } catch ( NoSuchMethodException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( InvocationTargetException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } catch ( IllegalAccessException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } } if ( applicationException != null ) { message = new TMessage( rd.methodName, TMessageType.EXCEPTION, rd.id ); } else { message = new TMessage( rd.methodName, TMessageType.REPLY, rd.id ); } RandomAccessByteArrayOutputStream bos = new RandomAccessByteArrayOutputStream( 1024 ); TIOStreamTransport transport = new TIOStreamTransport( bos ); TBinaryProtocol protocol = new TBinaryProtocol( transport ); int messageLength; int headerLength; byte[] bytes = new byte[4]; try { // magic protocol.writeI16( MAGIC ); // message length protocol.writeI32( Integer.MAX_VALUE ); // message header length protocol.writeI16( Short.MAX_VALUE ); // version protocol.writeByte( VERSION ); // service name protocol.writeString( rd.serviceName ); // id protocol.writeI64( response.getId() ); protocol.getTransport().flush(); headerLength = bos.size(); // message protocol.writeMessageBegin( message ); switch ( message.type ) { case TMessageType.EXCEPTION: applicationException.write( protocol ); break; case TMessageType.REPLY: resultObj.write( protocol ); break; } protocol.writeMessageEnd(); protocol.getTransport().flush(); int oldIndex = messageLength = bos.size(); try{ TFramedTransport.encodeFrameSize( messageLength, bytes ); bos.setWriteIndex( MESSAGE_LENGTH_INDEX ); protocol.writeI32( messageLength ); bos.setWriteIndex( MESSAGE_HEADER_LENGTH_INDEX ); protocol.writeI16( ( short ) ( 0xffff & headerLength ) ); } finally { bos.setWriteIndex( oldIndex ); } } catch ( TException e ) { throw new RpcException( RpcException.SERIALIZATION_EXCEPTION, e.getMessage(), e ); } buffer.writeBytes(bytes); buffer.writeBytes(bos.toByteArray()); } private static int nextSeqId() { return THRIFT_SEQ_ID.incrementAndGet(); } // just for test static int getSeqId() { return THRIFT_SEQ_ID.get(); } static class RequestData { int id; String serviceName; String methodName; static RequestData create( int id, String sn, String mn ) { RequestData result = new RequestData(); result.id = id; result.serviceName = sn; result.methodName = mn; return result; } } }