package com.pinterest.secor.util; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import org.apache.thrift.TBase; import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.TSerializer; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.pinterest.secor.common.SecorConfig; /** * Adapted from ProtobufUtil Various utilities for working with thrift encoded * messages. This utility will look for thrift class in the configuration. It * can be either per Kafka topic configuration, for example: * * <code>secor.thrift.message.class.<topic>=<thrift class name></code> * * or, it can be global configuration for all topics (in case all the topics * transfer the same message type): * * <code>secor.thrift.message.class.*=<thrift class name></code> * * @author jaime sastre (jaime sastre.s@gmail.com) */ public class ThriftUtil { private static final Logger LOG = LoggerFactory.getLogger(ThriftUtil.class); private boolean allTopics; @SuppressWarnings("rawtypes") private Map<String, Class<? extends TBase>> messageClassByTopic = new HashMap<String, Class<? extends TBase>>(); @SuppressWarnings("rawtypes") private Class<? extends TBase> messageClassForAll; private TProtocolFactory messageProtocolFactory; /** * Creates new instance of {@link ThriftUtil} * * @param config * Secor configuration instance * @throws RuntimeException * when configuration option * <code>secor.thrift.message.class</code> is invalid. */ @SuppressWarnings({ "rawtypes", "unchecked" }) public ThriftUtil(SecorConfig config) { Map<String, String> messageClassPerTopic = config.getThriftMessageClassPerTopic(); for (Entry<String, String> entry : messageClassPerTopic.entrySet()) { try { String topic = entry.getKey(); Class<? extends TBase> messageClass = (Class<? extends TBase>) Class.forName(entry.getValue()); allTopics = "*".equals(topic); if (allTopics) { messageClassForAll = messageClass; LOG.info("Using thrift message class: {} for all Kafka topics", messageClass.getName()); } else { messageClassByTopic.put(topic, messageClass); LOG.info("Using thrift message class: {} for Kafka topic: {}", messageClass.getName(), topic); } } catch (ClassNotFoundException e) { LOG.error("Unable to load thrift message class", e); } } try { String protocolName = config.getThriftProtocolClass(); if (protocolName != null) { String factoryClassName = protocolName.concat("$Factory"); messageProtocolFactory = ((Class<? extends TProtocolFactory>) Class.forName(factoryClassName)).newInstance(); } else messageProtocolFactory = new TBinaryProtocol.Factory(); } catch (ClassNotFoundException e) { LOG.error("Unable to load thrift protocol class", e); } catch (InstantiationException e) { LOG.error("Unable to load thrift protocol class", e); } catch (IllegalAccessException e) { LOG.error("Unable to load thrift protocol class", e); } } /** * Returns configured thrift message class for the given Kafka topic * * @param topic * Kafka topic * @return thrift message class used by this utility instance, or * <code>null</code> in case valid class couldn't be found in the * configuration. */ @SuppressWarnings("rawtypes") public Class<? extends TBase> getMessageClass(String topic) { return allTopics ? messageClassForAll : messageClassByTopic.get(topic); } @SuppressWarnings("rawtypes") public TBase decodeMessage(String topic, byte[] payload) throws InstantiationException, IllegalAccessException, TException { TDeserializer serializer = new TDeserializer(messageProtocolFactory); TBase result = this.getMessageClass(topic).newInstance(); serializer.deserialize(result, payload); return result; } @SuppressWarnings("rawtypes") public byte[] encodeMessage(TBase object) throws InstantiationException, IllegalAccessException, TException { TSerializer serializer = new TSerializer(messageProtocolFactory); return serializer.serialize(object); } }