/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.raptor.storage;
import com.facebook.presto.raptor.util.SyncingFileSystem;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.classloader.ThreadContextClassLoader;
import com.facebook.presto.spi.type.DecimalType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.spi.type.VarbinaryType;
import com.facebook.presto.spi.type.VarcharType;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.json.JsonCodec;
import io.airlift.slice.Slice;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.io.orc.NullMemoryManager;
import org.apache.hadoop.hive.ql.io.orc.OrcFile;
import org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat;
import org.apache.hadoop.hive.ql.io.orc.OrcSerde;
import org.apache.hadoop.hive.ql.io.orc.OrcWriterOptions;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import static com.facebook.presto.raptor.RaptorErrorCode.RAPTOR_ERROR;
import static com.facebook.presto.raptor.storage.Row.extractRow;
import static com.facebook.presto.raptor.storage.StorageType.arrayOf;
import static com.facebook.presto.raptor.storage.StorageType.mapOf;
import static com.facebook.presto.raptor.util.Types.isArrayType;
import static com.facebook.presto.raptor.util.Types.isMapType;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.google.common.base.Functions.toStringFunction;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.Iterables.transform;
import static io.airlift.json.JsonCodec.jsonCodec;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.META_TABLE_COLUMNS;
import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.META_TABLE_COLUMN_TYPES;
import static org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter;
import static org.apache.hadoop.hive.ql.io.orc.CompressionKind.SNAPPY;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.LIST;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.MAP;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category.PRIMITIVE;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardListObjectInspector;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardMapObjectInspector;
import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector;
import static org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory.getPrimitiveTypeInfo;
public class OrcFileWriter
implements Closeable
{
private static final Configuration CONFIGURATION = new Configuration();
private static final Constructor<? extends RecordWriter> WRITER_CONSTRUCTOR = getOrcWriterConstructor();
private static final JsonCodec<OrcFileMetadata> METADATA_CODEC = jsonCodec(OrcFileMetadata.class);
private final List<Type> columnTypes;
private final OrcSerde serializer;
private final RecordWriter recordWriter;
private final SettableStructObjectInspector tableInspector;
private final List<StructField> structFields;
private final Object orcRow;
private boolean closed;
private long rowCount;
private long uncompressedSize;
public OrcFileWriter(List<Long> columnIds, List<Type> columnTypes, File target)
{
this(columnIds, columnTypes, target, true);
}
@VisibleForTesting
OrcFileWriter(List<Long> columnIds, List<Type> columnTypes, File target, boolean writeMetadata)
{
this.columnTypes = ImmutableList.copyOf(requireNonNull(columnTypes, "columnTypes is null"));
checkArgument(columnIds.size() == columnTypes.size(), "ids and types mismatch");
checkArgument(isUnique(columnIds), "ids must be unique");
List<StorageType> storageTypes = ImmutableList.copyOf(toStorageTypes(columnTypes));
Iterable<String> hiveTypeNames = storageTypes.stream().map(StorageType::getHiveTypeName).collect(toList());
List<String> columnNames = ImmutableList.copyOf(transform(columnIds, toStringFunction()));
Properties properties = new Properties();
properties.setProperty(META_TABLE_COLUMNS, Joiner.on(',').join(columnNames));
properties.setProperty(META_TABLE_COLUMN_TYPES, Joiner.on(':').join(hiveTypeNames));
serializer = createSerializer(properties);
recordWriter = createRecordWriter(new Path(target.toURI()), columnIds, columnTypes, writeMetadata);
tableInspector = getStandardStructObjectInspector(columnNames, getJavaObjectInspectors(storageTypes));
structFields = ImmutableList.copyOf(tableInspector.getAllStructFieldRefs());
orcRow = tableInspector.create();
}
public void appendPages(List<Page> pages)
{
for (Page page : pages) {
for (int position = 0; position < page.getPositionCount(); position++) {
appendRow(extractRow(page, position, columnTypes));
}
}
}
public void appendPages(List<Page> inputPages, int[] pageIndexes, int[] positionIndexes)
{
checkArgument(pageIndexes.length == positionIndexes.length, "pageIndexes and positionIndexes do not match");
for (int i = 0; i < pageIndexes.length; i++) {
Page page = inputPages.get(pageIndexes[i]);
appendRow(extractRow(page, positionIndexes[i], columnTypes));
}
}
public void appendRow(Row row)
{
List<Object> columns = row.getColumns();
checkArgument(columns.size() == columnTypes.size());
for (int channel = 0; channel < columns.size(); channel++) {
tableInspector.setStructFieldData(orcRow, structFields.get(channel), columns.get(channel));
}
try {
recordWriter.write(serializer.serialize(orcRow, tableInspector));
}
catch (IOException e) {
throw new PrestoException(RAPTOR_ERROR, "Failed to write record", e);
}
rowCount++;
uncompressedSize += row.getSizeInBytes();
}
@Override
public void close()
{
if (closed) {
return;
}
closed = true;
try {
recordWriter.close(false);
}
catch (IOException e) {
throw new PrestoException(RAPTOR_ERROR, "Failed to close writer", e);
}
}
public long getRowCount()
{
return rowCount;
}
public long getUncompressedSize()
{
return uncompressedSize;
}
private static OrcSerde createSerializer(Properties properties)
{
OrcSerde serde = new OrcSerde();
serde.initialize(CONFIGURATION, properties);
return serde;
}
private static RecordWriter createRecordWriter(Path target, List<Long> columnIds, List<Type> columnTypes, boolean writeMetadata)
{
try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(FileSystem.class.getClassLoader());
FileSystem fileSystem = new SyncingFileSystem(CONFIGURATION)) {
OrcFile.WriterOptions options = new OrcWriterOptions(CONFIGURATION)
.memory(new NullMemoryManager(CONFIGURATION))
.fileSystem(fileSystem)
.compress(SNAPPY);
if (writeMetadata) {
options.callback(createFileMetadataCallback(columnIds, columnTypes));
}
return WRITER_CONSTRUCTOR.newInstance(target, options);
}
catch (ReflectiveOperationException | IOException e) {
throw new PrestoException(RAPTOR_ERROR, "Failed to create writer", e);
}
}
private static OrcFile.WriterCallback createFileMetadataCallback(List<Long> columnIds, List<Type> columnTypes)
{
return new OrcFile.WriterCallback()
{
@Override
public void preStripeWrite(OrcFile.WriterContext context)
throws IOException
{}
@Override
public void preFooterWrite(OrcFile.WriterContext context)
throws IOException
{
ImmutableMap.Builder<Long, TypeSignature> columnTypesMap = ImmutableMap.builder();
for (int i = 0; i < columnIds.size(); i++) {
columnTypesMap.put(columnIds.get(i), columnTypes.get(i).getTypeSignature());
}
byte[] bytes = METADATA_CODEC.toJsonBytes(new OrcFileMetadata(columnTypesMap.build()));
context.getWriter().addUserMetadata(OrcFileMetadata.KEY, ByteBuffer.wrap(bytes));
}
};
}
private static Constructor<? extends RecordWriter> getOrcWriterConstructor()
{
try {
String writerClassName = OrcOutputFormat.class.getName() + "$OrcRecordWriter";
Constructor<? extends RecordWriter> constructor = OrcOutputFormat.class.getClassLoader()
.loadClass(writerClassName).asSubclass(RecordWriter.class)
.getDeclaredConstructor(Path.class, OrcFile.WriterOptions.class);
constructor.setAccessible(true);
return constructor;
}
catch (ReflectiveOperationException e) {
throw Throwables.propagate(e);
}
}
private static List<ObjectInspector> getJavaObjectInspectors(List<StorageType> types)
{
return types.stream()
.map(StorageType::getHiveTypeName)
.map(TypeInfoUtils::getTypeInfoFromTypeString)
.map(OrcFileWriter::getJavaObjectInspector)
.collect(toList());
}
private static ObjectInspector getJavaObjectInspector(TypeInfo typeInfo)
{
Category category = typeInfo.getCategory();
if (category == PRIMITIVE) {
return getPrimitiveJavaObjectInspector(getPrimitiveTypeInfo(typeInfo.getTypeName()));
}
if (category == LIST) {
ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfo;
return getStandardListObjectInspector(getJavaObjectInspector(listTypeInfo.getListElementTypeInfo()));
}
if (category == MAP) {
MapTypeInfo mapTypeInfo = (MapTypeInfo) typeInfo;
return getStandardMapObjectInspector(
getJavaObjectInspector(mapTypeInfo.getMapKeyTypeInfo()),
getJavaObjectInspector(mapTypeInfo.getMapValueTypeInfo()));
}
throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unhandled storage type: " + category);
}
private static <T> boolean isUnique(Collection<T> items)
{
return new HashSet<>(items).size() == items.size();
}
private static List<StorageType> toStorageTypes(List<Type> columnTypes)
{
return columnTypes.stream().map(OrcFileWriter::toStorageType).collect(toList());
}
private static StorageType toStorageType(Type type)
{
if (type instanceof DecimalType) {
DecimalType decimalType = (DecimalType) type;
return StorageType.decimal(decimalType.getPrecision(), decimalType.getScale());
}
Class<?> javaType = type.getJavaType();
if (javaType == boolean.class) {
return StorageType.BOOLEAN;
}
if (javaType == long.class) {
return StorageType.LONG;
}
if (javaType == double.class) {
return StorageType.DOUBLE;
}
if (javaType == Slice.class) {
if (type instanceof VarcharType) {
return StorageType.STRING;
}
if (type.equals(VarbinaryType.VARBINARY)) {
return StorageType.BYTES;
}
}
if (isArrayType(type)) {
return arrayOf(toStorageType(type.getTypeParameters().get(0)));
}
if (isMapType(type)) {
return mapOf(toStorageType(type.getTypeParameters().get(0)), toStorageType(type.getTypeParameters().get(1)));
}
throw new PrestoException(NOT_SUPPORTED, "No storage type for type: " + type);
}
}