/* * Copyright (C) 2015 SoftIndex LLC. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.datakernel.serializer.asm; import io.datakernel.bytebuf.SerializationUtils; import io.datakernel.codegen.Expression; import io.datakernel.codegen.Variable; import io.datakernel.serializer.CompatibilityLevel; import io.datakernel.serializer.NullableOptimization; import io.datakernel.serializer.SerializerBuilder; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import static io.datakernel.codegen.Expressions.*; import static io.datakernel.codegen.utils.Preconditions.checkNotNull; import static org.objectweb.asm.Type.getType; @SuppressWarnings("PointlessArithmeticExpression") public class SerializerGenSubclass implements SerializerGen, NullableOptimization { @Override public SerializerGen setNullable() { return new SerializerGenSubclass(dataType, subclassSerializers, true, startIndex); } private final Class<?> dataType; private final LinkedHashMap<Class<?>, SerializerGen> subclassSerializers; private final boolean nullable; private final int startIndex; public SerializerGenSubclass(Class<?> dataType, LinkedHashMap<Class<?>, SerializerGen> subclassSerializers, int startIndex) { this.startIndex = startIndex; this.dataType = checkNotNull(dataType); this.subclassSerializers = new LinkedHashMap<>(subclassSerializers); this.nullable = false; } public SerializerGenSubclass(Class<?> dataType, LinkedHashMap<Class<?>, SerializerGen> subclassSerializers, boolean nullable, int startIndex) { this.startIndex = startIndex; this.dataType = checkNotNull(dataType); this.subclassSerializers = new LinkedHashMap<>(subclassSerializers); this.nullable = nullable; } @Override public void getVersions(VersionsCollector versions) { for (SerializerGen serializer : subclassSerializers.values()) { versions.addRecursive(serializer); } } @Override public boolean isInline() { return false; } @Override public Class<?> getRawType() { return dataType; } @Override public void prepareSerializeStaticMethods(int version, SerializerBuilder.StaticMethods staticMethods, CompatibilityLevel compatibilityLevel) { if (staticMethods.startSerializeStaticMethod(this, version)) { return; } byte subClassIndex = (byte) (nullable && startIndex == 0 ? 1 : startIndex); List<Expression> listKey = new ArrayList<>(); List<Expression> listValue = new ArrayList<>(); for (Class<?> subclass : subclassSerializers.keySet()) { SerializerGen subclassSerializer = subclassSerializers.get(subclass); subclassSerializer.prepareSerializeStaticMethods(version, staticMethods, compatibilityLevel); listKey.add(cast(value(getType(subclass)), Object.class)); listValue.add(sequence( set(arg(1), callStatic(SerializationUtils.class, "writeByte", arg(0), arg(1), value(subClassIndex))), subclassSerializer.serialize(arg(0), arg(1), cast(arg(2), subclassSerializer.getRawType()), version, staticMethods, compatibilityLevel) )); subClassIndex++; if (nullable && subClassIndex == 0) { subClassIndex++; } } if (nullable) { staticMethods.registerStaticSerializeMethod(this, version, ifThenElse(isNotNull(arg(2)), switchForKey(cast(call(cast(arg(2), Object.class), "getClass"), Object.class), listKey, listValue), callStatic(SerializationUtils.class, "writeByte", arg(0), arg(1), value((byte) 0))) ); } else { staticMethods.registerStaticSerializeMethod(this, version, switchForKey(cast(call(cast(arg(2), Object.class), "getClass"), Object.class), listKey, listValue) ); } } @Override public Expression serialize(Expression byteArray, Variable off, Expression value, int version, SerializerBuilder.StaticMethods staticMethods, CompatibilityLevel compatibilityLevel) { return staticMethods.callStaticSerializeMethod(this, version, byteArray, off, value); } @Override public void prepareDeserializeStaticMethods(int version, SerializerBuilder.StaticMethods staticMethods, CompatibilityLevel compatibilityLevel) { if (staticMethods.startDeserializeStaticMethod(this, version)) { return; } List<Expression> list = new ArrayList<>(); for (Class<?> subclass : subclassSerializers.keySet()) { SerializerGen subclassSerializer = subclassSerializers.get(subclass); subclassSerializer.prepareDeserializeStaticMethods(version, staticMethods, compatibilityLevel); list.add(cast(subclassSerializer.deserialize(subclassSerializer.getRawType(), version, staticMethods, compatibilityLevel), this.getRawType())); } if (nullable) list.add(-startIndex, nullRef(getRawType())); Variable subClassIndex = let(sub(call(arg(0), "readByte"), value(startIndex))); staticMethods.registerStaticDeserializeMethod(this, version, cast(switchForPosition(subClassIndex, list), this.getRawType())); } @Override public Expression deserialize(Class<?> targetType, int version, SerializerBuilder.StaticMethods staticMethods, CompatibilityLevel compatibilityLevel) { return staticMethods.callStaticDeserializeMethod(this, version, arg(0)); } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SerializerGenSubclass that = (SerializerGenSubclass) o; if (!dataType.equals(that.dataType)) return false; if (!subclassSerializers.equals(that.subclassSerializers)) return false; return true; } @Override public int hashCode() { int result = dataType.hashCode(); result = 31 * result + subclassSerializers.hashCode(); return result; } }