package com.twitter.elephantbird.mapreduce.output;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import com.google.common.collect.Lists;
import com.twitter.elephantbird.util.HadoopCompat;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.serde2.ByteStream;
import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable;
import org.apache.hadoop.hive.serde2.columnar.BytesRefWritable;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TField;
import org.apache.thrift.protocol.TProtocolUtil;
import org.apache.thrift.protocol.TType;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TMemoryInputTransport;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.twitter.elephantbird.mapreduce.io.ThriftWritable;
import com.twitter.elephantbird.thrift.TStructDescriptor;
import com.twitter.elephantbird.thrift.TStructDescriptor.Field;
import com.twitter.elephantbird.util.ColumnarMetadata;
import com.twitter.elephantbird.util.Protobufs;
import com.twitter.elephantbird.util.ThriftUtils;
import com.twitter.elephantbird.util.TypeRef;
/**
* OutputFormat for storing Thrift objects in RCFile.<p>
*
* Each of the top level fields is stored in a separate column.
* Thrift field ids are stored in RCFile metadata.<p>
*
* The user can write either a {@link ThriftWritable} with the Thrift object
* or a {@link BytesWritable} with serialized Thrift bytes. The latter
* ensures that all the fields are preserved even if the current Thrift
* definition does not match the definition represented in the serialized bytes.
* Any fields not recognized by current Thrift class are stored in the last
* column.
*/
public class RCFileThriftOutputFormat extends RCFileOutputFormat {
// typeRef is only required for setting metadata for the RCFile
private TypeRef<? extends TBase<?, ?>> typeRef;
private TStructDescriptor tDesc;
private List<Field> tFields;
private int numColumns;
private BytesRefArrayWritable rowWritable = new BytesRefArrayWritable();
private BytesRefWritable[] colValRefs;
/** internal, for MR use only. */
public RCFileThriftOutputFormat() {
}
public RCFileThriftOutputFormat(TypeRef<? extends TBase<?, ?>> typeRef) { // for PigLoader etc.
this.typeRef = typeRef;
init();
}
private void init() {
tDesc = TStructDescriptor.getInstance(typeRef.getRawClass());
tFields = tDesc.getFields();
numColumns = tFields.size() + 1; // known fields + 1 for unknown fields
colValRefs = new BytesRefWritable[numColumns];
for (int i = 0; i < numColumns; i++) {
colValRefs[i] = new BytesRefWritable();
rowWritable.set(i, colValRefs[i]);
}
}
protected ColumnarMetadata makeColumnarMetadata() {
List<Integer> fieldIds = Lists.newArrayList();
for(Field fd : tDesc.getFields()) {
fieldIds.add((int)fd.getFieldId());
}
fieldIds.add(-1); // -1 for unknown fields
return ColumnarMetadata.newInstance(typeRef.getRawClass().getName(), fieldIds);
}
private class ThriftWriter extends RCFileOutputFormat.Writer {
private ByteStream.Output byteStream = new ByteStream.Output();
private TBinaryProtocol tProto = new TBinaryProtocol(
new TIOStreamTransport(byteStream));
// used when deserializing thrift bytes
private Map<Short, Integer> idMap;
private TMemoryInputTransport mTransport;
private TBinaryProtocol skipProto;
ThriftWriter(TaskAttemptContext job) throws IOException {
super(RCFileThriftOutputFormat.this, job,
Protobufs.toText(makeColumnarMetadata().getMessage()));
}
@Override @SuppressWarnings("unchecked")
public void write(NullWritable key, Writable value) throws IOException, InterruptedException {
try {
if (value instanceof BytesWritable) {
// TODO: handle errors
fromBytes((BytesWritable)value);
} else {
fromObject((TBase<?, ?>)((ThriftWritable)value).get());
}
} catch (TException e) {
// might need to tolerate a few errors.
throw new IOException(e);
}
super.write(null, rowWritable);
}
@SuppressWarnings("unchecked")
private void fromObject(TBase tObj)
throws IOException, InterruptedException, TException {
byteStream.reset(); // reinitialize the byteStream if buffer is too large?
int startPos = 0;
// top level fields are split across the columns.
for (int i=0; i < numColumns; i++) {
if (i < (numColumns - 1)) {
Field fd = tFields.get(i);
ThriftUtils.writeFieldNoTag(tProto, fd, tDesc.getFieldValue(i, tObj));
} // else { } : no 'unknown fields' in thrift object
colValRefs[i].set(byteStream.getData(),
startPos,
byteStream.getCount() - startPos);
startPos = byteStream.getCount();
}
}
/**
* extract serialized bytes for each field, including unknown fields and
* store those byes in columns.
*/
private void fromBytes(BytesWritable bytesWritable)
throws IOException, InterruptedException, TException {
if (mTransport == null) {
initIdMap();
mTransport = new TMemoryInputTransport();
skipProto = new TBinaryProtocol(mTransport);
}
byte[] bytes = bytesWritable.getBytes();
mTransport.reset(bytes, 0, bytesWritable.getLength());
byteStream.reset();
// set all the fields to null
for(BytesRefWritable ref : colValRefs) {
ref.set(bytes, 0, 0);
}
skipProto.readStructBegin();
while (true) {
int start = mTransport.getBufferPosition();
TField field = skipProto.readFieldBegin();
if (field.type == TType.STOP) {
break;
}
int fieldStart = mTransport.getBufferPosition();
// skip still creates and copies primitive objects (String, buffer, etc)
// skipProto could override readString() and readBuffer() to avoid that.
TProtocolUtil.skip(skipProto, field.type);
int end = mTransport.getBufferPosition();
Integer idx = idMap.get(field.id);
if (idx != null && field.type == tFields.get(idx).getType()) {
// known field
colValRefs[idx].set(bytes, fieldStart, end-fieldStart);
} else {
// unknown field, copy the bytes to last column (with field id)
byteStream.write(bytes, start, end-start);
}
}
if (byteStream.getCount() > 0) {
byteStream.write(TType.STOP);
colValRefs[colValRefs.length-1].set(byteStream.getData(),
0,
byteStream.getCount());
}
}
private void initIdMap() {
idMap = Maps.newHashMap();
for(int i=0; i<tFields.size(); i++) {
idMap.put(tFields.get(i).getFieldId(), i);
}
idMap = ImmutableMap.copyOf(idMap);
}
}
/**
* Stores supplied class name in configuration. This configuration is
* read on the remote tasks to initialize the output format correctly.
*/
public static void setClassConf(Class<? extends TBase<?, ?> > thriftClass, Configuration conf) {
ThriftUtils.setClassConf(conf, RCFileThriftOutputFormat.class, thriftClass);
}
@Override
public RecordWriter<NullWritable, Writable>
getRecordWriter(TaskAttemptContext job) throws IOException, InterruptedException {
if (typeRef == null) {
typeRef = ThriftUtils.getTypeRef(HadoopCompat.getConfiguration(job), RCFileThriftOutputFormat.class);
init();
}
RCFileOutputFormat.setColumnNumber(HadoopCompat.getConfiguration(job), numColumns);
return new ThriftWriter(job);
}
}