/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.udf.generic; import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator; import org.apache.hadoop.hive.ql.exec.PTFPartition.PTFPartitionIterator; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.io.IntWritable; public abstract class GenericUDFLeadLag extends GenericUDF { transient ExprNodeEvaluator exprEvaluator; transient PTFPartitionIterator<Object> pItr; transient ObjectInspector firstArgOI; transient Converter defaultValueConverter; int amt; @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { Object defaultVal = null; if (arguments.length == 3) { defaultVal = ObjectInspectorUtils.copyToStandardObject( defaultValueConverter.convert(arguments[2].get()), firstArgOI); } int idx = pItr.getIndex() - 1; int start = 0; int end = pItr.getPartition().size(); try { Object ret = null; int newIdx = getIndex(amt); if (newIdx >= end || newIdx < start) { ret = defaultVal; } else { Object row = getRow(amt); ret = exprEvaluator.evaluate(row); ret = ObjectInspectorUtils.copyToStandardObject(ret, firstArgOI, ObjectInspectorCopyOption.WRITABLE); } return ret; } finally { Object currRow = pItr.resetToIndex(idx); // reevaluate expression on current Row, to trigger the Lazy object // caches to be reset to the current row. exprEvaluator.evaluate(currRow); } } @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { if (!(arguments.length >= 1 && arguments.length <= 3)) { throw new UDFArgumentTypeException(arguments.length - 1, "Incorrect invocation of " + _getFnName() + ": _FUNC_(expr, amt, default)"); } amt = 1; if (arguments.length > 1) { ObjectInspector amtOI = arguments[1]; if (!ObjectInspectorUtils.isConstantObjectInspector(amtOI) || (amtOI.getCategory() != ObjectInspector.Category.PRIMITIVE) || ((PrimitiveObjectInspector) amtOI).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.INT) { throw new UDFArgumentTypeException(1, _getFnName() + " amount must be a integer value " + amtOI.getTypeName() + " was passed as parameter 1."); } Object o = ((ConstantObjectInspector) amtOI).getWritableConstantValue(); amt = ((IntWritable) o).get(); if (amt < 0) { throw new UDFArgumentTypeException(1, " amount can not be nagative. Specified: " + amt); } } if (arguments.length == 3) { defaultValueConverter = ObjectInspectorConverters.getConverter(arguments[2], arguments[0]); } firstArgOI = arguments[0]; return ObjectInspectorUtils.getStandardObjectInspector(firstArgOI, ObjectInspectorCopyOption.WRITABLE); } public ExprNodeEvaluator getExprEvaluator() { return exprEvaluator; } public void setExprEvaluator(ExprNodeEvaluator exprEvaluator) { this.exprEvaluator = exprEvaluator; } public PTFPartitionIterator<Object> getpItr() { return pItr; } public void setpItr(PTFPartitionIterator<Object> pItr) { this.pItr = pItr; } public ObjectInspector getFirstArgOI() { return firstArgOI; } public void setFirstArgOI(ObjectInspector firstArgOI) { this.firstArgOI = firstArgOI; } public Converter getDefaultValueConverter() { return defaultValueConverter; } public void setDefaultValueConverter(Converter defaultValueConverter) { this.defaultValueConverter = defaultValueConverter; } public int getAmt() { return amt; } public void setAmt(int amt) { this.amt = amt; } @Override public String getDisplayString(String[] children) { assert (children.length == 2); return getStandardDisplayString(_getFnName(), children); } protected abstract String _getFnName(); protected abstract Object getRow(int amt) throws HiveException; protected abstract int getIndex(int amt); }