/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.facebook.presto.orc; import com.facebook.presto.orc.metadata.statistics.ColumnStatistics; import com.facebook.presto.spi.type.CharType; import com.facebook.presto.spi.type.DecimalType; import com.facebook.presto.spi.type.SqlDate; import com.facebook.presto.spi.type.SqlDecimal; import com.facebook.presto.spi.type.SqlTimestamp; import com.facebook.presto.spi.type.Type; import com.facebook.presto.spi.type.VarbinaryType; import com.facebook.presto.spi.type.VarcharType; import com.google.common.collect.Iterables; import com.google.common.collect.Ordering; import io.airlift.slice.Slice; import io.airlift.slice.Slices; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.spi.type.DateType.DATE; import static com.facebook.presto.spi.type.DoubleType.DOUBLE; import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.spi.type.RealType.REAL; import static com.facebook.presto.spi.type.SmallintType.SMALLINT; import static com.facebook.presto.spi.type.StandardTypes.ARRAY; import static com.facebook.presto.spi.type.StandardTypes.MAP; import static com.facebook.presto.spi.type.StandardTypes.ROW; import static com.facebook.presto.spi.type.TimestampType.TIMESTAMP; import static com.facebook.presto.spi.type.TinyintType.TINYINT; import static com.google.common.base.Predicates.equalTo; import static com.google.common.base.Predicates.notNull; import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Lists.newArrayList; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; public final class TestingOrcPredicate { private static final int ORC_ROW_GROUP_SIZE = 10_000; private TestingOrcPredicate() { } public static OrcPredicate createOrcPredicate(Type type, Iterable<?> values, boolean noFileStats) { List<Object> expectedValues = newArrayList(values); if (BOOLEAN.equals(type)) { return new BooleanOrcPredicate(expectedValues, noFileStats); } if (TINYINT.equals(type) || SMALLINT.equals(type) || INTEGER.equals(type) || BIGINT.equals(type)) { return new LongOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : ((Number) value).longValue()) .collect(toList()), noFileStats); } if (TIMESTAMP.equals(type)) { return new LongOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : ((SqlTimestamp) value).getMillisUtc()) .collect(toList()), noFileStats); } if (DATE.equals(type)) { return new DateOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : (long) ((SqlDate) value).getDays()) .collect(toList()), noFileStats); } if (REAL.equals(type) || DOUBLE.equals(type)) { return new DoubleOrcPredicate( expectedValues.stream() .map(value -> value == null ? null : ((Number) value).doubleValue()) .collect(toList()), noFileStats); } if (type instanceof VarbinaryType) { // binary does not have stats return new BasicOrcPredicate<>(expectedValues, Object.class, noFileStats); } if (type instanceof VarcharType) { return new StringOrcPredicate(expectedValues, noFileStats); } if (type instanceof CharType) { return new CharOrcPredicate(expectedValues, noFileStats); } if (type instanceof DecimalType) { return new DecimalOrcPredicate(expectedValues, noFileStats); } String baseType = type.getTypeSignature().getBase(); if (ARRAY.equals(baseType) || MAP.equals(baseType) || ROW.equals(baseType)) { return new BasicOrcPredicate<>(expectedValues, Object.class, noFileStats); } throw new IllegalArgumentException("Unsupported type " + type); } public static class BasicOrcPredicate<T> implements OrcPredicate { private final List<T> expectedValues; private final boolean noFileStats; public BasicOrcPredicate(Iterable<?> expectedValues, Class<T> type, boolean noFileStats) { List<T> values = new ArrayList<>(); for (Object expectedValue : expectedValues) { values.add(type.cast(expectedValue)); } this.expectedValues = Collections.unmodifiableList(values); this.noFileStats = noFileStats; } @Override public boolean matches(long numberOfRows, Map<Integer, ColumnStatistics> statisticsByColumnIndex) { ColumnStatistics columnStatistics = statisticsByColumnIndex.get(0); // todo enable file stats when DWRF team verifies that the stats are correct // assertTrue(columnStatistics.hasNumberOfValues()); if (noFileStats && numberOfRows == expectedValues.size()) { assertNull(columnStatistics); return true; } if (numberOfRows == expectedValues.size()) { // whole file assertChunkStats(expectedValues, columnStatistics); } else if (numberOfRows == ORC_ROW_GROUP_SIZE) { // middle section boolean foundMatch = false; int length; for (int offset = 0; offset < expectedValues.size(); offset += length) { length = Math.min(ORC_ROW_GROUP_SIZE, expectedValues.size() - offset); if (chunkMatchesStats(expectedValues.subList(offset, offset + length), columnStatistics)) { foundMatch = true; break; } } assertTrue(foundMatch); } else if (numberOfRows == expectedValues.size() % ORC_ROW_GROUP_SIZE) { // tail section List<T> chunk = expectedValues.subList((int) (expectedValues.size() - numberOfRows), expectedValues.size()); assertChunkStats(chunk, columnStatistics); } else { fail("Unexpected number of rows: " + numberOfRows); } return true; } private void assertChunkStats(List<T> chunk, ColumnStatistics columnStatistics) { assertTrue(chunkMatchesStats(chunk, columnStatistics)); } protected boolean chunkMatchesStats(List<T> chunk, ColumnStatistics columnStatistics) { // verify non null count if (columnStatistics.getNumberOfValues() != Iterables.size(filter(chunk, notNull()))) { return false; } return true; } } public static class BooleanOrcPredicate extends BasicOrcPredicate<Boolean> { public BooleanOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, Boolean.class, noFileStats); } @Override protected boolean chunkMatchesStats(List<Boolean> chunk, ColumnStatistics columnStatistics) { assertNull(columnStatistics.getIntegerStatistics()); assertNull(columnStatistics.getDoubleStatistics()); assertNull(columnStatistics.getStringStatistics()); assertNull(columnStatistics.getDateStatistics()); // check basic statistics if (!super.chunkMatchesStats(chunk, columnStatistics)) { return false; } // statistics can be missing for any reason if (columnStatistics.getBooleanStatistics() != null) { if (columnStatistics.getBooleanStatistics().getTrueValueCount() != Iterables.size(filter(chunk, equalTo(Boolean.TRUE)))) { return false; } } return true; } } public static class DoubleOrcPredicate extends BasicOrcPredicate<Double> { public DoubleOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, Double.class, noFileStats); } @Override protected boolean chunkMatchesStats(List<Double> chunk, ColumnStatistics columnStatistics) { assertNull(columnStatistics.getBooleanStatistics()); assertNull(columnStatistics.getIntegerStatistics()); assertNull(columnStatistics.getStringStatistics()); assertNull(columnStatistics.getDateStatistics()); // check basic statistics if (!super.chunkMatchesStats(chunk, columnStatistics)) { return false; } // statistics can be missing for any reason if (columnStatistics.getDoubleStatistics() != null) { // verify min if (Math.abs(columnStatistics.getDoubleStatistics().getMin() - Ordering.natural().nullsLast().min(chunk)) > 0.001) { return false; } // verify max if (Math.abs(columnStatistics.getDoubleStatistics().getMax() - Ordering.natural().nullsFirst().max(chunk)) > 0.001) { return false; } } return true; } } private static class DecimalOrcPredicate extends BasicOrcPredicate<SqlDecimal> { public DecimalOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, SqlDecimal.class, noFileStats); } } public static class LongOrcPredicate extends BasicOrcPredicate<Long> { public LongOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, Long.class, noFileStats); } @Override protected boolean chunkMatchesStats(List<Long> chunk, ColumnStatistics columnStatistics) { assertNull(columnStatistics.getBooleanStatistics()); assertNull(columnStatistics.getDoubleStatistics()); assertNull(columnStatistics.getStringStatistics()); assertNull(columnStatistics.getDateStatistics()); // check basic statistics if (!super.chunkMatchesStats(chunk, columnStatistics)) { return false; } // statistics can be missing for any reason if (columnStatistics.getIntegerStatistics() != null) { // verify min if (!columnStatistics.getIntegerStatistics().getMin().equals(Ordering.natural().nullsLast().min(chunk))) { return false; } // verify max if (!columnStatistics.getIntegerStatistics().getMax().equals(Ordering.natural().nullsFirst().max(chunk))) { return false; } } return true; } } public static class StringOrcPredicate extends BasicOrcPredicate<String> { public StringOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, String.class, noFileStats); } @Override protected boolean chunkMatchesStats(List<String> chunk, ColumnStatistics columnStatistics) { assertNull(columnStatistics.getBooleanStatistics()); assertNull(columnStatistics.getIntegerStatistics()); assertNull(columnStatistics.getDoubleStatistics()); assertNull(columnStatistics.getDateStatistics()); // check basic statistics if (!super.chunkMatchesStats(chunk, columnStatistics)) { return false; } List<Slice> slices = chunk.stream() .filter(Objects::nonNull) .map(Slices::utf8Slice) .collect(toList()); // statistics can be missing for any reason if (columnStatistics.getStringStatistics() != null) { // verify min Slice chunkMin = Ordering.natural().nullsLast().min(slices); if (columnStatistics.getStringStatistics().getMin().compareTo(chunkMin) > 0) { return false; } // verify max Slice chunkMax = Ordering.natural().nullsFirst().max(slices); if (columnStatistics.getStringStatistics().getMax().compareTo(chunkMax) < 0) { return false; } } return true; } } public static class CharOrcPredicate extends BasicOrcPredicate<String> { public CharOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, String.class, noFileStats); } @Override protected boolean chunkMatchesStats(List<String> chunk, ColumnStatistics columnStatistics) { assertNull(columnStatistics.getBooleanStatistics()); assertNull(columnStatistics.getIntegerStatistics()); assertNull(columnStatistics.getDoubleStatistics()); assertNull(columnStatistics.getDateStatistics()); // check basic statistics if (!super.chunkMatchesStats(chunk, columnStatistics)) { return false; } List<String> strings = chunk.stream() .filter(Objects::nonNull) .map(String::trim) .collect(toList()); // statistics can be missing for any reason if (columnStatistics.getStringStatistics() != null) { // verify min String chunkMin = Ordering.natural().nullsLast().min(strings); if (columnStatistics.getStringStatistics().getMin().toStringUtf8().trim().compareTo(chunkMin) > 0) { return false; } // verify max String chunkMax = Ordering.natural().nullsFirst().max(strings); if (columnStatistics.getStringStatistics().getMax().toStringUtf8().trim().compareTo(chunkMax) < 0) { return false; } } return true; } } public static class DateOrcPredicate extends BasicOrcPredicate<Long> { public DateOrcPredicate(Iterable<?> expectedValues, boolean noFileStats) { super(expectedValues, Long.class, noFileStats); } @Override protected boolean chunkMatchesStats(List<Long> chunk, ColumnStatistics columnStatistics) { assertNull(columnStatistics.getBooleanStatistics()); assertNull(columnStatistics.getIntegerStatistics()); assertNull(columnStatistics.getDoubleStatistics()); assertNull(columnStatistics.getStringStatistics()); // check basic statistics if (!super.chunkMatchesStats(chunk, columnStatistics)) { return false; } // statistics can be missing for any reason if (columnStatistics.getDateStatistics() != null) { // verify min Long min = columnStatistics.getDateStatistics().getMin().longValue(); if (!min.equals(Ordering.natural().nullsLast().min(chunk))) { return false; } // verify max Long statMax = columnStatistics.getDateStatistics().getMax().longValue(); Long chunkMax = Ordering.natural().nullsFirst().max(chunk); if (!statMax.equals(chunkMax)) { return false; } } return true; } } }