/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.tools.array; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.UDFType; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; @Description(name = "array_intersect", value = "_FUNC_(array<ANY> x1, array<ANY> x2, ..) - Returns an intersect of given arrays") @UDFType(deterministic = true, stateful = false) public final class ArrayIntersectUDF extends GenericUDF { private ListObjectInspector[] argListOIs; private List<Object> result; public ArrayIntersectUDF() {} @Override public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException { final int argLength = argOIs.length; if (argLength < 2) { throw new UDFArgumentLengthException("Expecting at least two arrays as arguments: " + argLength); } ListObjectInspector[] argListOIs = new ListObjectInspector[argLength]; ListObjectInspector arg0ListOI = HiveUtils.asListOI(argOIs[0]); ObjectInspector arg0ElemOI = arg0ListOI.getListElementObjectInspector(); argListOIs[0] = arg0ListOI; for (int i = 1; i < argLength; i++) { ListObjectInspector listOI = HiveUtils.asListOI(argOIs[i]); if (!ObjectInspectorUtils.compareTypes(listOI.getListElementObjectInspector(), arg0ElemOI)) { throw new UDFArgumentException("Array types does not match: " + arg0ElemOI.getTypeName() + " != " + listOI.getListElementObjectInspector().getTypeName()); } argListOIs[i] = listOI; } this.argListOIs = argListOIs; this.result = new ArrayList<Object>(); return ObjectInspectorUtils.getStandardObjectInspector(arg0ListOI); } @Override public List<Object> evaluate(@Nonnull DeferredObject[] args) throws HiveException { result.clear(); final Object arg0 = args[0].get(); if (arg0 == null) { return Collections.emptyList(); } Set<InspectableObject> checkSet = new HashSet<ArrayIntersectUDF.InspectableObject>(); final ListObjectInspector arg0ListOI = argListOIs[0]; final ObjectInspector arg0ElemOI = arg0ListOI.getListElementObjectInspector(); final int arg0size = arg0ListOI.getListLength(arg0); for (int i = 0; i < arg0size; i++) { Object o = arg0ListOI.getListElement(arg0, i); if (o == null) { continue; } checkSet.add(new InspectableObject(o, arg0ElemOI)); } final InspectableObject probe = new InspectableObject(); for (int i = 1, numArgs = args.length; i < numArgs; i++) { final Object argI = args[i].get(); if (argI == null) { continue; } final Set<InspectableObject> newSet = new HashSet<ArrayIntersectUDF.InspectableObject>(); final ListObjectInspector argIListOI = argListOIs[i]; final ObjectInspector argIElemOI = argIListOI.getListElementObjectInspector(); for (int j = 0, j_size = argIListOI.getListLength(argI); j < j_size; j++) { Object o = argIListOI.getListElement(argI, j); if (o == null) { continue; } probe.set(o, argIElemOI); if (checkSet.contains(probe)) { newSet.add(probe.copy()); } } checkSet = newSet; } for (InspectableObject inspect : checkSet) { Object obj = ObjectInspectorUtils.copyToStandardObject(inspect.o, inspect.oi); result.add(obj); } return result; } @Override public String getDisplayString(String[] args) { return "array_intersect(" + Arrays.toString(args) + ")"; } private static final class InspectableObject implements Comparable<InspectableObject> { public Object o; public ObjectInspector oi; InspectableObject() {} InspectableObject(@Nonnull Object o, @Nonnull ObjectInspector oi) { this.o = o; this.oi = oi; } void set(@Nonnull Object o, @Nonnull ObjectInspector oi) { this.o = o; this.oi = oi; } InspectableObject copy() { return new InspectableObject(o, oi); } @Override public int hashCode() { return ObjectInspectorUtils.hashCode(o, oi); } @Override public int compareTo(InspectableObject otherOI) { return ObjectInspectorUtils.compare(o, oi, otherOI.o, otherOI.oi); } @Override public boolean equals(Object other) { if (!(other instanceof InspectableObject)) { return false; } return compareTo((InspectableObject) other) == 0; } } }