package org.jcodec.common;
import org.jcodec.codecs.vpx.VPXBooleanDecoder;
import org.junit.Assert;
import org.junit.Test;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.IllegalArgumentException;
import java.lang.StringBuilder;
import java.lang.System;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* See for theoretical details: http://www.youtube.com/playlist?list=PLE125425EC837021F
*/
public class ArithmeticCoderTest {
private static final int PRECISSION = 8;
private static final int[] rs = new int[]{ 10, 25, 11, 15, 10 };
public static String printArrayAsHex(byte[] b){
StringBuilder sb = new StringBuilder("{");
if (b.length > 0){
sb.append("0x").append(Integer.toHexString(b[0]&0xff).toUpperCase());
for (int i=1;i<b.length;i++)
sb.append(", 0x").append(Integer.toHexString(b[i]&0xff).toUpperCase());
}
sb.append("}");
return sb.toString();
}
@Test
public void testPrinting() throws Exception {
Assert.assertEquals("{0xD8}", printArrayAsHex(new byte[]{(byte)0xD8}));
Assert.assertEquals("{0xD8, 0x44}", printArrayAsHex(new byte[]{(byte)0xD8, 0x44}));
}
@Test
public void testEncoder() throws IOException {
ArithmeticCoder ac = new ArithmeticCoder(PRECISSION, rs);
ac.encode(Arrays.asList(new Integer[] { 1, 2, 3, 4, 0 }));
Assert.assertArrayEquals(new byte[]{0x5C, 0x18}, ac.e.getArray());
ac.encode(Arrays.asList(new Integer[] { 4, 0 }));
Assert.assertArrayEquals(new byte[]{(byte)0xDC}, ac.e.getArray());
ac.encode(Arrays.asList(new Integer[] { 1, 1, 1, 2, 0 }));
Assert.assertArrayEquals(new byte[]{0x3A, (byte)0x80}, ac.e.getArray());
}
@Test
public void testDecoder() throws Exception {
ArithmeticDecoder ad = new ArithmeticDecoder(PRECISSION, rs);
ad.decode(new byte[] { 0x5C, 0x18 });
System.out.println(ad.data);
ad.decode(new byte[] { (byte) 0xDC });
System.out.println(ad.data);
ad.decode(new byte[] { 0x3A, (byte)0x80 });
System.out.println(ad.data);
}
@Test
public void testCodingAndDecoding() throws Exception {
int[] smallRs = new int[]{ 2, 5, 1, 3, 2 };
ArithmeticCoder ac = new ArithmeticCoder(PRECISSION, smallRs);
ArithmeticDecoder ad = new ArithmeticDecoder(PRECISSION, smallRs);
List<Integer> asList = Arrays.asList(new Integer[] { 1, 2, 3, 0 });
ac.encode(asList);
ad.decode(ac.e.getArray());
Assert.assertEquals(asList, ad.data);
asList = Arrays.asList(new Integer[] { 2, 3, 4, 0 });
ac.encode(asList);
ad.decode(ac.e.getArray());
Assert.assertEquals(asList, ad.data);
asList = Arrays.asList(new Integer[] { 1, 2, 4, 0 });
ac.encode(asList);
ad.decode(ac.e.getArray());
Assert.assertEquals(asList, ad.data);
asList = Arrays.asList(new Integer[] { 1, 3, 4, 0 });
ac.encode(asList);
ad.decode(ac.e.getArray());
Assert.assertEquals(asList, ad.data);
asList = Arrays.asList(new Integer[] { 4, 3, 4, 0 });
ac.encode(asList);
ad.decode(ac.e.getArray());
Assert.assertEquals(asList, ad.data);
asList = Arrays.asList(new Integer[] { 1, 2, 3, 4, 0 });
ac.encode(asList);
ad.decode(ac.e.getArray());
Assert.assertEquals(asList, ad.data);
}
public static class ArithmeticCoder {
public final long precission;
public final long whole;
public final long half;
public final long quater;
public final int[] r; // R = 13
public final int[] c; // probability borders, c[0]=0, c[i] = r[0]+...+r[i-1]
public final int[] d; // probablility gap sizes, d[i] = c[i]+r[i]
public final int R;
public Emitter e;
public ArithmeticCoder(int precission, int[] r){
this.precission = precission;
this.whole = (1L << precission);
this.half = whole >> 1;
this.quater = whole >> 2;
this.r = r;
this.c = new int[r.length];
this.d = new int[r.length];
this.d[0] = this.r[0];
this.c[0] = 0;
int bigR = r[0];
for(int i=1;i<r.length;i++){
for(int k=0;k<i;k++){
c[i] += r[k];
}
d[i] = c[i]+r[i];
bigR += r[i];
}
this.R = bigR;
}
public static long sOnes(long s) {
return (1 << s) - 1;
}
public void emitZeroAndSOnes(long s) throws IOException {
// System.out.print(0);
e.emit(0);
while (s > 0) {
// System.out.print(1);
e.emit(1);
s--;
}
}
private void emitOneAndSZeros(long s) throws IOException {
// System.out.print(1);
e.emit(1);
while (s > 0) {
// System.out.print(0);
e.emit(0);
s--;
}
}
public void encode(List<Integer> input) throws IOException {
e = new Emitter();
long a = 0L;
long b = whole;
long s = 0;
for (int index = 0; index < input.size(); index++) {
long omega = b - a;
b = a + Math.round((omega * d[input.get(index)]) / (R*1.0));
a = a + Math.round((omega * c[input.get(index)]) / (R*1.0));
while (b < half || a > half) {
if (b < half) {
// emit 0 and s 1's
// result = result << (s+1) | sOnes(result);
emitZeroAndSOnes(s);
s = 0;
a = 2*a; // a=2a
b = 2*b; // b=2b
} else if (a > half) {
// emit 1 and s 0's
// result = (result<<1 | 0x01)<<s;
emitOneAndSZeros(s);
s = 0;
a = 2*(a - half);
b = 2*(b - half);
}
}
while (a > quater && b < 3 * quater) {
s++;
a = 2*(a - quater);
b = 2*(b - quater);
}
}
s++;
if (a <= quater) {
// emit 0 and s 1's
// result = result << (s + 1) | sOnes(result);
emitZeroAndSOnes(s);
} else {
// emit 1 and s 0's
emitOneAndSZeros(s);
// result = (result << 1 | 0x01) << s;
}
}
}
public static class ArithmeticDecoder {
public final long precission;
public final long whole;
public final long half;
public final long quater;
public final int[] r; // R = 13
public final int[] c; // probability borders, c[0]=0, c[i] = r[0]+...+r[i-1]
public final int[] d; // probablility gap sizes, d[i] = c[i]+r[i]
public final int R;
public List<Integer> data;
public ArithmeticDecoder(int precission, int[] r){
this.precission = precission;
this.whole = (1L << precission);
this.half = whole >> 1;
this.quater = whole >> 2;
this.r = r;
this.c = new int[r.length];
this.d = new int[r.length];
this.d[0] = this.r[0];
this.c[0] = 0;
int bigR = r[0];
for(int i=1;i<r.length;i++){
for(int k=0;k<i;k++){
c[i] += r[k];
}
d[i] = c[i]+r[i];
bigR += r[i];
}
this.R = bigR;
}
public void decode(byte[] bs) {
data = new ArrayList<Integer>();
long a = 0;
long b = whole;
long z = 0;
long i = 0;
while (i < precission && i < bs.length * 8) {
if (VPXBooleanDecoder.getBitInBytes(bs, (int)i) != 0x00) {
z += (1L << (precission - i - 1));
}
i++;
}
while (true) {
for (int j = 0; j < 5; j++) {
long omega = b - a;
long bzero = a + Math.round((omega * d[j]) / (R*1.0));
long azero = a + Math.round((omega * c[j]) / (R*1.0));
if (azero <= z && z < bzero) {
data.add(j);
a = azero;
b = bzero;
if (j == 0) {
return;
}
break;
}
}
while (b < half || a > half) {
if (b < half) {
a = 2*a;
b = 2*b;
z = 2*z;
} else if (a > half) {
a = 2*(a - half);
b = 2*(b - half);
z = 2*(z - half);
}
if (i < (bs.length * 8)){
if (VPXBooleanDecoder.getBitInBytes(bs, (int)i) == 0x01)
z++;
i++;
}
}
while (a > quater && b < 3 * quater) {
a = (a - quater) << 1;
b = (b - quater) << 1;
z = (z - quater) << 1;
if (i < (bs.length * 8) ){
if (VPXBooleanDecoder.getBitInBytes(bs, (int)i) == 0x01)
z++;
i++;
}
}
}
}
}
@Test
public void testEmiter() throws Exception {
Emitter p = new Emitter();
Assert.assertArrayEquals(new byte[]{}, p.getArray());
p.emit(1);p.emit(1);p.emit(0);p.emit(1);p.emit(1);p.emit(0);
Assert.assertArrayEquals(new byte[]{(byte)0xD8}, p.getArray());
p = new Emitter();
// 01000011001
p.emit(0);p.emit(1);p.emit(0);p.emit(0);p.emit(0);p.emit(0);p.emit(1);p.emit(1);p.emit(0);p.emit(0);p.emit(1);
Assert.assertArrayEquals(new byte[]{0x43, 0x20}, p.getArray());
}
public static class Emitter{
private int i=0;
private byte buffer=0;
private ByteArrayOutputStream baos;
public Emitter() {
this.baos = new ByteArrayOutputStream();
}
public void emit(int b) throws IOException {
if (b != 1 && b != 0)
throw new IllegalArgumentException("Only 0's and 1's are accepted");
buffer |= b<<(7-i);
i++;
if (i>7){
i=0;
baos.write(new byte[]{buffer});
buffer=0;
}
}
public byte[] getArray() throws IOException {
if (i!=0){
i=0;
baos.write(new byte[]{buffer});
buffer=0;
}
return baos.toByteArray();
}
}
}