/* (c) 2014 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied.
*/
package com.linkedin.cubert.io;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import com.linkedin.cubert.utils.print;
/**
*
* @author Maneesh Varshney
*
*/
public class VariableLengthEncoder
{
private static final byte CODE_POS_ZERO = 0x1;
private static final byte CODE_NEG_ZERO = 0x2;
private static final byte CODE_NULL = 0x4;
// 0x0 = 0
// 0x80 = 10000000
// 0xC0 = 11000000
// 0xE0 = 11100000
// 0xF0 = 11110000
// 0xF8 = 11111000
// 0xFC = 11111100
// 0xFE = 11111110
// 0xFF = 11111111
private static final int[] intLeftBitsMasks = new int[] { 0x0, 0x80, 0xC0, 0xE0,
0xF0, 0xF8, 0xFC, 0xFE, 0xFF };
private static final byte[] buffer = new byte[10];
private static final IntWritable intWritable = new IntWritable();
private static final LongWritable longWritable = new LongWritable();
private static final FloatWritable floatWritable = new FloatWritable();
private static final DoubleWritable doubleWritable = new DoubleWritable();
/*** Public methods ***/
public static final void encodeNullInteger(OutputStream out) throws IOException
{
// null is encoded as -0 (sign bit = 1, value = 0).
// The sign bit for integer is the first bit
out.write(0x80);
}
public static final void encodeInteger(int num, OutputStream out) throws IOException
{
encodeInteger(num, out, 1);
}
public static final IntWritable decodeInteger(InputStream in) throws IOException
{
int firstByte = in.read();
// special case of null: null is encoded as -0 (0x80)
if (firstByte == 0x80)
return null;
boolean isNegative = ((firstByte & 0x80) != 0);
int num = extractInteger(firstByte, in, 1);
intWritable.set(isNegative ? -num : num);
return intWritable;
}
public static final void encodeNullLong(OutputStream out) throws IOException
{
// null is encoded as -0 (sign bit = 1, value = 0).
// The sign bit for integer is the first bit
out.write(0x80);
}
public static final void encodeLong(long num, OutputStream out) throws IOException
{
encodeLong(num, out, 1);
}
public static final LongWritable decodeLong(InputStream in) throws IOException
{
int firstByte = in.read();
// special case of null: null is encoded as -0 (0x80)
if (firstByte == 0x80)
return null;
boolean isNegative = ((firstByte & 0x80) != 0);
long num = extractLong(firstByte, in, 1);
longWritable.set(isNegative ? -num : num);
return longWritable;
}
public static final void encodeNullFloat(OutputStream out) throws IOException
{
out.write(CODE_NULL);
}
public static final void encodeFloat(float num, OutputStream out) throws IOException
{
// IEEE specifies floats can be 0.0 or -0.0. Handle this first
if (num == 0.0)
{
if (Float.floatToIntBits(num) == 0) // this is +0.0
out.write(CODE_POS_ZERO);
else
// this is -0.0
out.write(CODE_NEG_ZERO);
return;
}
if ((num - (int) num) != 0)
{
int i = Float.floatToIntBits(num);
fillIntBuffer(i);
out.write(buffer, 0, 5);
}
else
{
encodeInteger((int) num, out, 2);
}
}
public static final FloatWritable decodeFloat(InputStream in) throws IOException
{
int firstByte = in.read();
switch (firstByte)
{
case CODE_POS_ZERO:
floatWritable.set(0.0f);
break;
case CODE_NEG_ZERO:
floatWritable.set(-0.0f);
break;
case CODE_NULL:
return null;
case 0:
floatWritable.set(Float.intBitsToFloat(readInt(0, in, 4)));
break;
default:
boolean isNegative = ((firstByte & 0x40) != 0);
int out = extractInteger(firstByte, in, 2);
out = isNegative ? -out : out;
floatWritable.set((float) out);
}
// if (firstByte == 0)
// {
// int i = readInt(0, in, 4);
// out = Float.intBitsToFloat(i);
// }
// else
// {
// boolean isNegative = ((firstByte & 0x40) != 0);
// out = (float) extractInteger(firstByte, in, 2);
//
// out = isNegative ? -out : out;
// }
// floatWritable.set(out);
return floatWritable;
}
public static final void encodeNullDouble(OutputStream out) throws IOException
{
out.write(CODE_NULL);
}
public static final void encodeDouble(double num, OutputStream out) throws IOException
{
// IEEE specifies doubles can be 0.0 or -0.0. Handle this first
if (num == 0.0)
{
if (Double.doubleToLongBits(num) == 0) // this is +0.0
out.write(CODE_POS_ZERO);
else
// this is -0.0
out.write(CODE_NEG_ZERO);
return;
}
if ((num - (long) num) != 0)
{
long l = Double.doubleToLongBits(num);
fillLongBuffer(l);
out.write(buffer, 1, 9);
}
else
{
encodeLong((long) num, out, 2);
}
}
public static final DoubleWritable decodeDouble(InputStream in) throws IOException
{
int firstByte = in.read();
switch (firstByte)
{
case CODE_POS_ZERO:
doubleWritable.set(0.0d);
break;
case CODE_NEG_ZERO:
doubleWritable.set(-0.0d);
break;
case CODE_NULL:
return null;
case 0:
doubleWritable.set(Double.longBitsToDouble(readLong(0, in, 8)));
break;
default:
boolean isNegative = ((firstByte & 0x40) != 0);
long num = extractLong(firstByte, in, 2);
num = isNegative ? -num : num;
doubleWritable.set((double) num);
}
// double out;
// if (firstByte == 0)
// {
// long l = readLong(0, in, 8);
// out = Double.longBitsToDouble(l);
// }
// else
// {
// boolean isNegative = ((firstByte & 0x40) != 0);
// long num = extractLong(firstByte, in, 2);
//
// out = isNegative ? -num : num;
// }
// doubleWritable.set(out);
return doubleWritable;
}
/*** Private helper methods ***/
private static final void fillIntBuffer(int num)
{
buffer[0] = 0;
buffer[1] = (byte) (num >>> 24);
buffer[2] = (byte) (num >>> 16);
buffer[3] = (byte) (num >>> 8);
buffer[4] = (byte) (num >>> 0);
}
private static final void fillLongBuffer(long num)
{
buffer[0] = 0;
buffer[1] = 0;
buffer[2] = (byte) (num >>> 56);
buffer[3] = (byte) (num >>> 48);
buffer[4] = (byte) (num >>> 40);
buffer[5] = (byte) (num >>> 32);
buffer[6] = (byte) (num >>> 24);
buffer[7] = (byte) (num >>> 16);
buffer[8] = (byte) (num >>> 8);
buffer[9] = (byte) (num >>> 0);
}
private static final int readInt(int num, InputStream in, int length) throws IOException
{
switch (length)
{
case 1:
return (num << 8) | in.read();
case 2:
return (num << 16) | (in.read() << 8) | in.read();
case 3:
return (num << 24) | (in.read() << 16) | (in.read() << 8) | in.read();
case 4:
return (in.read() << 24) | (in.read() << 16) | (in.read() << 8) | in.read();
}
return num;
}
private static final long readLong(long num, InputStream in, int length) throws IOException
{
switch (length)
{
case 1:
return (num << 8) | (((long) in.read()) << 0);
case 2:
return (num << 16) | (((long) in.read()) << 8) | (((long) in.read()) << 0);
case 3:
return (num << 24) | (((long) in.read()) << 16) | (((long) in.read()) << 8)
| (((long) in.read()) << 0);
case 4:
return (num << 32) | (((long) in.read()) << 24) | (((long) in.read()) << 16)
| (((long) in.read()) << 8) | (((long) in.read()) << 0);
case 5:
return (num << 40) | (((long) in.read()) << 32) | (((long) in.read()) << 24)
| (((long) in.read()) << 16) | (((long) in.read()) << 8)
| (((long) in.read()) << 0);
case 6:
return (num << 48) | (((long) in.read()) << 40) | (((long) in.read()) << 32)
| (((long) in.read()) << 24) | (((long) in.read()) << 16)
| (((long) in.read()) << 8) | (((long) in.read()) << 0);
case 7:
return (num << 56) | (((long) in.read()) << 48) | (((long) in.read()) << 40)
| (((long) in.read()) << 32) | (((long) in.read()) << 24)
| (((long) in.read()) << 16) | (((long) in.read()) << 8)
| (((long) in.read()) << 0);
case 8:
return (((long) in.read()) << 56) | (((long) in.read()) << 48)
| (((long) in.read()) << 40) | (((long) in.read()) << 32)
| (((long) in.read()) << 24) | (((long) in.read()) << 16)
| (((long) in.read()) << 8) | (((long) in.read()) << 0);
}
return num;
}
private static final int encodePositiveInteger(int num, int numHeaderBits)
{
fillIntBuffer(num);
int offset = 4;
for (int i = 1; i <= 4; i++)
{
if (buffer[i] != 0)
{
offset =
((buffer[i] & intLeftBitsMasks[5 - i + numHeaderBits]) != 0)
? i - 1 : i;
buffer[offset] |= (intLeftBitsMasks[4 - offset] >>> numHeaderBits);
break;
}
}
return offset;
}
private static final int encodePositiveLong(long num, int numHeaderBits)
{
fillLongBuffer(num);
// first handle cases that definitely require length encoding of two bytes
if (buffer[2] != 0)
{
buffer[0] = (byte) (0xFF >>> numHeaderBits);
buffer[1] = (byte) intLeftBitsMasks[1 + numHeaderBits];
return 0;
}
if (buffer[3] != 0)
{
buffer[1] = (byte) (0xFF >>> numHeaderBits);
buffer[2] = (byte) intLeftBitsMasks[numHeaderBits];
return 1;
}
if (buffer[4] != 0 && numHeaderBits == 2)
{
buffer[2] = (byte) (0xFF >>> numHeaderBits);
buffer[3] = (byte) intLeftBitsMasks[1];
return 2;
}
int offset = 9;
for (int i = 4; i <= 9; i++)
{
if (buffer[i] != 0)
{
offset =
((buffer[i] & intLeftBitsMasks[10 - i + numHeaderBits]) != 0)
? i - 1 : i;
buffer[offset] |= (intLeftBitsMasks[9 - offset] >>> numHeaderBits);
break;
}
}
return offset;
}
private static final int getLength(int firstByte, int numHeaderBits, int maxSetBits)
{
int length;
for (length = maxSetBits; length >= 0; length--)
{
if ((firstByte & intLeftBitsMasks[length + numHeaderBits]) == intLeftBitsMasks[length
+ numHeaderBits])
break;
}
return length;
}
private final static int extractInteger(int firstByte,
InputStream in,
int numHeaderBits) throws IOException
{
int num = firstByte;
firstByte |= intLeftBitsMasks[numHeaderBits];
int length = getLength(firstByte, numHeaderBits, 4);
num &= ~intLeftBitsMasks[length + numHeaderBits];
return readInt(num, in, length);
}
private static final long extractLong(int firstByte, InputStream in, int numHeaderBits) throws IOException
{
long num = firstByte;
firstByte |= intLeftBitsMasks[numHeaderBits];
int length;
if (firstByte == 0xFF)
{
// need to look into the next byte to read the length
int secondByte = in.read();
length = getLength(secondByte, 0, 3);
num = ((long) secondByte) & ~intLeftBitsMasks[length];
length += 7 - numHeaderBits;
}
else
{
length = getLength(firstByte, numHeaderBits, numHeaderBits == 2 ? 6 : 7);
num &= ~intLeftBitsMasks[length + numHeaderBits];
}
return readLong(num, in, length);
}
private static final void encodeInteger(int num, OutputStream out, int numHeaderBits) throws IOException
{
int headerBit = numHeaderBits == 2 ? 0x80 : 0;
int signBit = 0;
if (num < 0)
{
num = -num;
signBit = (numHeaderBits == 2) ? 0x40 : 0x80;
}
int offset = encodePositiveInteger(num, numHeaderBits);
buffer[offset] = (byte) (headerBit | signBit | buffer[offset]);
out.write(buffer, offset, 5 - offset);
}
private static final void encodeLong(long num, OutputStream out, int numHeaderBits) throws IOException
{
int headerBit = numHeaderBits == 2 ? 0x80 : 0;
int signBit = 0;
if (num < 0)
{
num = -num;
signBit = (numHeaderBits == 2) ? 0x40 : 0x80;
}
int offset = encodePositiveLong(num, numHeaderBits);
buffer[offset] = (byte) (headerBit | signBit | buffer[offset]);
out.write(buffer, offset, 10 - offset);
}
static final void test(Number num) throws IOException
{
final String[] prefix =
new String[] { "00000000", "0000000", "000000", "00000", "0000", "000",
"00", "0" };
ByteArrayOutputStream bos = new ByteArrayOutputStream();
if (num instanceof Integer)
encodeInteger(num.intValue(), bos);
else if (num instanceof Long)
encodeLong(num.longValue(), bos);
else if (num instanceof Float)
encodeFloat(num.floatValue(), bos);
else if (num instanceof Double)
encodeDouble(num.doubleValue(), bos);
byte[] bytes = bos.toByteArray();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < bytes.length; i++)
{
String str = Integer.toBinaryString(bytes[i]);
if (str.length() > 8)
str = str.substring(str.length() - 8);
if (str.length() < 8)
str = prefix[str.length()] + str;
sb.append(str);
sb.append(" ");
}
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
Number dec = null;
if (num instanceof Integer)
dec = decodeInteger(bis).get();
else if (num instanceof Long)
dec = decodeLong(bis).get();
else if (num instanceof Float)
dec = decodeFloat(bis).get();
else if (num instanceof Double)
dec = decodeDouble(bis).get();
print.f("%-5s %s %-16s %-16s",
Boolean.toString(num.equals(dec)),
sb.toString(),
num.toString(),
dec.toString());
}
private static void test(Number[] array) throws IOException
{
}
public static void main(String[] args) throws IOException
{
test((float) 2147483647);
test((float) -2147483647);
ArrayList<Number> list = new ArrayList<Number>();
for (int i = 0; i <= 32; i++)
{
int num = 0;
for (int j = 0; j < i; j++)
num |= 1 << j;
list.add(num);
list.add(-num);
list.add((float) num);
list.add(-(float) num);
}
for (int i = 0; i <= 64; i++)
{
long num = 0;
for (int j = 0; j < i; j++)
num |= 1L << j;
list.add(-num);
list.add(num);
list.add((double) num);
list.add(-(double) num);
}
test(list.toArray(new Number[] {}));
Number[] specials =
new Number[] { Integer.MAX_VALUE, Integer.MIN_VALUE, Long.MAX_VALUE,
Long.MIN_VALUE, Double.NaN, Double.POSITIVE_INFINITY,
Double.NEGATIVE_INFINITY, Double.MAX_VALUE, Double.MIN_VALUE,
Float.MIN_VALUE, Float.MAX_VALUE, Float.POSITIVE_INFINITY,
Float.NEGATIVE_INFINITY };
test(specials);
}
}