package com.facebook.presto.operator.scalar; /* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.block.Block; import com.facebook.presto.spi.type.Type; import com.google.common.base.Throwables; import java.lang.invoke.MethodHandle; import java.util.LinkedHashMap; import java.util.Map; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.type.TypeUtils.readNativeValue; public final class MapGenericEquality { private MapGenericEquality() {} public interface EqualityPredicate { Boolean equals(int leftMapIndex, int rightMapIndex) throws Throwable; } public static Boolean genericEqual( MethodHandle keyEqualsFunction, MethodHandle keyHashcodeFunction, Type keyType, Block leftMapBlock, Block rightMapBlock, EqualityPredicate predicate) { Map<KeyWrapper, Integer> wrappedLeftMap = new LinkedHashMap<>(); for (int position = 0; position < leftMapBlock.getPositionCount(); position += 2) { wrappedLeftMap.put(new KeyWrapper(readNativeValue(keyType, leftMapBlock, position), keyEqualsFunction, keyHashcodeFunction), position + 1); } Map<KeyWrapper, Integer> wrappedRightMap = new LinkedHashMap<>(); for (int position = 0; position < rightMapBlock.getPositionCount(); position += 2) { wrappedRightMap.put(new KeyWrapper(readNativeValue(keyType, rightMapBlock, position), keyEqualsFunction, keyHashcodeFunction), position + 1); } if (wrappedLeftMap.size() != wrappedRightMap.size()) { return false; } for (Map.Entry<KeyWrapper, Integer> entry : wrappedRightMap.entrySet()) { KeyWrapper key = entry.getKey(); Integer leftValuePosition = wrappedLeftMap.get(key); if (leftValuePosition == null) { return false; } try { Boolean result = predicate.equals(leftValuePosition, entry.getValue()); if (result == null) { return null; } else if (!result) { return false; } } catch (Throwable t) { Throwables.propagateIfInstanceOf(t, Error.class); Throwables.propagateIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } return true; } private static final class KeyWrapper { private final Object key; private final MethodHandle hashCode; private final MethodHandle equals; public KeyWrapper(Object key, MethodHandle equals, MethodHandle hashCode) { this.key = key; this.equals = equals; this.hashCode = hashCode; } @Override public int hashCode() { try { return Long.hashCode((long) hashCode.invoke(key)); } catch (Throwable t) { Throwables.propagateIfInstanceOf(t, Error.class); Throwables.propagateIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } @Override public boolean equals(Object obj) { if (obj == null || !getClass().equals(obj.getClass())) { return false; } KeyWrapper other = (KeyWrapper) obj; try { return (boolean) equals.invoke(key, other.key); } catch (Throwable t) { Throwables.propagateIfInstanceOf(t, Error.class); Throwables.propagateIfInstanceOf(t, PrestoException.class); throw new PrestoException(GENERIC_INTERNAL_ERROR, t); } } } }