/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.orc.stream;
import com.facebook.presto.orc.OrcCorruptionException;
import com.facebook.presto.orc.OrcDataSourceId;
import com.facebook.presto.orc.memory.AggregatedMemoryContext;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.testng.annotations.Test;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.math.BigInteger;
import java.util.Optional;
import static com.facebook.presto.spi.type.Decimals.MAX_DECIMAL_UNSCALED_VALUE;
import static com.facebook.presto.spi.type.Decimals.MIN_DECIMAL_UNSCALED_VALUE;
import static com.facebook.presto.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimal;
import static com.facebook.presto.spi.type.UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger;
import static java.math.BigInteger.ONE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
public class TestDecimalStream
{
private static final BigInteger BIG_INTEGER_127_BIT_SET;
static {
BigInteger b = BigInteger.ZERO;
for (int i = 0; i < 127; ++i) {
b = b.setBit(i);
}
BIG_INTEGER_127_BIT_SET = b;
}
@Test
public void testShortDecimals()
throws IOException
{
assertReadsShortValue(0L);
assertReadsShortValue(1L);
assertReadsShortValue(-1L);
assertReadsShortValue(256L);
assertReadsShortValue(-256L);
assertReadsShortValue(Long.MAX_VALUE);
assertReadsShortValue(Long.MIN_VALUE);
}
@Test
public void testShouldFailWhenShortDecimalDoesNotFit()
throws IOException
{
assertShortValueReadFails(BigInteger.valueOf(Long.MAX_VALUE).add(ONE));
}
@Test
public void testShouldFailWhenExceeds128Bits()
throws IOException
{
assertLongValueReadFails(BigInteger.valueOf(1).shiftLeft(127));
assertLongValueReadFails(BigInteger.valueOf(-2).shiftLeft(127));
}
@Test
public void testLongDecimals()
throws IOException
{
assertReadsLongValue(BigInteger.valueOf(0L));
assertReadsLongValue(BigInteger.valueOf(1L));
assertReadsLongValue(BigInteger.valueOf(-1L));
assertReadsLongValue(BigInteger.valueOf(-1).shiftLeft(126));
assertReadsLongValue(BigInteger.valueOf(1).shiftLeft(126));
assertReadsLongValue(BIG_INTEGER_127_BIT_SET);
assertReadsLongValue(BIG_INTEGER_127_BIT_SET.negate());
assertReadsLongValue(MAX_DECIMAL_UNSCALED_VALUE);
assertReadsLongValue(MIN_DECIMAL_UNSCALED_VALUE);
}
@Test
public void testSkipsValue()
throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
writeBigInteger(baos, BigInteger.valueOf(Long.MAX_VALUE));
writeBigInteger(baos, BigInteger.valueOf(Long.MIN_VALUE));
OrcInputStream inputStream = orcInputStreamFor("skip test", baos.toByteArray());
DecimalInputStream stream = new DecimalInputStream(inputStream);
stream.skip(1);
assertEquals(stream.nextLong(), Long.MIN_VALUE);
}
private static void assertReadsShortValue(long value)
throws IOException
{
DecimalInputStream stream = new DecimalInputStream(decimalInputStream(BigInteger.valueOf(value)));
assertEquals(stream.nextLong(), value);
}
private static void assertReadsLongValue(BigInteger value)
throws IOException
{
DecimalInputStream stream = new DecimalInputStream(decimalInputStream(value));
Slice decimal = unscaledDecimal();
stream.nextLongDecimal(decimal);
assertEquals(unscaledDecimalToBigInteger(decimal), value);
}
private static void assertShortValueReadFails(BigInteger value)
throws IOException
{
assertThrows(OrcCorruptionException.class, () -> {
DecimalInputStream stream = new DecimalInputStream(decimalInputStream(value));
stream.nextLong();
});
}
private static void assertLongValueReadFails(BigInteger value)
throws IOException
{
Slice decimal = unscaledDecimal();
assertThrows(OrcCorruptionException.class, () -> {
DecimalInputStream stream = new DecimalInputStream(decimalInputStream(value));
stream.nextLongDecimal(decimal);
});
}
private static OrcInputStream decimalInputStream(BigInteger value)
throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
writeBigInteger(baos, value);
return orcInputStreamFor(value.toString(), baos.toByteArray());
}
private static OrcInputStream orcInputStreamFor(String source, byte[] bytes)
{
return new OrcInputStream(new OrcDataSourceId(source), new BasicSliceInput(Slices.wrappedBuffer(bytes)), Optional.empty(), new AggregatedMemoryContext());
}
// copied from org.apache.hadoop.hive.ql.io.orc.SerializationUtils.java
private static void writeBigInteger(OutputStream output, BigInteger value)
throws IOException
{
// encode the signed number as a positive integer
value = value.shiftLeft(1);
int sign = value.signum();
if (sign < 0) {
value = value.negate();
value = value.subtract(ONE);
}
int length = value.bitLength();
while (true) {
long lowBits = value.longValue() & 0x7fffffffffffffffL;
length -= 63;
// write out the next 63 bits worth of data
for (int i = 0; i < 9; ++i) {
// if this is the last byte, leave the high bit off
if (length <= 0 && (lowBits & ~0x7f) == 0) {
output.write((byte) lowBits);
return;
}
else {
output.write((byte) (0x80 | (lowBits & 0x7f)));
lowBits >>>= 7;
}
}
value = value.shiftRight(63);
}
}
}