/** * */ package io.nettythrift.protocol; import java.util.HashMap; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TJSONProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author HouKx * */ public class ProtocolFactorySelector { private static Logger logger = LoggerFactory.getLogger(ProtocolFactorySelector.class); private final HashMap<Short, TProtocolFactory> protocolFactoryMap = new HashMap<Short, TProtocolFactory>(8); public ProtocolFactorySelector() { } public ProtocolFactorySelector(@SuppressWarnings("rawtypes") Class interfaceClass) { protocolFactoryMap.put((short) -32767, new TBinaryProtocol.Factory()); protocolFactoryMap.put((short) -32223, new TCompactProtocol.Factory()); protocolFactoryMap.put((short) 23345, new TJSONProtocol.Factory()); if (interfaceClass != null) { protocolFactoryMap.put((short) 23330, new TSimpleJSONProtocol.Factory(interfaceClass)); } } protected void registProtocolFactory(short head, TProtocolFactory factory) { protocolFactoryMap.put(head, factory); } public TProtocolFactory getProtocolFactory(short head) { // SimpleJson的前两个字符为:[" ,而TJSONProtocol的第二个字符为一个数字 TProtocolFactory fac = protocolFactoryMap.get(head); if (logger.isDebugEnabled()) { logger.debug("head:{}, getProtocolFactory:{}", head, fac); } return fac; } }