/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.exec.spark; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.concurrent.LinkedBlockingQueue; import org.apache.hadoop.hive.ql.io.HiveKey; import org.apache.hadoop.io.BytesWritable; import org.junit.Test; import scala.Tuple2; import com.clearspring.analytics.util.Preconditions; @SuppressWarnings({"unchecked", "rawtypes"}) public class TestHiveKVResultCache { @Test public void testSimple() throws Exception { // Create KV result cache object, add one (k,v) pair and retrieve them. HiveKVResultCache cache = new HiveKVResultCache(); HiveKey key = new HiveKey("key".getBytes(), "key".hashCode()); BytesWritable value = new BytesWritable("value".getBytes()); cache.add(key, value); assertTrue("KV result cache should have at least one element", cache.hasNext()); Tuple2<HiveKey, BytesWritable> row = cache.next(); assertTrue("Incorrect key", row._1().equals(key)); assertTrue("Incorrect value", row._2().equals(value)); assertTrue("Cache shouldn't have more records", !cache.hasNext()); } @Test public void testSpilling() throws Exception { HiveKVResultCache cache = new HiveKVResultCache(); final int recordCount = HiveKVResultCache.IN_MEMORY_NUM_ROWS * 3; // Test using the same cache where first n rows are inserted then cache is cleared. // Next reuse the same cache and insert another m rows and verify the cache stores correctly. // This simulates reusing the same cache over and over again. testSpillingHelper(cache, recordCount); testSpillingHelper(cache, 1); testSpillingHelper(cache, recordCount); } /** Helper method which inserts numRecords and retrieves them from cache and verifies */ private void testSpillingHelper(HiveKVResultCache cache, int numRecords) { for(int i=0; i<numRecords; i++) { String key = "key_" + i; String value = "value_" + i; cache.add(new HiveKey(key.getBytes(), key.hashCode()), new BytesWritable(value.getBytes())); } int recordsSeen = 0; while(cache.hasNext()) { String key = "key_" + recordsSeen; String value = "value_" + recordsSeen; Tuple2<HiveKey, BytesWritable> row = cache.next(); assertTrue("Unexpected key at position: " + recordsSeen, new String(row._1().getBytes()).equals(key)); assertTrue("Unexpected value at position: " + recordsSeen, new String(row._2().getBytes()).equals(value)); recordsSeen++; } assertTrue("Retrieved record count doesn't match inserted record count", numRecords == recordsSeen); cache.clear(); } @Test public void testResultList() throws Exception { scanAndVerify(10000, 0, 0, "a", "b"); scanAndVerify(10000, 511, 0, "a", "b"); scanAndVerify(10000, 511 * 2, 0, "a", "b"); scanAndVerify(10000, 511, 10, "a", "b"); scanAndVerify(10000, 511 * 2, 10, "a", "b"); scanAndVerify(10000, 512, 0, "a", "b"); scanAndVerify(10000, 512 * 2, 0, "a", "b"); scanAndVerify(10000, 512, 3, "a", "b"); scanAndVerify(10000, 512 * 6, 10, "a", "b"); scanAndVerify(10000, 512 * 7, 5, "a", "b"); scanAndVerify(10000, 512 * 9, 19, "a", "b"); scanAndVerify(10000, 1, 0, "a", "b"); scanAndVerify(10000, 1, 1, "a", "b"); } private static void scanAndVerify( long rows, int threshold, int separate, String prefix1, String prefix2) { ArrayList<Tuple2<HiveKey, BytesWritable>> output = new ArrayList<Tuple2<HiveKey, BytesWritable>>((int)rows); scanResultList(rows, threshold, separate, output, prefix1, prefix2); assertEquals(rows, output.size()); long primaryRows = rows * (100 - separate) / 100; long separateRows = rows - primaryRows; HashSet<Long> primaryRowKeys = new HashSet<Long>(); HashSet<Long> separateRowKeys = new HashSet<Long>(); for (Tuple2<HiveKey, BytesWritable> item: output) { String key = bytesWritableToString(item._1); String value = bytesWritableToString(item._2); String prefix = key.substring(0, key.indexOf('_')); Long id = Long.valueOf(key.substring(5 + prefix.length())); if (prefix.equals(prefix1)) { assertTrue(id >= 0 && id < primaryRows); primaryRowKeys.add(id); } else { assertEquals(prefix2, prefix); assertTrue(id >= 0 && id < separateRows); separateRowKeys.add(id); } assertEquals(prefix + "_value_" + id, value); } assertEquals(separateRows, separateRowKeys.size()); assertEquals(primaryRows, primaryRowKeys.size()); } /** * Convert a BytesWritable to a string. * Don't use {@link BytesWritable#copyBytes()} * so as to be compatible with hadoop 1 */ private static String bytesWritableToString(BytesWritable bw) { int size = bw.getLength(); byte[] bytes = new byte[size]; System.arraycopy(bw.getBytes(), 0, bytes, 0, size); return new String(bytes); } private static class MyHiveFunctionResultList extends HiveBaseFunctionResultList { private static final long serialVersionUID = -1L; // Total rows to emit during the whole iteration, // excluding the rows emitted by the separate thread. private long primaryRows; // Batch of rows to emit per processNextRecord() call. private int thresholdRows; // Rows to be emitted with a separate thread per processNextRecord() call. private long separateRows; // Thread to generate the separate rows beside the normal thread. private Thread separateRowGenerator; // Counter for rows emitted private long rowsEmitted; private long separateRowsEmitted; // Prefix for primary row keys private String prefix1; // Prefix for separate row keys private String prefix2; // A queue to notify separateRowGenerator to generate the next batch of rows. private LinkedBlockingQueue<Boolean> queue; MyHiveFunctionResultList(Iterator inputIterator) { super(inputIterator); } void init(long rows, int threshold, int separate, String p1, String p2) { Preconditions.checkArgument((threshold > 0 || separate == 0) && separate < 100 && separate >= 0 && rows > 0); primaryRows = rows * (100 - separate) / 100; separateRows = rows - primaryRows; thresholdRows = threshold; prefix1 = p1; prefix2 = p2; if (separateRows > 0) { separateRowGenerator = new Thread(new Runnable() { @Override public void run() { try { long separateBatchSize = thresholdRows * separateRows / primaryRows; while (!queue.take().booleanValue()) { for (int i = 0; i < separateBatchSize; i++) { collect(prefix2, separateRowsEmitted++); } } } catch (InterruptedException e) { e.printStackTrace(); } for (; separateRowsEmitted < separateRows;) { collect(prefix2, separateRowsEmitted++); } } }); queue = new LinkedBlockingQueue<Boolean>(); separateRowGenerator.start(); } } public void collect(String prefix, long id) { String k = prefix + "_key_" + id; String v = prefix + "_value_" + id; HiveKey key = new HiveKey(k.getBytes(), k.hashCode()); BytesWritable value = new BytesWritable(v.getBytes()); try { collect(key, value); } catch (IOException e) { e.printStackTrace(); } } @Override protected void processNextRecord(Object inputRecord) throws IOException { for (int i = 0; i < thresholdRows; i++) { collect(prefix1, rowsEmitted++); } if (separateRowGenerator != null) { queue.add(Boolean.FALSE); } } @Override protected boolean processingDone() { return false; } @Override protected void closeRecordProcessor() { for (; rowsEmitted < primaryRows;) { collect(prefix1, rowsEmitted++); } if (separateRowGenerator != null) { queue.add(Boolean.TRUE); try { separateRowGenerator.join(); } catch (InterruptedException e) { e.printStackTrace(); } } } } private static long scanResultList(long rows, int threshold, int separate, List<Tuple2<HiveKey, BytesWritable>> output, String prefix1, String prefix2) { final long iteratorCount = threshold == 0 ? 1 : rows * (100 - separate) / 100 / threshold; MyHiveFunctionResultList resultList = new MyHiveFunctionResultList(new Iterator() { // Input record iterator, not used private int i = 0; @Override public boolean hasNext() { return i++ < iteratorCount; } @Override public Object next() { return Integer.valueOf(i); } @Override public void remove() { } }); resultList.init(rows, threshold, separate, prefix1, prefix2); long startTime = System.currentTimeMillis(); while (resultList.hasNext()) { Object item = resultList.next(); if (output != null) { output.add((Tuple2<HiveKey, BytesWritable>)item); } } long endTime = System.currentTimeMillis(); return endTime - startTime; } private static long[] scanResultList(long rows, int threshold, int extra) { // 1. Simulate emitting all records in closeRecordProcessor(). long t1 = scanResultList(rows, 0, 0, null, "a", "b"); // 2. Simulate emitting records in processNextRecord() with small memory usage limit. long t2 = scanResultList(rows, threshold, 0, null, "c", "d"); // 3. Simulate emitting records in processNextRecord() with large memory usage limit. long t3 = scanResultList(rows, threshold * 10, 0, null, "e", "f"); // 4. Same as 2. Also emit extra records from a separate thread. long t4 = scanResultList(rows, threshold, extra, null, "g", "h"); // 5. Same as 3. Also emit extra records from a separate thread. long t5 = scanResultList(rows, threshold * 10, extra, null, "i", "j"); return new long[] {t1, t2, t3, t4, t5}; } public static void main(String[] args) throws Exception { long rows = 1000000; // total rows to generate int threshold = 512; // # of rows to cache at most int extra = 5; // percentile of extra rows to generate by a different thread if (args.length > 0) { rows = Long.parseLong(args[0]); } if (args.length > 1) { threshold = Integer.parseInt(args[1]); } if (args.length > 2) { extra = Integer.parseInt(args[2]); } // Warm up couple times for (int i = 0; i < 2; i++) { scanResultList(rows, threshold, extra); } int count = 5; long[] t = new long[count]; // Run count times and get average for (int i = 0; i < count; i++) { long[] tmp = scanResultList(rows, threshold, extra); for (int k = 0; k < count; k++) { t[k] += tmp[k]; } } for (int i = 0; i < count; i++) { t[i] /= count; } System.out.println(t[0] + "\t" + t[1] + "\t" + t[2] + "\t" + t[3] + "\t" + t[4]); } }