package com.twitter.elephantbird.mapreduce.input; import java.io.IOException; import java.util.ArrayList; import java.util.List; import com.twitter.elephantbird.util.HadoopCompat; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.serde2.ColumnProjectionUtils; import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable; import org.apache.hadoop.hive.serde2.columnar.BytesRefWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.hadoop.mapreduce.lib.input.FileSplit; import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.transport.TMemoryInputTransport; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Function; import com.google.common.collect.Lists; import com.twitter.elephantbird.mapreduce.io.ThriftWritable; import com.twitter.elephantbird.util.ColumnarMetadata; import com.twitter.elephantbird.util.RCFileUtil; import com.twitter.elephantbird.thrift.TStructDescriptor; import com.twitter.elephantbird.thrift.TStructDescriptor.Field; import com.twitter.elephantbird.util.ThriftUtils; import com.twitter.elephantbird.util.TypeRef; public class RCFileThriftInputFormat extends RCFileBaseInputFormat { private static final Logger LOG = LoggerFactory.getLogger(RCFileThriftInputFormat.class); private TypeRef<? extends TBase<?, ?>> typeRef; // for MR public RCFileThriftInputFormat() {} public RCFileThriftInputFormat(TypeRef<? extends TBase<?, ?>> typeRef) { this.typeRef = typeRef; } /** * Stores supplied class name in configuration. This configuration is * read on the remote tasks to initialize the input format correctly. */ public static void setClassConf(Class<? extends TBase<?, ?> > thriftClass, Configuration conf) { ThriftUtils.setClassConf(conf, RCFileThriftInputFormat.class, thriftClass); } @Override public RecordReader<LongWritable, Writable> createRecordReader(InputSplit split, TaskAttemptContext taskAttempt) throws IOException, InterruptedException { if (typeRef == null) { typeRef = ThriftUtils.getTypeRef(HadoopCompat.getConfiguration(taskAttempt), RCFileThriftInputFormat.class); } return new ThriftReader(createUnwrappedRecordReader(split, taskAttempt)); } public class ThriftReader extends FilterRecordReader<LongWritable, Writable> { protected TStructDescriptor tDesc; protected boolean readUnknownsColumn = false; protected List<Field> knownRequiredFields = Lists.newArrayList(); protected ArrayList<Integer> columnsBeingRead = Lists.newArrayList(); protected TMemoryInputTransport memTransport = new TMemoryInputTransport(); protected TBinaryProtocol tProto = new TBinaryProtocol(memTransport); protected ThriftWritable<TBase<?, ?>> thriftWritable; /** * The reader is expected to be a * <code>RecordReader< LongWritable, BytesRefArrayWritable ></code> */ @SuppressWarnings({ "unchecked", "rawtypes" }) public ThriftReader(RecordReader reader) { super(reader); } /** is valid only after initialize() is called */ public boolean isReadingUnknonwsColumn() { return readUnknownsColumn; } @Override @SuppressWarnings("unchecked") public void initialize(InputSplit split, TaskAttemptContext ctx) throws IOException, InterruptedException { // set up columns that needs to read from the RCFile. tDesc = TStructDescriptor.getInstance(typeRef.getRawClass()); thriftWritable = ThriftWritable.newInstance((Class<TBase<?, ?>>)typeRef.getRawClass()); final List<Field> tFields = tDesc.getFields(); FileSplit fsplit = (FileSplit)split; Path file = fsplit.getPath(); LOG.info(String.format("reading %s from %s:%d:%d" , typeRef.getRawClass().getName() , file.toString() , fsplit.getStart() , fsplit.getStart() + fsplit.getLength())); Configuration conf = HadoopCompat.getConfiguration(ctx); ColumnarMetadata storedInfo = RCFileUtil.readMetadata(conf, file); // list of field numbers List<Integer> tFieldIds = Lists.transform(tFields, new Function<Field, Integer>() { public Integer apply(Field fd) { return Integer.valueOf(fd.getFieldId()); } }); columnsBeingRead = RCFileUtil.findColumnsToRead(conf, tFieldIds, storedInfo); for(int idx : columnsBeingRead) { int fid = storedInfo.getFieldId(idx); if (fid >= 0) { knownRequiredFields.add(tFields.get(tFieldIds.indexOf(fid))); } else { readUnknownsColumn = true; } } ColumnProjectionUtils.setReadColumnIDs(conf, columnsBeingRead); // finally! super.initialize(split, ctx); } @Override @SuppressWarnings({ "unchecked", "rawtypes" }) public Writable getCurrentValue() throws IOException, InterruptedException { try { thriftWritable.set(getCurrentThriftValue()); return thriftWritable; } catch (TException e) { //TODO : add error tracking throw new IOException(e); } } /** returns <code>super.getCurrentValue()</code> */ public BytesRefArrayWritable getCurrentBytesRefArrayWritable() throws IOException, InterruptedException { return (BytesRefArrayWritable) super.getCurrentValue(); } /** * Builds Thrift object from the raw bytes returned by RCFile reader. * @throws TException */ @SuppressWarnings({ "unchecked", "rawtypes" }) public TBase<?, ?> getCurrentThriftValue() throws IOException, InterruptedException, TException { BytesRefArrayWritable byteRefs = getCurrentBytesRefArrayWritable(); if (byteRefs == null) { return null; } TBase tObj = tDesc.newThriftObject(); for (int i=0; i < knownRequiredFields.size(); i++) { BytesRefWritable buf = byteRefs.get(columnsBeingRead.get(i)); if (buf.getLength() > 0) { memTransport.reset(buf.getData(), buf.getStart(), buf.getLength()); Field field = knownRequiredFields.get(i); tObj.setFieldValue(field.getFieldIdEnum(), ThriftUtils.readFieldNoTag(tProto, field)); } // else no need to set default value since any default value // would have been serialized when this record was written. } // parse unknowns column if required if (readUnknownsColumn) { int last = columnsBeingRead.get(columnsBeingRead.size() - 1); BytesRefWritable buf = byteRefs.get(last); if (buf.getLength() > 0) { memTransport.reset(buf.getData(), buf.getStart(), buf.getLength()); tObj.read(tProto); } } return tObj; } } }