/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.hive; import com.facebook.presto.connector.ConnectorId; import com.facebook.presto.hive.orc.OrcPageSourceFactory; import com.facebook.presto.metadata.Split; import com.facebook.presto.operator.CursorProcessor; import com.facebook.presto.operator.DriverContext; import com.facebook.presto.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; import com.facebook.presto.operator.SourceOperator; import com.facebook.presto.operator.SourceOperatorFactory; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.Page; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.predicate.TupleDomain; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.relational.RowExpression; import com.facebook.presto.testing.TestingSplit; import com.facebook.presto.testing.TestingTransactionHandle; import com.google.common.base.Joiner; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; import org.apache.hadoop.hive.ql.io.HiveOutputFormat; import org.apache.hadoop.hive.ql.io.orc.NullMemoryManager; import org.apache.hadoop.hive.ql.io.orc.OrcFile.WriterOptions; import org.apache.hadoop.hive.ql.io.orc.OrcInputFormat; import org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat; import org.apache.hadoop.hive.ql.io.orc.OrcSerde; import org.apache.hadoop.hive.ql.io.orc.OrcWriterOptions; import org.apache.hadoop.hive.ql.io.orc.Writer; import org.apache.hadoop.hive.ql.io.orc.WriterImpl; import org.apache.hadoop.hive.serde2.SerDe; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.CompressionCodecFactory; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.JobConf; import org.joda.time.DateTimeZone; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.List; import java.util.Optional; import java.util.OptionalInt; import java.util.Properties; import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.function.Supplier; import java.util.stream.Collectors; import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.PARTITION_KEY; import static com.facebook.presto.hive.HiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static com.facebook.presto.hive.HiveTestUtils.SESSION; import static com.facebook.presto.hive.HiveTestUtils.TYPE_MANAGER; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.base.Predicates.not; import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Iterables.transform; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertBetweenInclusive; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.FILE_INPUT_FORMAT; import static org.apache.hadoop.hive.ql.io.orc.CompressionKind.ZLIB; import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_LIB; import static org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.getStandardStructObjectInspector; import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaStringObjectInspector; import static org.apache.hadoop.mapreduce.lib.output.FileOutputFormat.COMPRESS_CODEC; import static org.apache.hadoop.mapreduce.lib.output.FileOutputFormat.COMPRESS_TYPE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; public class TestOrcPageSourceMemoryTracking { private static final String ORC_RECORD_WRITER = OrcOutputFormat.class.getName() + "$OrcRecordWriter"; private static final Constructor<? extends RecordWriter> WRITER_CONSTRUCTOR = getOrcWriterConstructor(); private static final Configuration CONFIGURATION = new Configuration(); private static final int NUM_ROWS = 50000; private static final int STRIPE_ROWS = 20000; private static final ExpressionCompiler EXPRESSION_COMPILER = new ExpressionCompiler(createTestMetadataManager()); private final Random random = new Random(); private final List<TestColumn> testColumns = ImmutableList.<TestColumn>builder() .add(new TestColumn("p_empty_string", javaStringObjectInspector, () -> "", true)) .add(new TestColumn("p_string", javaStringObjectInspector, () -> Long.toHexString(random.nextLong()), false)) .build(); private File tempFile; private TestPreparer testPreparer; @BeforeClass public void setUp() throws Exception { tempFile = File.createTempFile("presto_test_orc_page_source_memory_tracking", "orc"); tempFile.delete(); testPreparer = new TestPreparer(tempFile.getAbsolutePath()); } @AfterClass public void tearDown() throws Exception { tempFile.delete(); } @Test public void testPageSource() throws Exception { // Numbers used in assertions in this test may change when implementation is modified, // feel free to change them if they break in the future ConnectorPageSource pageSource = testPreparer.newPageSource(); assertEquals(pageSource.getSystemMemoryUsage(), 0); long memoryUsage = -1; for (int i = 0; i < 20; i++) { assertFalse(pageSource.isFinished()); Page page = pageSource.getNextPage(); assertNotNull(page); Block block = page.getBlock(1); if (memoryUsage == -1) { assertBetweenInclusive(pageSource.getSystemMemoryUsage(), 180000L, 189999L); // Memory usage before lazy-loading the block createUnboundedVarcharType().getSlice(block, block.getPositionCount() - 1); // trigger loading for lazy block memoryUsage = pageSource.getSystemMemoryUsage(); assertBetweenInclusive(memoryUsage, 460000L, 469999L); // Memory usage after lazy-loading the actual block } else { assertEquals(pageSource.getSystemMemoryUsage(), memoryUsage); createUnboundedVarcharType().getSlice(block, block.getPositionCount() - 1); // trigger loading for lazy block assertEquals(pageSource.getSystemMemoryUsage(), memoryUsage); } } memoryUsage = -1; for (int i = 20; i < 40; i++) { assertFalse(pageSource.isFinished()); Page page = pageSource.getNextPage(); assertNotNull(page); Block block = page.getBlock(1); if (memoryUsage == -1) { assertBetweenInclusive(pageSource.getSystemMemoryUsage(), 180000L, 189999L); // Memory usage before lazy-loading the block createUnboundedVarcharType().getSlice(block, block.getPositionCount() - 1); // trigger loading for lazy block memoryUsage = pageSource.getSystemMemoryUsage(); assertBetweenInclusive(memoryUsage, 460000L, 469999L); // Memory usage after lazy-loading the actual block } else { assertEquals(pageSource.getSystemMemoryUsage(), memoryUsage); createUnboundedVarcharType().getSlice(block, block.getPositionCount() - 1); // trigger loading for lazy block assertEquals(pageSource.getSystemMemoryUsage(), memoryUsage); } } memoryUsage = -1; for (int i = 40; i < 50; i++) { assertFalse(pageSource.isFinished()); Page page = pageSource.getNextPage(); assertNotNull(page); Block block = page.getBlock(1); if (memoryUsage == -1) { assertBetweenInclusive(pageSource.getSystemMemoryUsage(), 90000L, 99999L); // Memory usage before lazy-loading the block createUnboundedVarcharType().getSlice(block, block.getPositionCount() - 1); // trigger loading for lazy block memoryUsage = pageSource.getSystemMemoryUsage(); assertBetweenInclusive(memoryUsage, 360000L, 369999L); // Memory usage after lazy-loading the actual block } else { assertEquals(pageSource.getSystemMemoryUsage(), memoryUsage); createUnboundedVarcharType().getSlice(block, block.getPositionCount() - 1); // trigger loading for lazy block assertEquals(pageSource.getSystemMemoryUsage(), memoryUsage); } } assertFalse(pageSource.isFinished()); assertNull(pageSource.getNextPage()); assertTrue(pageSource.isFinished()); assertEquals(pageSource.getSystemMemoryUsage(), 0); pageSource.close(); } @Test public void testTableScanOperator() throws Exception { // Numbers used in assertions in this test may change when implementation is modified, // feel free to change them if they break in the future DriverContext driverContext = testPreparer.newDriverContext(); SourceOperator operator = testPreparer.newTableScanOperator(driverContext); assertEquals(driverContext.getSystemMemoryUsage(), 0); long memoryUsage = -1; for (int i = 0; i < 20; i++) { assertFalse(operator.isFinished()); Page page = operator.getOutput(); assertNotNull(page); page.getBlock(1); if (memoryUsage == -1) { memoryUsage = driverContext.getSystemMemoryUsage(); assertBetweenInclusive(memoryUsage, 460000L, 469999L); } else { assertEquals(driverContext.getSystemMemoryUsage(), memoryUsage); } } memoryUsage = -1; for (int i = 20; i < 40; i++) { assertFalse(operator.isFinished()); Page page = operator.getOutput(); assertNotNull(page); page.getBlock(1); if (memoryUsage == -1) { memoryUsage = driverContext.getSystemMemoryUsage(); assertBetweenInclusive(memoryUsage, 460000L, 469999L); } else { assertEquals(driverContext.getSystemMemoryUsage(), memoryUsage); } } memoryUsage = -1; for (int i = 40; i < 50; i++) { assertFalse(operator.isFinished()); Page page = operator.getOutput(); assertNotNull(page); page.getBlock(1); if (memoryUsage == -1) { memoryUsage = driverContext.getSystemMemoryUsage(); assertBetweenInclusive(memoryUsage, 360000L, 369999L); } else { assertEquals(driverContext.getSystemMemoryUsage(), memoryUsage); } } assertFalse(operator.isFinished()); assertNull(operator.getOutput()); assertTrue(operator.isFinished()); assertEquals(driverContext.getSystemMemoryUsage(), 0); } @Test public void testScanFilterAndProjectOperator() throws Exception { // Numbers used in assertions in this test may change when implementation is modified, // feel free to change them if they break in the future DriverContext driverContext = testPreparer.newDriverContext(); SourceOperator operator = testPreparer.newScanFilterAndProjectOperator(driverContext); assertEquals(driverContext.getSystemMemoryUsage(), 0); for (int i = 0; i < 50; i++) { assertFalse(operator.isFinished()); assertNotNull(operator.getOutput()); assertBetweenInclusive(driverContext.getSystemMemoryUsage(), 90_000L, 499_999L); } // done... in the current implementation finish is not set until output returns a null page assertNull(operator.getOutput()); assertTrue(operator.isFinished()); assertEquals(driverContext.getSystemMemoryUsage(), 0); } private class TestPreparer { private final FileSplit fileSplit; private final Properties schema; private final List<HiveColumnHandle> columns; private final List<Type> types; private final List<HivePartitionKey> partitionKeys; private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-%s")); public TestPreparer(String tempFilePath) throws Exception { OrcSerde serde = new OrcSerde(); schema = new Properties(); schema.setProperty("columns", testColumns.stream() .map(TestColumn::getName) .collect(Collectors.joining(","))); schema.setProperty("columns.types", testColumns.stream() .map(TestColumn::getType) .collect(Collectors.joining(","))); schema.setProperty(FILE_INPUT_FORMAT, OrcInputFormat.class.getName()); schema.setProperty(SERIALIZATION_LIB, serde.getClass().getName()); partitionKeys = testColumns.stream() .filter(TestColumn::isPartitionKey) .map(input -> new HivePartitionKey(input.getName(), HiveType.valueOf(input.getObjectInspector().getTypeName()), (String) input.getWriteValue())) .collect(toList()); ImmutableList.Builder<HiveColumnHandle> columnsBuilder = ImmutableList.builder(); ImmutableList.Builder<Type> typesBuilder = ImmutableList.builder(); int nextHiveColumnIndex = 0; for (int i = 0; i < testColumns.size(); i++) { TestColumn testColumn = testColumns.get(i); int columnIndex = testColumn.isPartitionKey() ? -1 : nextHiveColumnIndex++; ObjectInspector inspector = testColumn.getObjectInspector(); HiveType hiveType = HiveType.valueOf(inspector.getTypeName()); Type type = hiveType.getType(TYPE_MANAGER); columnsBuilder.add(new HiveColumnHandle("client_id", testColumn.getName(), hiveType, type.getTypeSignature(), columnIndex, testColumn.isPartitionKey() ? PARTITION_KEY : REGULAR, Optional.empty())); typesBuilder.add(type); } columns = columnsBuilder.build(); types = typesBuilder.build(); fileSplit = createTestFile(tempFilePath, new OrcOutputFormat(), serde, null, testColumns, NUM_ROWS); } public ConnectorPageSource newPageSource() { OrcPageSourceFactory orcPageSourceFactory = new OrcPageSourceFactory(TYPE_MANAGER, false, HDFS_ENVIRONMENT, new FileFormatDataSourceStats()); return HivePageSourceProvider.createHivePageSource( ImmutableSet.of(), ImmutableSet.of(orcPageSourceFactory), "test", new Configuration(), SESSION, fileSplit.getPath(), OptionalInt.empty(), fileSplit.getStart(), fileSplit.getLength(), schema, TupleDomain.all(), columns, partitionKeys, DateTimeZone.UTC, TYPE_MANAGER, ImmutableMap.of()) .get(); } public SourceOperator newTableScanOperator(DriverContext driverContext) { ConnectorPageSource pageSource = newPageSource(); SourceOperatorFactory sourceOperatorFactory = new TableScanOperatorFactory( 0, new PlanNodeId("0"), (session, split, columnHandles) -> pageSource, types, columns.stream().map(columnHandle -> (ColumnHandle) columnHandle).collect(toList()) ); SourceOperator operator = sourceOperatorFactory.createOperator(driverContext); operator.addSplit(new Split(new ConnectorId("test"), TestingTransactionHandle.create(), TestingSplit.createLocalSplit())); return operator; } public SourceOperator newScanFilterAndProjectOperator(DriverContext driverContext) { ConnectorPageSource pageSource = newPageSource(); ImmutableList.Builder<RowExpression> projectionsBuilder = ImmutableList.builder(); for (int i = 0; i < types.size(); i++) { projectionsBuilder.add(field(i, types.get(i))); } Supplier<CursorProcessor> cursorProcessor = EXPRESSION_COMPILER.compileCursorProcessor(Optional.empty(), projectionsBuilder.build(), "key"); Supplier<PageProcessor> pageProcessor = EXPRESSION_COMPILER.compilePageProcessor(Optional.empty(), projectionsBuilder.build()); SourceOperatorFactory sourceOperatorFactory = new ScanFilterAndProjectOperatorFactory( 0, new PlanNodeId("test"), new PlanNodeId("0"), (session, split, columnHandles) -> pageSource, cursorProcessor, pageProcessor, columns.stream().map(columnHandle -> (ColumnHandle) columnHandle).collect(toList()), types ); SourceOperator operator = sourceOperatorFactory.createOperator(driverContext); operator.addSplit(new Split(new ConnectorId("test"), TestingTransactionHandle.create(), TestingSplit.createLocalSplit())); return operator; } private DriverContext newDriverContext() { return createTaskContext(executor, testSessionBuilder().build()) .addPipelineContext(0, true, true) .addDriverContext(); } } public static FileSplit createTestFile(String filePath, HiveOutputFormat<?, ?> outputFormat, @SuppressWarnings("deprecation") SerDe serDe, String compressionCodec, List<TestColumn> testColumns, int numRows) throws Exception { // filter out partition keys, which are not written to the file testColumns = ImmutableList.copyOf(filter(testColumns, not(TestColumn::isPartitionKey))); Properties tableProperties = new Properties(); tableProperties.setProperty("columns", Joiner.on(',').join(transform(testColumns, TestColumn::getName))); tableProperties.setProperty("columns.types", Joiner.on(',').join(transform(testColumns, TestColumn::getType))); serDe.initialize(CONFIGURATION, tableProperties); JobConf jobConf = new JobConf(); if (compressionCodec != null) { CompressionCodec codec = new CompressionCodecFactory(CONFIGURATION).getCodecByName(compressionCodec); jobConf.set(COMPRESS_CODEC, codec.getClass().getName()); jobConf.set(COMPRESS_TYPE, SequenceFile.CompressionType.BLOCK.toString()); } RecordWriter recordWriter = createRecordWriter(new Path(filePath), CONFIGURATION); try { SettableStructObjectInspector objectInspector = getStandardStructObjectInspector( ImmutableList.copyOf(transform(testColumns, TestColumn::getName)), ImmutableList.copyOf(transform(testColumns, TestColumn::getObjectInspector))); Object row = objectInspector.create(); List<StructField> fields = ImmutableList.copyOf(objectInspector.getAllStructFieldRefs()); for (int rowNumber = 0; rowNumber < numRows; rowNumber++) { for (int i = 0; i < testColumns.size(); i++) { Object writeValue = testColumns.get(i).getWriteValue(); if (writeValue instanceof Slice) { writeValue = ((Slice) writeValue).getBytes(); } objectInspector.setStructFieldData(row, fields.get(i), writeValue); } Writable record = serDe.serialize(row, objectInspector); recordWriter.write(record); if (rowNumber % STRIPE_ROWS == STRIPE_ROWS - 1) { flushStripe(recordWriter); } } } finally { recordWriter.close(false); } Path path = new Path(filePath); path.getFileSystem(CONFIGURATION).setVerifyChecksum(true); File file = new File(filePath); return new FileSplit(path, 0, file.length(), new String[0]); } private static void flushStripe(RecordWriter recordWriter) { try { Field writerField = OrcOutputFormat.class.getClassLoader() .loadClass(ORC_RECORD_WRITER) .getDeclaredField("writer"); writerField.setAccessible(true); Writer writer = (Writer) writerField.get(recordWriter); Method flushStripe = WriterImpl.class.getDeclaredMethod("flushStripe"); flushStripe.setAccessible(true); flushStripe.invoke(writer); } catch (ReflectiveOperationException e) { throw Throwables.propagate(e); } } private static RecordWriter createRecordWriter(Path target, Configuration conf) throws IOException { try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(FileSystem.class.getClassLoader())) { WriterOptions options = new OrcWriterOptions(conf) .memory(new NullMemoryManager(conf)) .compress(ZLIB); try { return WRITER_CONSTRUCTOR.newInstance(target, options); } catch (ReflectiveOperationException e) { throw new RuntimeException(e); } } } private static Constructor<? extends RecordWriter> getOrcWriterConstructor() { try { Constructor<? extends RecordWriter> constructor = OrcOutputFormat.class.getClassLoader() .loadClass(ORC_RECORD_WRITER) .asSubclass(RecordWriter.class) .getDeclaredConstructor(Path.class, WriterOptions.class); constructor.setAccessible(true); return constructor; } catch (ReflectiveOperationException e) { throw Throwables.propagate(e); } } public static final class TestColumn { private final String name; private final ObjectInspector objectInspector; private final Supplier<?> writeValue; private final boolean partitionKey; public TestColumn(String name, ObjectInspector objectInspector, Supplier<?> writeValue, boolean partitionKey) { this.name = requireNonNull(name, "name is null"); this.objectInspector = requireNonNull(objectInspector, "objectInspector is null"); this.writeValue = writeValue; this.partitionKey = partitionKey; } public String getName() { return name; } public String getType() { return objectInspector.getTypeName(); } public ObjectInspector getObjectInspector() { return objectInspector; } public Object getWriteValue() { return writeValue.get(); } public boolean isPartitionKey() { return partitionKey; } @Override public String toString() { StringBuilder sb = new StringBuilder("TestColumn{"); sb.append("name='").append(name).append('\''); sb.append(", objectInspector=").append(objectInspector); sb.append(", partitionKey=").append(partitionKey); sb.append('}'); return sb.toString(); } } }