/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.aliyun.odps.commons.proto; import java.io.IOException; import java.io.OutputStream; import java.math.BigDecimal; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Date; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.zip.Deflater; import java.util.zip.DeflaterOutputStream; import org.apache.commons.io.output.CountingOutputStream; import org.xerial.snappy.SnappyFramedOutputStream; import com.aliyun.odps.Column; import com.aliyun.odps.TableSchema; import com.aliyun.odps.commons.util.DateUtils; import com.aliyun.odps.data.AbstractChar; import com.aliyun.odps.data.Binary; import com.aliyun.odps.data.IntervalDayTime; import com.aliyun.odps.data.IntervalYearMonth; import com.aliyun.odps.data.Record; import com.aliyun.odps.data.RecordPack; import com.aliyun.odps.data.RecordReader; import com.aliyun.odps.data.RecordWriter; import com.aliyun.odps.data.Struct; import com.aliyun.odps.tunnel.io.Checksum; import com.aliyun.odps.tunnel.io.CompressOption; import com.aliyun.odps.tunnel.io.ProtobufRecordPack; import com.aliyun.odps.type.ArrayTypeInfo; import com.aliyun.odps.type.MapTypeInfo; import com.aliyun.odps.type.StructTypeInfo; import com.aliyun.odps.type.TypeInfo; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.WireFormat; /** * @author chao.liu */ public class ProtobufRecordStreamWriter implements RecordWriter { private CountingOutputStream bou; private Column[] columns; private CodedOutputStream out; private long count; private Checksum crc = new Checksum(); private Checksum crccrc = new Checksum(); private Deflater def; public ProtobufRecordStreamWriter(TableSchema schema, OutputStream out) throws IOException { this(schema, out, new CompressOption()); } public ProtobufRecordStreamWriter(TableSchema schema, OutputStream out, CompressOption option) throws IOException { columns = schema.getColumns().toArray(new Column[0]); OutputStream tmpOut; if (option != null) { if (option.algorithm.equals(CompressOption.CompressAlgorithm.ODPS_ZLIB)) { def = new Deflater(); def.setLevel(option.level); def.setStrategy(option.strategy); tmpOut = new DeflaterOutputStream(out, def); } else if (option.algorithm.equals(CompressOption.CompressAlgorithm.ODPS_SNAPPY)) { tmpOut = new SnappyFramedOutputStream(out); } else if (option.algorithm.equals(CompressOption.CompressAlgorithm.ODPS_RAW)) { tmpOut = out; } else { throw new IOException("invalid compression option."); } } else { tmpOut = out; } bou = new CountingOutputStream(tmpOut); this.out = CodedOutputStream.newInstance(bou); } static void writeRawBytes(byte[] value, CodedOutputStream out) throws IOException { out.writeRawVarint32(value.length); out.writeRawBytes(value); } @Override public void write(Record r) throws IOException { int recordValues = r.getColumnCount(); int columnCount = columns.length; if (recordValues > columnCount) { throw new IOException("record values more than schema."); } int i = 0; for (; i < columnCount && i < recordValues; i++) { Object v = r.get(i); if (v == null) { continue; } int pbIdx = i + 1; crc.update(pbIdx); TypeInfo typeInfo = columns[i].getTypeInfo(); writeFieldTag(pbIdx, typeInfo); writeField(v, typeInfo); } int checksum = (int) crc.getValue(); out.writeUInt32(ProtoWireConstant.TUNNEL_END_RECORD, checksum); crc.reset(); crccrc.update(checksum); count++; } private void writeFieldTag(int pbIdx, TypeInfo typeInfo) throws IOException { switch (typeInfo.getOdpsType()) { case DATETIME: case BOOLEAN: case BIGINT: case TINYINT: case SMALLINT: case INT: case DATE: case INTERVAL_YEAR_MONTH: { out.writeTag(pbIdx, WireFormat.WIRETYPE_VARINT); break; } case DOUBLE: { out.writeTag(pbIdx, WireFormat.WIRETYPE_FIXED64); break; } case FLOAT: { out.writeTag(pbIdx, WireFormat.WIRETYPE_FIXED32); break; } case INTERVAL_DAY_TIME: case TIMESTAMP: case STRING: case CHAR: case VARCHAR: case BINARY: case DECIMAL: case ARRAY: case MAP: case STRUCT:{ out.writeTag(pbIdx, com.google.protobuf.WireFormat.WIRETYPE_LENGTH_DELIMITED); break; } default: throw new IOException("Invalid data type: " + typeInfo); } } private void writeField(Object v, TypeInfo typeInfo) throws IOException { switch (typeInfo.getOdpsType()) { case BOOLEAN: { boolean value = (Boolean) v; crc.update(value); out.writeBoolNoTag(value); break; } case DATETIME: { Date value = (Date) v; Long longValue = DateUtils.date2ms(value); crc.update(longValue); out.writeSInt64NoTag(longValue); break; } case DATE: { Long longValue = DateUtils.getDayOffset((java.sql.Date) v); crc.update(longValue); out.writeSInt64NoTag(longValue); break; } case TIMESTAMP: { Long value = ((Timestamp) v).getTime() / 1000; Integer nano = ((Timestamp) v).getNanos(); crc.update(value); crc.update(nano); out.writeSInt64NoTag(value); out.writeSInt32NoTag(nano); break; } case INTERVAL_DAY_TIME: { Long value = ((IntervalDayTime) v).getTotalSeconds(); Integer nano = ((Timestamp) v).getNanos(); crc.update(value); crc.update(nano); out.writeSInt64NoTag(value); out.writeSInt32NoTag(nano); break; } case VARCHAR: case CHAR: { byte [] bytes; bytes = ((AbstractChar) v).getValue().getBytes("UTF-8"); crc.update(bytes, 0, bytes.length); writeRawBytes(bytes, out); break; } case STRING: { byte[] bytes; if (v instanceof String) { String value = (String) v; bytes = value.getBytes("UTF-8"); } else { bytes = (byte[]) v; } crc.update(bytes, 0, bytes.length); writeRawBytes(bytes, out); break; } case BINARY: { byte[] bytes = ((Binary) v).data(); crc.update(bytes, 0, bytes.length); writeRawBytes(bytes, out); break; } case DOUBLE: { double value = (Double) v; crc.update(value); out.writeDoubleNoTag(value); break; } case FLOAT: { float value = (Float) v; crc.update(value); out.writeFloatNoTag(value); break; } case BIGINT: { long value = (Long) v; crc.update(value); out.writeSInt64NoTag(value); break; } case INTERVAL_YEAR_MONTH: { long value = ((IntervalYearMonth) v).getTotalMonths(); crc.update(value); out.writeSInt64NoTag(value); break; } case INT: { long value = ((Integer) v).longValue(); crc.update(value); out.writeSInt64NoTag(value); break; } case SMALLINT: { long value = ((Short) v).longValue(); crc.update(value); out.writeSInt64NoTag(value); break; } case TINYINT: { long value = ((Byte) v).longValue(); crc.update(value); out.writeSInt64NoTag(value); break; } case DECIMAL: { String value = ((BigDecimal) v).toPlainString(); byte[] bytes = value.getBytes("UTF-8"); crc.update(bytes, 0, bytes.length); writeRawBytes(bytes, out); break; } case ARRAY: { writeArray((List) v, ((ArrayTypeInfo) typeInfo).getElementTypeInfo()); break; } case MAP: { MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo; writeMap((Map) v, mapTypeInfo.getKeyTypeInfo(), mapTypeInfo.getValueTypeInfo()); break; } case STRUCT: { writeStruct((Struct) v, (StructTypeInfo) typeInfo); break; } default: throw new IOException("Invalid data type: " + typeInfo); } } private void writeStruct(Struct object, StructTypeInfo typeInfo) throws IOException { List<TypeInfo> fieldTypeInfos = typeInfo.getFieldTypeInfos(); for (int i = 0; i < fieldTypeInfos.size(); ++i) { if (object.getFieldValue(i) == null) { out.writeBoolNoTag(true); } else { out.writeBoolNoTag(false); writeField(object.getFieldValue(i), fieldTypeInfos.get(i)); } } } private void writeArray(List v, TypeInfo type) throws IOException { out.writeInt32NoTag(v.size()); for (int i = 0; i < v.size(); i++) { if (v.get(i) == null) { out.writeBoolNoTag(true); } else { out.writeBoolNoTag(false); writeField(v.get(i), type); } } } private void writeMap(Map v, TypeInfo keyType, TypeInfo valueType) throws IOException { // note: storage will check the availability of key and value List keyList = new ArrayList(); List valueList = new ArrayList(); Iterator iter = v.entrySet().iterator(); while (iter.hasNext()) { Map.Entry entry = (Map.Entry) iter.next(); keyList.add(entry.getKey()); valueList.add(entry.getValue()); } writeArray(keyList, keyType); writeArray(valueList, valueType); } @Override public void close() throws IOException { try { out.writeSInt64(ProtoWireConstant.TUNNEL_META_COUNT, count); out.writeUInt32(ProtoWireConstant.TUNNEL_META_CHECKSUM, (int) crccrc.getValue()); out.flush(); bou.close(); } finally { if (def != null) { def.end(); } } } /** * 返回已经写出的 protobuf 序列化后的字节数。 * * 这个数字不包含已经存在于 buffer 中,但是尚未 flush 的内容。 * 如果需要全部序列化过的字节数,需要在调用本方法前先调用 flush() * * @return 字节数 */ public long getTotalBytes() { return bou.getByteCount(); } @Deprecated public void write(RecordPack pack) throws IOException { if (pack instanceof ProtobufRecordPack) { ProtobufRecordPack pbPack = (ProtobufRecordPack) pack; pbPack.getProtobufStream().writeTo(bou); count += pbPack.getSize(); setCheckSum(pbPack.getCheckSum()); } else { RecordReader reader = pack.getRecordReader(); Record record; while ((record = reader.read()) != null) { write(record); } } } public void flush() throws IOException { out.flush(); } /** * 获取已经写出的 CheckSum */ public Checksum getCheckSum() { return crccrc; } public void setCheckSum(Checksum checkSum) { crccrc = checkSum; } }