/******************************************************************************* * Copyright 2017 Capital One Services, LLC and Bitwise, Inc. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License *******************************************************************************/ package hydrograph.engine.cascading.scheme.parquet; import cascading.tuple.TupleEntry; import jodd.typeconverter.Convert; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import parquet.hadoop.api.WriteSupport; import parquet.io.api.Binary; import parquet.io.api.RecordConsumer; import parquet.schema.MessageType; import parquet.schema.MessageTypeParser; import parquet.schema.Type; import java.math.BigDecimal; import java.util.Calendar; import java.util.HashMap; import java.util.List; import java.util.TimeZone; import java.util.concurrent.TimeUnit; public class ParquetTupleWriter extends WriteSupport<TupleEntry> { private static final long MILLIS_PER_DAY = TimeUnit.DAYS.toMillis(1); private static final Logger LOG = LoggerFactory .getLogger(ParquetTupleWriter.class); private static final ThreadLocal<TimeZone> LOCAL_TIMEZONE = new ThreadLocal<TimeZone>() { @Override protected TimeZone initialValue() { return Calendar.getInstance().getTimeZone(); } }; private RecordConsumer recordConsumer; private MessageType rootSchema; public static final String PARQUET_CASCADING_SCHEMA = "parquet.cascading.schema"; @Override public WriteContext init(Configuration configuration) { String schema = configuration.get(PARQUET_CASCADING_SCHEMA); rootSchema = MessageTypeParser.parseMessageType(schema); return new WriteContext(rootSchema, new HashMap<String, String>()); } @Override public void prepareForWrite(RecordConsumer recordConsumer) { this.recordConsumer = recordConsumer; } @Override public void write(TupleEntry record) { recordConsumer.startMessage(); final List<Type> fields = rootSchema.getFields(); for (int i = 0; i < fields.size(); i++) { Type field = fields.get(i); if (record == null || record.getObject(field.getName()) == null) { continue; } recordConsumer.startField(field.getName(), i); if (field.isPrimitive()) { writePrimitive(record, field); } else { throw new UnsupportedOperationException( "Complex type not implemented"); } recordConsumer.endField(field.getName(), i); } recordConsumer.endMessage(); } private void writePrimitive(TupleEntry record, Type field) { String type; if (field.asPrimitiveType().getPrimitiveTypeName().name().toUpperCase() .equals("INT32") && field.getOriginalType() != null && field.getOriginalType().name().toUpperCase().equals("DATE")) type = "DATE"; else type = field.asPrimitiveType().getPrimitiveTypeName().name(); switch (type) { case "BINARY": recordConsumer.addBinary(Binary.fromString(record.getString(field .getName()))); break; case "BOOLEAN": recordConsumer.addBoolean(record.getBoolean(field.getName())); break; case "INT32": recordConsumer.addInteger(record.getInteger(field.getName())); break; case "INT64": recordConsumer.addLong(record.getLong(field.getName())); break; case "DOUBLE": recordConsumer.addDouble(record.getDouble(field.getName())); break; case "FLOAT": recordConsumer.addFloat(record.getFloat(field.getName())); break; case "FIXED_LEN_BYTE_ARRAY": BigDecimal bg = (BigDecimal) record.getObject(field.getName()); recordConsumer.addBinary(decimalToBinary(bg, field)); break; case "INT96": throw new UnsupportedOperationException( "Int96 type not implemented"); case "DATE": recordConsumer.addInteger(millisToDays(record.getLong(field .getName()))); break; default: throw new UnsupportedOperationException(field.getName() + " type not implemented"); } } public static int millisToDays(long millisLocal) { long millisUtc = millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal); int days; if (millisUtc >= 0L) { days = (int) (millisUtc / MILLIS_PER_DAY); } else { days = (int) ((millisUtc - 86399999) / MILLIS_PER_DAY); } return days; } private Binary decimalToBinary(final BigDecimal hiveDecimal, Type field) { int prec = field.asPrimitiveType().getDecimalMetadata().getPrecision(); int scale = field.asPrimitiveType().getDecimalMetadata().getScale(); byte[] decimalBytes = hiveDecimal.setScale(scale).unscaledValue() .toByteArray(); // Estimated number of bytes needed. int precToBytes = ParquetHiveSerDe.PRECISION_TO_BYTE_COUNT[prec - 1]; if (precToBytes == decimalBytes.length) { // No padding needed. return Binary.fromByteArray(decimalBytes); } byte[] tgt = new byte[precToBytes]; if (hiveDecimal.signum() == -1) { // For negative number, initializing bits to 1 for (int i = 0; i < precToBytes; i++) { tgt[i] |= 0xFF; } } System.arraycopy(decimalBytes, 0, tgt, precToBytes - decimalBytes.length, decimalBytes.length); // Padding leading LOG.debug(Convert.toString(tgt.length)); LOG.debug(hiveDecimal.toString());// zeroes/ones. return Binary.fromByteArray(tgt); } }