/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.aliyun.odps.commons.proto;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.InflaterInputStream;
import org.xerial.snappy.SnappyFramedInputStream;
import com.aliyun.odps.Column;
import com.aliyun.odps.OdpsType;
import com.aliyun.odps.Survey;
import com.aliyun.odps.TableSchema;
import com.aliyun.odps.commons.util.DateUtils;
import com.aliyun.odps.data.ArrayRecord;
import com.aliyun.odps.data.Binary;
import com.aliyun.odps.data.Char;
import com.aliyun.odps.data.IntervalDayTime;
import com.aliyun.odps.data.IntervalYearMonth;
import com.aliyun.odps.data.Record;
import com.aliyun.odps.data.RecordReader;
import com.aliyun.odps.data.SimpleStruct;
import com.aliyun.odps.data.Struct;
import com.aliyun.odps.data.Varchar;
import com.aliyun.odps.tunnel.io.Checksum;
import com.aliyun.odps.tunnel.io.CompressOption;
import com.aliyun.odps.type.ArrayTypeInfo;
import com.aliyun.odps.type.MapTypeInfo;
import com.aliyun.odps.type.StructTypeInfo;
import com.aliyun.odps.type.TypeInfo;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.WireFormat;
/**
* @author chao.liu
*/
public class ProtobufRecordStreamReader implements RecordReader {
private BufferedInputStream bin;
private CodedInputStream in;
private Column[] columns;
private long count;
private long bytesReaded = 0;
private Checksum crc = new Checksum();
private Checksum crccrc = new Checksum();
public ProtobufRecordStreamReader() {
}
public ProtobufRecordStreamReader(TableSchema schema, InputStream in)
throws IOException {
this(schema, null, in, new CompressOption());
}
public ProtobufRecordStreamReader(TableSchema schema, InputStream in, CompressOption option)
throws IOException {
this(schema, null, in, option);
}
public ProtobufRecordStreamReader(TableSchema schema, List<Column> columns, InputStream in,
CompressOption option) throws IOException {
if (columns == null) {
this.columns = schema.getColumns().toArray(new Column[0]);
} else {
Column[] tmpColumns = new Column[columns.size()];
for (int i = 0; i < columns.size(); ++i) {
tmpColumns[i] = schema.getColumn(columns.get(i).getName());
}
this.columns = tmpColumns;
}
bin = new BufferedInputStream(in);
if (option != null) {
if (option.algorithm.equals(CompressOption.CompressAlgorithm.ODPS_ZLIB)) {
this.in = CodedInputStream.newInstance(new InflaterInputStream(bin));
} else if (option.algorithm.equals(CompressOption.CompressAlgorithm.ODPS_SNAPPY)) {
this.in = CodedInputStream.newInstance(new SnappyFramedInputStream(bin));
} else if (option.algorithm.equals(CompressOption.CompressAlgorithm.ODPS_RAW)) {
this.in = CodedInputStream.newInstance((bin));
} else {
throw new IOException("invalid compression option.");
}
} else {
this.in = CodedInputStream.newInstance(bin);
}
this.in.setSizeLimit(Integer.MAX_VALUE);
}
/**
* 使用 reuse 的Record 读取数据
* 当 reuseRecord 为 null 时,返回一个新的 Record 对象
* 当 reuseRecord 非 null 时, 返回 reuseRecord 本身
* 当数据读取完成, 返回 null
*
* @param reuseRecord
* @return
* @throws IOException
*/
public Record read(Record reuseRecord) throws IOException {
if (reuseRecord == null) {
reuseRecord = new ArrayRecord(columns);
} else {
for (int i = 0; i < reuseRecord.getColumnCount(); ++i) {
reuseRecord.set(i, null);
}
}
while (true) {
int checkSum = 0;
if (in.isAtEnd()) {
throw new IOException("No more record");
}
int i = getTagFieldNumber(in);
if (i == ProtoWireConstant.TUNNEL_END_RECORD) {
checkSum = (int) crc.getValue();
if (in.readUInt32() != checkSum) {
throw new IOException("Checksum invalid.");
}
crc.reset();
crccrc.update(checkSum);
break;
}
if (i == ProtoWireConstant.TUNNEL_META_COUNT) {
if (count != in.readSInt64()) {
throw new IOException("count does not match.");
}
if (ProtoWireConstant.TUNNEL_META_CHECKSUM != getTagFieldNumber(in)) {
throw new IOException("Invalid stream.");
}
if ((int) crccrc.getValue() != in.readUInt32()) {
throw new IOException("Checksum invalid.");
}
if (!in.isAtEnd()) {
throw new IOException("Expect at the end of stream, but not.");
}
return null;
}
// tag index starts from 1.
if (i > columns.length) {
throw new IOException(
"Invalid protobuf tag. Perhaps the datastream from server is crushed.");
}
crc.update(i);
reuseRecord.set(i - 1, readField(columns[i - 1].getTypeInfo()));
}
bytesReaded += in.getTotalBytesRead();
in.resetSizeCounter();
count++;
return reuseRecord;
}
private Object readField(TypeInfo type) throws IOException {
switch (type.getOdpsType()) {
case DOUBLE: {
double v = in.readDouble();
crc.update(v);
return v;
}
case FLOAT: {
float v = in.readFloat();
crc.update(v);
return v;
}
case BOOLEAN: {
boolean v = in.readBool();
crc.update(v);
return v;
}
case BIGINT: {
long v = in.readSInt64();
crc.update(v);
return v;
}
case INTERVAL_YEAR_MONTH: {
long v = in.readSInt64();
crc.update(v);
return new IntervalYearMonth((int) v);
}
case INT: {
long v = in.readSInt64();
crc.update(v);
return (int)v;
}
case SMALLINT: {
long v = in.readSInt64();
crc.update(v);
return (short) v;
}
case TINYINT: {
long v = in.readSInt64();
crc.update(v);
return (byte) v;
}
case STRING: {
return readString();
}
case VARCHAR: {
return new Varchar(readString());
}
case CHAR: {
return new Char(readString());
}
case BINARY:{
return new Binary(readBytes());
}
case DATETIME:{
long v = in.readSInt64();
crc.update(v);
return DateUtils.ms2date(v);
}
case DATE: {
long v = in.readSInt64();
crc.update(v);
// translate to sql.date
return DateUtils.fromDayOffset(v);
}
case INTERVAL_DAY_TIME: {
long time = in.readSInt64();
int nano = in.readSInt32();
crc.update(time);
crc.update(nano);
return new IntervalDayTime(time, nano);
}
case TIMESTAMP: {
long time = in.readSInt64();
int nano = in.readSInt32();
crc.update(time);
crc.update(nano);
Timestamp t = new Timestamp(time * 1000);
t.setNanos(nano);
return t;
}
case DECIMAL: {
int size = in.readRawVarint32();
byte[] bytes = in.readRawBytes(size);
crc.update(bytes, 0, bytes.length);
BigDecimal decimal = new BigDecimal(new String(bytes, "UTF-8"));
return decimal;
}
case ARRAY: {
return readArray(((ArrayTypeInfo) type).getElementTypeInfo());
}
case MAP: {
MapTypeInfo mapTypeInfo = (MapTypeInfo) type;
return readMap(mapTypeInfo.getKeyTypeInfo(), mapTypeInfo.getValueTypeInfo());
}
case STRUCT: {
return readStruct(type);
}
default:
throw new IOException("Unsupported type " + type.getTypeName());
}
}
private String readString() throws IOException {
byte[] bytes = readBytes();
return new String(bytes, "utf-8");
}
private byte[] readBytes() throws IOException {
int size = in.readRawVarint32();
byte[] bytes = in.readRawBytes(size);
crc.update(bytes, 0, bytes.length);
bytesReaded += in.getTotalBytesRead();
in.resetSizeCounter();
return bytes;
}
static int getTagFieldNumber(CodedInputStream in) throws IOException {
return WireFormat.getTagFieldNumber(in.readTag());
}
@Override
public Record read() throws IOException {
return read(null);
}
public Record createEmptyRecord() throws IOException {
return new ArrayRecord(columns);
}
@Override
public void close() throws IOException {
if (bin != null) {
bin.close();
}
}
public long getTotalBytes() {
return bytesReaded;
}
public Struct readStruct(TypeInfo type) throws IOException {
StructTypeInfo typeInfo = (StructTypeInfo) type;
List<Object> values = new ArrayList<Object>();
List<TypeInfo> fieldTypeInfos = typeInfo.getFieldTypeInfos();
for (int i = 0; i < typeInfo.getFieldCount(); ++i) {
if (in.readBool()) {
values.add(null);
} else {
values.add(readField(fieldTypeInfos.get(i)));
}
}
return new SimpleStruct(typeInfo, values);
}
public List readArray(TypeInfo type) throws IOException {
OdpsType t = type.getOdpsType();
if ((t == OdpsType.ARRAY) || (t == OdpsType.MAP) || (t == OdpsType.STRUCT)) {
throw new IOException("Unsupported array type: " + t);
}
int arraySize = in.readUInt32();
List list = new ArrayList();
for (int i = 0; i < arraySize; i++) {
if (in.readBool()) {
list.add(null);
} else {
list.add(readField(type));
}
}
return list;
}
public Map readMap(TypeInfo keyType, TypeInfo valueType) throws IOException {
List keyArray = readArray(keyType);
List valueArray = readArray(valueType);
if (keyArray.size() != valueArray.size()) {
throw new IOException("Read Map error: key value does not match.");
}
Map map = new HashMap();
for (int i = 0; i < keyArray.size(); i++) {
map.put(keyArray.get(i), valueArray.get(i));
}
return map;
}
/**
* remain this func to keep compatibility
* The func param is OdpsType, so it cannot support complex types
* @see #readArray(TypeInfo), it supports all types
*/
@Survey
public List readArray(OdpsType type) throws IOException {
int arraySize = in.readUInt32();
List list = null;
switch (type) {
case STRING: {
list = new ArrayList<byte []>();
for (int i = 0; i < arraySize; i++) {
if (in.readBool()) {
list.add(null);
} else {
int size = in.readRawVarint32();
byte[] bytes = in.readRawBytes(size);
crc.update(bytes, 0, bytes.length);
list.add(bytes);
}
}
break;
}
case BIGINT: {
list = new ArrayList<Long>();
for (int i = 0; i < arraySize; i++) {
if (in.readBool()) {
list.add(null);
} else {
Long value = in.readSInt64();
crc.update(value);
list.add(value);
}
}
break;
}
case DOUBLE: {
list = new ArrayList<Double>();
for (int i = 0; i < arraySize; i++) {
if (in.readBool()) {
list.add(null);
} else {
Double value = in.readDouble();
crc.update(value);
list.add(value);
}
}
break;
}
case BOOLEAN: {
list = new ArrayList<Boolean>();
for (int i = 0; i < arraySize; i++) {
if (in.readBool()) {
list.add(null);
} else {
Boolean value = in.readBool();
crc.update(value);
list.add(value);
}
}
break;
}
default:
throw new IOException("Unsupport array type. type :" + type);
}
return list;
}
/**
* Remain this func to keep compatibility
* The func param is OdpsType, so it cannot support complex types
* @see #readMap(TypeInfo, TypeInfo), it supports all types
*/
@Survey
public Map readMap(OdpsType keyType, OdpsType valueType) throws IOException {
List keyArray = readArray(keyType);
List valueArray = readArray(valueType);
if (keyArray.size() != valueArray.size()) {
throw new IOException("Read Map error: key value does not match.");
}
Map map = new HashMap();
for (int i = 0; i < keyArray.size(); i++) {
map.put(keyArray.get(i), valueArray.get(i));
}
return map;
}
}