package org.elasticsearch.plan.a;
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import org.antlr.v4.runtime.ParserRuleContext;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
import static org.elasticsearch.plan.a.Adapter.*;
import static org.elasticsearch.plan.a.Default.*;
import static org.elasticsearch.plan.a.Definition.*;
class Caster {
private abstract static class Segment {
abstract Type promote(final ParserRuleContext source, final Type from0, final Type from1);
}
private static class SameTypeSegment extends Segment {
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
if (from1 != null && from0.equals(from1)) {
return from0;
}
return null;
}
}
private static class AnyTypeSegment extends Segment {
private final Caster caster;
private final Type to;
AnyTypeSegment(final Caster caster, final Type to) {
this.caster = caster;
this.to = to;
}
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
final boolean eq0 = from0.equals(to);
final boolean eq1 = from1 != null && from1.equals(to);
if (eq0 && (from1 == null || eq1)) {
return to;
}
if (eq0 || eq1) {
try {
caster.getLegalCast(source, eq0 ? from1 : from0, to, false);
return to;
} catch (ClassCastException exception) {
// Do nothing.
}
}
return null;
}
}
private static class ToTypeSegment extends Segment {
private final Caster caster;
private final Type to;
ToTypeSegment(final Caster caster, final Type to) {
this.caster = caster;
this.to = to;
}
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
final boolean eq0 = from0.equals(to);
final boolean eq1 = from1 == null || from1.equals(to);
if (eq0 && eq1) {
return to;
}
boolean castable = true;
if (!eq0) {
try {
caster.getLegalCast(source,from0, to, false);
} catch (ClassCastException exception) {
castable = false;
}
}
if (!eq1) {
try {
caster.getLegalCast(source, from1, to, false);
} catch (ClassCastException exception) {
castable = false;
}
}
if (castable) {
return to;
}
return null;
}
}
private static class AnyNumericSegment extends Segment {
private final Caster caster;
private final boolean decimal;
AnyNumericSegment(final Caster caster, final boolean decimal) {
this.caster = caster;
this.decimal = decimal;
}
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
if (from0.metadata.numeric || from1 != null && from1.metadata.numeric) {
try {
return caster.getNumericPromotion(source, from0, from1, decimal);
} catch (ClassCastException exception) {
// Do nothing.
}
}
return null;
}
}
private static class ToNumericSegment extends Segment {
private final Caster caster;
private final boolean decimal;
ToNumericSegment(final Caster caster, final boolean decimal) {
this.caster = caster;
this.decimal = decimal;
}
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
try {
return caster.getNumericPromotion(source, from0, from1, decimal);
} catch (ClassCastException exception) {
return null;
}
}
}
private static class ToSuperClassSegment extends Segment {
final Definition definition;
ToSuperClassSegment(final Definition definition) {
this.definition = definition;
}
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
if (from0.equals(from1)) {
return from0;
}
final Cast cast0 = new Cast(from0, from1);
final Cast cast1 = new Cast(from1, from0);
if (definition.upcasts.contains(cast0)) {
return from1;
}
if (definition.upcasts.contains(cast1)) {
return from0;
}
return null;
}
}
private static class ToSubClassSegment extends Segment {
final Definition definition;
final Standard standard;
ToSubClassSegment(final Definition definition, final Standard standard) {
this.definition = definition;
this.standard = standard;
}
@Override
Type promote(final ParserRuleContext source, final Type from0, final Type from1) {
if (from0.equals(from1)) {
return from0;
}
if (from0.clazz.equals(from1.clazz)) {
if (from0.struct.generic && !from1.struct.generic) {
return from1;
} else if (!from0.struct.generic && from1.struct.generic) {
return from0;
}
return standard.objectType;
}
if (from0.metadata.object && from1.metadata.object) {
try {
from0.clazz.asSubclass(from1.clazz);
return from1;
} catch (ClassCastException cce0) {
// Do nothing.
}
try {
from1.clazz.asSubclass(from0.clazz);
return from0;
} catch (ClassCastException cce0) {
// Do nothing.
}
return standard.objectType;
}
return null;
}
}
static class Promotion {
private final List<Segment> segments;
Promotion(final List<Segment> segments) {
this.segments = Collections.unmodifiableList(segments);
}
}
private final Definition definition;
private final Standard standard;
final Promotion concat;
final Promotion equality;
final Promotion decimal;
final Promotion numeric;
final Promotion shortcut;
Caster(final Definition definition, final Standard standard) {
this.definition = definition;
this.standard = standard;
List<Segment> segments = new ArrayList<>();
segments.add(new SameTypeSegment());
segments.add(new AnyTypeSegment(this, standard.boolType));
segments.add(new AnyNumericSegment(this, true));
segments.add(new ToSuperClassSegment(definition));
segments.add(new ToSubClassSegment(definition, standard));
concat = new Promotion(segments);
segments = new ArrayList<>();
segments.add(new AnyTypeSegment(this, standard.boolType));
segments.add(new AnyNumericSegment(this, true));
segments.add(new ToSuperClassSegment(definition));
segments.add(new ToSubClassSegment(definition, standard));
equality = new Promotion(segments);
segments = new ArrayList<>();
segments.add(new ToNumericSegment(this, true));
decimal = new Promotion(segments);
segments = new ArrayList<>();
segments.add(new ToNumericSegment(this, false));
numeric = new Promotion(segments);
segments = new ArrayList<>();
segments.add(new ToTypeSegment(this, standard.intType));
segments.add(new ToTypeSegment(this, standard.objectType));
shortcut = new Promotion(segments);
}
void markCast(final ExpressionMetadata emd) {
if (emd.from == null) {
throw new IllegalStateException(error(emd.source) + "From cast type should never be null.");
}
if (emd.to != null) {
emd.cast = getLegalCast(emd.source, emd.from, emd.to, emd.explicit);
if (emd.preConst != null && emd.to.metadata.constant) {
emd.postConst = constCast(emd.source, emd.preConst, emd.cast);
}
} else if (emd.promotion == null) {
throw new IllegalStateException(error(emd.source) + "No cast or promotion specified.");
}
}
Cast getLegalCast(final ParserRuleContext source, final Type from, final Type to, final boolean force) {
final Cast cast = new Cast(from, to);
if (from.equals(to)) {
return cast;
}
final Transform explicit = definition.explicits.get(cast);
if (force && explicit != null) {
return explicit;
}
final Transform implicit = definition.implicits.get(cast);
if (implicit != null) {
return implicit;
}
if (definition.upcasts.contains(cast)) {
return cast;
}
if (from.metadata.numeric && to.metadata.numeric && (force || definition.numerics.contains(cast))) {
return cast;
}
try {
from.clazz.asSubclass(to.clazz);
return cast;
} catch (ClassCastException cce0) {
try {
if (force) {
to.clazz.asSubclass(from.clazz);
return cast;
} else {
throw new ClassCastException(
error(source) + "Cannot cast from [" + from.name + "] to [" + to.name + "].");
}
} catch (ClassCastException cce1) {
throw new ClassCastException(
error(source) + "Cannot cast from [" + from.name + "] to [" + to.name + "].");
}
}
}
Object constCast(final ParserRuleContext source, final Object constant, final Cast cast) {
if (cast instanceof Transform) {
final Transform transform = (Transform)cast;
return invokeTransform(source, transform, constant);
} else {
final TypeMetadata fromTMD = cast.from.metadata;
final TypeMetadata toTMD = cast.to.metadata;
if (fromTMD == toTMD) {
return constant;
} else if (fromTMD.numeric && toTMD.numeric) {
Number number;
if (fromTMD == TypeMetadata.CHAR) {
number = (int)(char)constant;
} else {
number = (Number)constant;
}
switch (toTMD) {
case BYTE: return number.byteValue();
case SHORT: return number.shortValue();
case CHAR: return (char)number.intValue();
case INT: return number.intValue();
case LONG: return number.longValue();
case FLOAT: return number.floatValue();
case DOUBLE: return number.doubleValue();
default:
throw new IllegalStateException(error(source) + "Expected numeric type for cast.");
}
} else {
throw new IllegalStateException(error(source) + "No valid constant cast from " +
"[" + cast.from.clazz.getCanonicalName() + "] to " +
"[" + cast.to.clazz.getCanonicalName() + "].");
}
}
}
private Object invokeTransform(final ParserRuleContext source, final Transform transform, final Object object) {
final Method method = transform.method;
final java.lang.reflect.Method jmethod = method.method;
final int modifiers = jmethod.getModifiers();
try {
if (java.lang.reflect.Modifier.isStatic(modifiers)) {
return jmethod.invoke(null, object);
} else {
return jmethod.invoke(object);
}
} catch (IllegalAccessException | IllegalArgumentException |
java.lang.reflect.InvocationTargetException | NullPointerException |
ExceptionInInitializerError exception) {
throw new IllegalStateException(error(source) + "Unable to invoke transform to cast constant from " +
"[" + transform.from.name + "] to [" + transform.to.name + "].");
}
}
Type getTypePromotion(final ParserRuleContext source, final Type from0, final Type from1, final Promotion promotion) {
for (final Segment segment : promotion.segments) {
final Type type = segment.promote(source, from0, from1);
if (type != null) {
return type;
}
}
throw new ClassCastException(error(source) + "Cannot find valid promotion for types [" +
from0.name + "] and [" + from1.name + "].");
}
Type getNumericPromotion(final ParserRuleContext source, final Type from0, final Type from1, boolean decimal) {
final Deque<Type> upcast = new ArrayDeque<>();
final Deque<Type> downcast = new ArrayDeque<>();
if (decimal) {
upcast.push(standard.doubleType);
upcast.push(standard.floatType);
} else {
downcast.push(standard.doubleType);
downcast.push(standard.floatType);
}
upcast.push(standard.longType);
upcast.push(standard.intType);
while (!upcast.isEmpty()) {
final Type to = upcast.pop();
final Cast cast0 = new Cast(from0, to);
if (from0.metadata.numeric && from0.metadata != to.metadata &&
!definition.numerics.contains(cast0)) continue;
if (upcast.contains(from0)) continue;
if (downcast.contains(from0) && !definition.numerics.contains(cast0) &&
!definition.implicits.containsKey(cast0)) continue;
if (!from0.metadata.numeric && !definition.implicits.containsKey(cast0)) continue;
if (from1 != null) {
final Cast cast1 = new Cast(from1, to);
if (from1.metadata.numeric && from1.metadata != to.metadata &&
!definition.numerics.contains(cast1)) continue;
if (upcast.contains(from1)) continue;
if (downcast.contains(from1) && !definition.numerics.contains(cast1) &&
!definition.implicits.containsKey(cast1)) continue;
if (!from1.metadata.numeric && !definition.implicits.containsKey(cast1)) continue;
}
return to;
}
if (from1 == null) {
throw new ClassCastException(
error(source) + "Unable to find numeric promotion for type [" + from0.name + "].");
} else {
throw new ClassCastException(error(source) + "Unable to find numeric promotion for types" +
" [" + from0.name + "] and [" + from1.name + "].");
}
}
void checkWriteCast(final MethodVisitor visitor, final ExpressionMetadata metadata) {
checkWriteCast(visitor, metadata.source, metadata.cast);
}
void checkWriteCast(final MethodVisitor visitor, final ParserRuleContext source, final Cast cast) {
if (cast instanceof Transform) {
writeTransform(visitor, (Transform)cast);
} else if (cast != null) {
writeCast(visitor, cast);
} else {
throw new IllegalStateException(error(source) + "Unexpected cast object.");
}
}
void writeCast(final MethodVisitor visitor, final Cast cast) {
final Type from = cast.from;
final Type to = cast.to;
if (from.equals(to)) {
return;
}
if (from.metadata.numeric && to.metadata.numeric) {
switch (from.metadata) {
case BYTE:
switch (to.metadata) {
case SHORT: visitor.visitInsn(Opcodes.I2S); break;
case CHAR: visitor.visitInsn(Opcodes.I2C); break;
case LONG: visitor.visitInsn(Opcodes.I2L); break;
case FLOAT: visitor.visitInsn(Opcodes.I2F); break;
case DOUBLE: visitor.visitInsn(Opcodes.I2D); break;
}
break;
case SHORT:
switch (to.metadata) {
case BYTE: visitor.visitInsn(Opcodes.I2B); break;
case CHAR: visitor.visitInsn(Opcodes.I2C); break;
case LONG: visitor.visitInsn(Opcodes.I2L); break;
case FLOAT: visitor.visitInsn(Opcodes.I2F); break;
case DOUBLE: visitor.visitInsn(Opcodes.I2D); break;
}
break;
case CHAR:
switch (to.metadata) {
case BYTE: visitor.visitInsn(Opcodes.I2B); break;
case SHORT: visitor.visitInsn(Opcodes.I2S); break;
case LONG: visitor.visitInsn(Opcodes.I2L); break;
case FLOAT: visitor.visitInsn(Opcodes.I2F); break;
case DOUBLE: visitor.visitInsn(Opcodes.I2D); break;
}
break;
case INT:
switch (to.metadata) {
case BYTE: visitor.visitInsn(Opcodes.I2B); break;
case SHORT: visitor.visitInsn(Opcodes.I2S); break;
case CHAR: visitor.visitInsn(Opcodes.I2C); break;
case LONG: visitor.visitInsn(Opcodes.I2L); break;
case FLOAT: visitor.visitInsn(Opcodes.I2F); break;
case DOUBLE: visitor.visitInsn(Opcodes.I2D); break;
}
break;
case LONG:
switch (to.metadata) {
case BYTE: visitor.visitInsn(Opcodes.L2I); visitor.visitInsn(Opcodes.I2B); break;
case SHORT: visitor.visitInsn(Opcodes.L2I); visitor.visitInsn(Opcodes.I2S); break;
case CHAR: visitor.visitInsn(Opcodes.L2I); visitor.visitInsn(Opcodes.I2C); break;
case INT: visitor.visitInsn(Opcodes.L2I); break;
case FLOAT: visitor.visitInsn(Opcodes.L2F); break;
case DOUBLE: visitor.visitInsn(Opcodes.L2D); break;
}
break;
case FLOAT:
switch (to.metadata) {
case BYTE: visitor.visitInsn(Opcodes.F2I); visitor.visitInsn(Opcodes.I2B); break;
case SHORT: visitor.visitInsn(Opcodes.F2I); visitor.visitInsn(Opcodes.I2S); break;
case CHAR: visitor.visitInsn(Opcodes.F2I); visitor.visitInsn(Opcodes.I2C); break;
case INT: visitor.visitInsn(Opcodes.F2I); break;
case LONG: visitor.visitInsn(Opcodes.F2L); break;
case DOUBLE: visitor.visitInsn(Opcodes.F2D); break;
}
break;
case DOUBLE:
switch (to.metadata) {
case BYTE: visitor.visitInsn(Opcodes.D2I); visitor.visitInsn(Opcodes.I2B); break;
case SHORT: visitor.visitInsn(Opcodes.D2I); visitor.visitInsn(Opcodes.I2S); break;
case CHAR: visitor.visitInsn(Opcodes.D2I); visitor.visitInsn(Opcodes.I2C); break;
case INT: visitor.visitInsn(Opcodes.D2I); break;
case LONG: visitor.visitInsn(Opcodes.D2L); break;
case FLOAT: visitor.visitInsn(Opcodes.D2F); break;
}
break;
}
} else {
try {
from.clazz.asSubclass(to.clazz);
} catch (ClassCastException exception) {
visitor.visitTypeInsn(Opcodes.CHECKCAST, to.internal);
}
}
}
void writeTransform(final MethodVisitor visitor, final Transform transform) {
final Class clazz = transform.method.owner.clazz;
final java.lang.reflect.Method method = transform.method.method;
final String name = method.getName();
final String internal = transform.method.owner.internal;
final String descriptor = transform.method.descriptor;
final Type upcast = transform.upcast;
final Type downcast = transform.downcast;
if (upcast != null) {
visitor.visitTypeInsn(Opcodes.CHECKCAST, upcast.internal);
}
if (java.lang.reflect.Modifier.isStatic(method.getModifiers())) {
visitor.visitMethodInsn(Opcodes.INVOKESTATIC, internal, name, descriptor, false);
} else if (java.lang.reflect.Modifier.isInterface(clazz.getModifiers())) {
visitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, internal, name, descriptor, true);
} else {
visitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, internal, name, descriptor, false);
}
if (downcast != null) {
visitor.visitTypeInsn(Opcodes.CHECKCAST, downcast.internal);
}
}
}