/*
* Copyright 2016 higherfrequencytrading.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.openhft.lang.io.serialization.impl;
import net.openhft.lang.io.ByteBufferBytes;
import net.openhft.lang.io.Bytes;
import net.openhft.lang.io.serialization.CompactBytesMarshaller;
import net.openhft.lang.model.constraints.Nullable;
import org.xerial.snappy.Snappy;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
/**
* Created by peter.lawrey on 24/10/14.
*/
public enum SnappyStringMarshaller implements CompactBytesMarshaller<CharSequence> {
INSTANCE;
private static final StringFactory STRING_FACTORY = getStringFactory();
private static final int NULL_LENGTH = -1;
private static final ThreadLocal<ThreadLocals> THREAD_LOCALS = new ThreadLocal<ThreadLocals>();
private static StringFactory getStringFactory() {
try {
return new StringFactory17();
} catch (Exception e) {
// do nothing
}
try {
return new StringFactory16();
} catch (Exception e) {
// no more alternatives
throw new AssertionError(e);
}
}
@Override
public byte code() {
return STRINGZ_CODE;
}
public ThreadLocals acquireThreadLocals() {
ThreadLocals threadLocals = THREAD_LOCALS.get();
if (threadLocals == null)
THREAD_LOCALS.set(threadLocals = new ThreadLocals());
threadLocals.clear();
return threadLocals;
}
@Override
public void write(Bytes bytes, CharSequence s) {
if (s == null) {
bytes.writeStopBit(NULL_LENGTH);
return;
} else if (s.length() == 0) {
bytes.writeStopBit(0);
return;
}
// write the total length.
int length = s.length();
bytes.writeStopBit(length);
ThreadLocals threadLocals = acquireThreadLocals();
// stream the portions of the string.
Bytes db = threadLocals.decompressedBytes;
ByteBuffer dbb = threadLocals.decompressedByteBuffer;
ByteBuffer cbb = bytes.sliceAsByteBuffer(threadLocals.compressedByteBuffer);
int position = 0;
while (position < length) {
// 3 is the longest encoding.
while (position < length && db.remaining() >= 3)
db.writeStopBit(s.charAt(position++));
dbb.position(0);
dbb.limit((int) db.position());
// portion copied now compress it.
int portionLengthPos = cbb.position();
cbb.putShort((short) 0);
int compressedLength;
try {
Snappy.compress(dbb, cbb);
compressedLength = cbb.remaining();
if (compressedLength >= 1 << 16)
throw new AssertionError();
// unflip.
cbb.position(cbb.limit());
cbb.limit(cbb.capacity());
} catch (IOException e) {
throw new AssertionError(e);
}
cbb.putShort(portionLengthPos, (short) compressedLength);
db.clear();
}
// the end.
cbb.putShort((short) 0);
bytes.position(bytes.position() + cbb.position());
}
@Override
public String read(Bytes bytes) {
return read(bytes, null);
}
@Override
public String read(Bytes bytes, @Nullable CharSequence ignored) {
long size = bytes.readStopBit();
if (size == NULL_LENGTH)
return null;
if (size < 0 || size > Integer.MAX_VALUE)
throw new IllegalStateException("Invalid length: " + size);
if (size == 0)
return "";
ThreadLocals threadLocals = acquireThreadLocals();
// stream the portions of the string.
Bytes db = threadLocals.decompressedBytes;
ByteBuffer dbb = threadLocals.decompressedByteBuffer;
ByteBuffer cbb = bytes.sliceAsByteBuffer(threadLocals.compressedByteBuffer);
char[] chars = new char[(int) size];
int pos = 0;
for (int chunkLen; (chunkLen = cbb.getShort() & 0xFFFF) > 0; ) {
cbb.limit(cbb.position() + chunkLen);
dbb.clear();
try {
Snappy.uncompress(cbb, dbb);
cbb.position(cbb.limit());
cbb.limit(cbb.capacity());
} catch (IOException e) {
throw new AssertionError(e);
}
db.position(0);
db.limit(dbb.limit());
while (db.remaining() > 0)
chars[pos++] = (char) db.readStopBit();
}
bytes.position(bytes.position() + cbb.position());
try {
return STRING_FACTORY.fromChars(chars);
} catch (Exception e) {
throw new AssertionError(e);
}
}
static class ThreadLocals {
ByteBuffer decompressedByteBuffer = ByteBuffer.allocateDirect(32 * 1024);
Bytes decompressedBytes = ByteBufferBytes.wrap(decompressedByteBuffer);
ByteBuffer compressedByteBuffer = ByteBuffer.allocateDirect(0);
public void clear() {
decompressedByteBuffer.clear();
decompressedBytes.clear();
compressedByteBuffer.clear();
}
}
private static abstract class StringFactory {
abstract String fromChars(char[] chars) throws IllegalAccessException, InvocationTargetException, InstantiationException;
}
private static final class StringFactory16 extends StringFactory {
private final Constructor<String> constructor;
private StringFactory16() throws NoSuchMethodException {
constructor = String.class.getDeclaredConstructor(int.class,
int.class, char[].class);
constructor.setAccessible(true);
}
@Override
String fromChars(char[] chars) throws IllegalAccessException, InvocationTargetException, InstantiationException {
return constructor.newInstance(0, chars.length, chars);
}
}
private static final class StringFactory17 extends StringFactory {
private final Constructor<String> constructor;
private StringFactory17() throws NoSuchMethodException {
constructor = String.class.getDeclaredConstructor(char[].class, boolean.class);
constructor.setAccessible(true);
}
@Override
String fromChars(char[] chars) throws IllegalAccessException, InvocationTargetException, InstantiationException {
return constructor.newInstance(chars, true);
}
}
}