/** * Licensed to the Apache Software Foundation (ASF) under one * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.udf.generic; import java.math.BigDecimal; import java.sql.Timestamp; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.Calendar; import java.util.Date; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DateWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.io.TimestampWritable; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter.TimestampConverter; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; /** * GenericUDFTrunc. * * Returns the first day of the month which the date belongs to. The time part of the date will be * ignored. * */ @Description(name = "trunc", value = "_FUNC_(date, fmt) / _FUNC_(N,D) - Returns If input is date returns date with the time portion of the day truncated " + "to the unit specified by the format model fmt. If you omit fmt, then date is truncated to " + "the nearest day. It currently only supports 'MONTH'/'MON'/'MM', 'QUARTER'/'Q' and 'YEAR'/'YYYY'/'YY' as format." + "If input is a number group returns N truncated to D decimal places. If D is omitted, then N is truncated to 0 places." + "D can be negative to truncate (make zero) D digits left of the decimal point." , extended = "date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'." + " The time part of date is ignored.\n" + "Example:\n " + " > SELECT _FUNC_('2009-02-12', 'MM');\n" + "OK\n" + " '2009-02-01'" + "\n" + " > SELECT _FUNC_('2017-03-15', 'Q');\n" + "OK\n" + " '2017-01-01'" + "\n" + " > SELECT _FUNC_('2015-10-27', 'YEAR');\n" + "OK\n" + " '2015-01-01'" + " > SELECT _FUNC_(1234567891.1234567891,4);\n" + "OK\n" + " 1234567891.1234" + "\n" + " > SELECT _FUNC_(1234567891.1234567891,-4);\n" + "OK\n" + " 1234560000" + " > SELECT _FUNC_(1234567891.1234567891,0);\n" + "OK\n" + " 1234567891" + "\n" + " > SELECT _FUNC_(1234567891.1234567891);\n" + "OK\n" + " 1234567891") public class GenericUDFTrunc extends GenericUDF { private transient SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd"); private transient TimestampConverter timestampConverter; private transient Converter textConverter1; private transient Converter textConverter2; private transient Converter dateWritableConverter; private transient Converter byteConverter; private transient Converter shortConverter; private transient Converter intConverter; private transient Converter longConverter; private transient PrimitiveCategory inputType1; private transient PrimitiveCategory inputType2; private final Calendar calendar = Calendar.getInstance(); private final Text output = new Text(); private transient String fmtInput; private transient PrimitiveObjectInspector inputOI; private transient PrimitiveObjectInspector inputScaleOI; private int scale = 0; private boolean inputSacleConst; private boolean dateTypeArg; @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { if (arguments.length == 2) { inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory(); inputType2 = ((PrimitiveObjectInspector) arguments[1]).getPrimitiveCategory(); if ((PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.DATE_GROUP || PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.STRING_GROUP) && PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType2) == PrimitiveGrouping.STRING_GROUP) { dateTypeArg = true; return initializeDate(arguments); } else if (PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.NUMERIC_GROUP && PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType2) == PrimitiveGrouping.NUMERIC_GROUP) { dateTypeArg = false; return initializeNumber(arguments); } throw new UDFArgumentException("Got wrong argument types : first argument type : " + arguments[0].getTypeName() + ", second argument type : " + arguments[1].getTypeName()); } else if (arguments.length == 1) { inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory(); if (PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.NUMERIC_GROUP) { dateTypeArg = false; return initializeNumber(arguments); } else { throw new UDFArgumentException( "Only primitive type arguments are accepted, when arguments length is one, got " + arguments[1].getTypeName()); } } throw new UDFArgumentException("TRUNC requires one or two argument, got " + arguments.length); } private ObjectInspector initializeNumber(ObjectInspector[] arguments) throws UDFArgumentException { if (arguments.length < 1 || arguments.length > 2) { throw new UDFArgumentLengthException( "TRUNC requires one or two argument, got " + arguments.length); } if (arguments[0].getCategory() != Category.PRIMITIVE) { throw new UDFArgumentTypeException(0, "TRUNC input only takes primitive types, got " + arguments[0].getTypeName()); } inputOI = (PrimitiveObjectInspector) arguments[0]; if (arguments.length == 2) { if (arguments[1].getCategory() != Category.PRIMITIVE) { throw new UDFArgumentTypeException(1, "TRUNC second argument only takes primitive types, got " + arguments[1].getTypeName()); } inputScaleOI = (PrimitiveObjectInspector) arguments[1]; inputSacleConst = arguments[1] instanceof ConstantObjectInspector; if (inputSacleConst) { try { Object obj = ((ConstantObjectInspector) arguments[1]).getWritableConstantValue(); fmtInput = obj != null ? obj.toString() : null; scale = Integer.parseInt(fmtInput); } catch (Exception e) { throw new UDFArgumentException("TRUNC input only takes integer values, got " + fmtInput); } } else { switch (inputScaleOI.getPrimitiveCategory()) { case BYTE: byteConverter = ObjectInspectorConverters.getConverter(arguments[1], PrimitiveObjectInspectorFactory.writableByteObjectInspector); break; case SHORT: shortConverter = ObjectInspectorConverters.getConverter(arguments[1], PrimitiveObjectInspectorFactory.writableShortObjectInspector); break; case INT: intConverter = ObjectInspectorConverters.getConverter(arguments[1], PrimitiveObjectInspectorFactory.writableIntObjectInspector); break; case LONG: longConverter = ObjectInspectorConverters.getConverter(arguments[1], PrimitiveObjectInspectorFactory.writableLongObjectInspector); break; default: throw new UDFArgumentTypeException(1, getFuncName().toUpperCase() + " second argument only takes integer values"); } } } inputType1 = inputOI.getPrimitiveCategory(); ObjectInspector outputOI = null; switch (inputType1) { case DECIMAL: outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType1); break; case VOID: case BYTE: case SHORT: case INT: case LONG: case FLOAT: case DOUBLE: outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType1); break; default: throw new UDFArgumentTypeException(0, "Only numeric or string group data types are allowed for TRUNC function. Got " + inputType1.name()); } return outputOI; } private ObjectInspector initializeDate(ObjectInspector[] arguments) throws UDFArgumentLengthException, UDFArgumentTypeException { if (arguments.length != 2) { throw new UDFArgumentLengthException("trunc() requires 2 argument, got " + arguments.length); } if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + arguments[0].getTypeName() + " is passed. as first arguments"); } if (arguments[1].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentTypeException(1, "Only primitive type arguments are accepted but " + arguments[1].getTypeName() + " is passed. as second arguments"); } ObjectInspector outputOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector; inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory(); switch (inputType1) { case STRING: case VARCHAR: case CHAR: case VOID: inputType1 = PrimitiveCategory.STRING; textConverter1 = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableStringObjectInspector); break; case TIMESTAMP: timestampConverter = new TimestampConverter((PrimitiveObjectInspector) arguments[0], PrimitiveObjectInspectorFactory.writableTimestampObjectInspector); break; case DATE: dateWritableConverter = ObjectInspectorConverters.getConverter(arguments[0], PrimitiveObjectInspectorFactory.writableDateObjectInspector); break; default: throw new UDFArgumentTypeException(0, "TRUNC() only takes STRING/TIMESTAMP/DATEWRITABLE types as first argument, got " + inputType1); } inputType2 = ((PrimitiveObjectInspector) arguments[1]).getPrimitiveCategory(); if (PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType2) != PrimitiveGrouping.STRING_GROUP && PrimitiveObjectInspectorUtils .getPrimitiveGrouping(inputType2) != PrimitiveGrouping.VOID_GROUP) { throw new UDFArgumentTypeException(1, "trunk() only takes STRING/CHAR/VARCHAR types as second argument, got " + inputType2); } inputType2 = PrimitiveCategory.STRING; if (arguments[1] instanceof ConstantObjectInspector) { Object obj = ((ConstantObjectInspector) arguments[1]).getWritableConstantValue(); fmtInput = obj != null ? obj.toString() : null; } else { textConverter2 = ObjectInspectorConverters.getConverter(arguments[1], PrimitiveObjectInspectorFactory.writableStringObjectInspector); } return outputOI; } @Override public Object evaluate(DeferredObject[] arguments) throws HiveException { if (dateTypeArg) { return evaluateDate(arguments); } else { return evaluateNumber(arguments); } } private Object evaluateDate(DeferredObject[] arguments) throws UDFArgumentLengthException, HiveException, UDFArgumentTypeException, UDFArgumentException { if (arguments.length != 2) { throw new UDFArgumentLengthException("trunc() requires 2 argument, got " + arguments.length); } if (arguments[0].get() == null || arguments[1].get() == null) { return null; } if (textConverter2 != null) { fmtInput = textConverter2.convert(arguments[1].get()).toString(); } Date date; switch (inputType1) { case STRING: String dateString = textConverter1.convert(arguments[0].get()).toString(); try { date = formatter.parse(dateString.toString()); } catch (ParseException e) { return null; } break; case TIMESTAMP: Timestamp ts = ((TimestampWritable) timestampConverter.convert(arguments[0].get())).getTimestamp(); date = ts; break; case DATE: DateWritable dw = (DateWritable) dateWritableConverter.convert(arguments[0].get()); date = dw.get(); break; default: throw new UDFArgumentTypeException(0, "TRUNC() only takes STRING/TIMESTAMP/DATEWRITABLE types, got " + inputType1); } if (evalDate(date) == null) { return null; } Date newDate = calendar.getTime(); output.set(formatter.format(newDate)); return output; } private Object evaluateNumber(DeferredObject[] arguments) throws HiveException, UDFArgumentTypeException { if (arguments[0] == null) { return null; } Object input = arguments[0].get(); if (input == null) { return null; } if (arguments.length == 2 && arguments[1] != null && arguments[1].get() != null && !inputSacleConst) { Object scaleObj = null; switch (inputScaleOI.getPrimitiveCategory()) { case BYTE: scaleObj = byteConverter.convert(arguments[1].get()); scale = ((ByteWritable) scaleObj).get(); break; case SHORT: scaleObj = shortConverter.convert(arguments[1].get()); scale = ((ShortWritable) scaleObj).get(); break; case INT: scaleObj = intConverter.convert(arguments[1].get()); scale = ((IntWritable) scaleObj).get(); break; case LONG: scaleObj = longConverter.convert(arguments[1].get()); long l = ((LongWritable) scaleObj).get(); if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) { throw new UDFArgumentException( getFuncName().toUpperCase() + " scale argument out of allowed range"); } scale = (int) l; default: break; } } switch (inputType1) { case VOID: return null; case DECIMAL: HiveDecimalWritable decimalWritable = (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input); HiveDecimal dec = trunc(decimalWritable.getHiveDecimal(), scale); if (dec == null) { return null; } return new HiveDecimalWritable(dec); case BYTE: ByteWritable byteWritable = (ByteWritable) inputOI.getPrimitiveWritableObject(input); if (scale >= 0) { return byteWritable; } else { return new ByteWritable((byte) trunc(byteWritable.get(), scale)); } case SHORT: ShortWritable shortWritable = (ShortWritable) inputOI.getPrimitiveWritableObject(input); if (scale >= 0) { return shortWritable; } else { return new ShortWritable((short) trunc(shortWritable.get(), scale)); } case INT: IntWritable intWritable = (IntWritable) inputOI.getPrimitiveWritableObject(input); if (scale >= 0) { return intWritable; } else { return new IntWritable((int) trunc(intWritable.get(), scale)); } case LONG: LongWritable longWritable = (LongWritable) inputOI.getPrimitiveWritableObject(input); if (scale >= 0) { return longWritable; } else { return new LongWritable(trunc(longWritable.get(), scale)); } case FLOAT: float f = ((FloatWritable) inputOI.getPrimitiveWritableObject(input)).get(); return new FloatWritable((float) trunc(f, scale)); case DOUBLE: return trunc(((DoubleWritable) inputOI.getPrimitiveWritableObject(input)), scale); default: throw new UDFArgumentTypeException(0, "Only numeric or string group data types are allowed for TRUNC function. Got " + inputType1.name()); } } @Override public String getDisplayString(String[] children) { return getStandardDisplayString("trunc", children); } private Calendar evalDate(Date d) throws UDFArgumentException { calendar.setTime(d); if ("MONTH".equals(fmtInput) || "MON".equals(fmtInput) || "MM".equals(fmtInput)) { calendar.set(Calendar.DAY_OF_MONTH, 1); return calendar; } else if ("QUARTER".equals(fmtInput) || "Q".equals(fmtInput)) { int month = calendar.get(Calendar.MONTH); int quarter = month / 3; int monthToSet = quarter * 3; calendar.set(Calendar.MONTH, monthToSet); calendar.set(Calendar.DAY_OF_MONTH, 1); return calendar; } else if ("YEAR".equals(fmtInput) || "YYYY".equals(fmtInput) || "YY".equals(fmtInput)) { calendar.set(Calendar.MONTH, 0); calendar.set(Calendar.DAY_OF_MONTH, 1); return calendar; } else { return null; } } protected HiveDecimal trunc(HiveDecimal input, int scale) { BigDecimal bigDecimal = trunc(input.bigDecimalValue(), scale); return HiveDecimal.create(bigDecimal); } protected long trunc(long input, int scale) { return trunc(BigDecimal.valueOf(input), scale).longValue(); } protected double trunc(double input, int scale) { return trunc(BigDecimal.valueOf(input), scale).doubleValue(); } protected DoubleWritable trunc(DoubleWritable input, int scale) { BigDecimal bigDecimal = new BigDecimal(input.get()); BigDecimal trunc = trunc(bigDecimal, scale); DoubleWritable doubleWritable = new DoubleWritable(trunc.doubleValue()); return doubleWritable; } protected BigDecimal trunc(BigDecimal input, int scale) { BigDecimal output = new BigDecimal(0); BigDecimal pow = BigDecimal.valueOf(Math.pow(10, Math.abs(scale))); if (scale >= 0) { pow = BigDecimal.valueOf(Math.pow(10, scale)); if (scale != 0) { long longValue = input.multiply(pow).longValue(); output = BigDecimal.valueOf(longValue).divide(pow); } else { output = BigDecimal.valueOf(input.longValue()); } } else { long longValue2 = input.divide(pow).longValue(); output = BigDecimal.valueOf(longValue2).multiply(pow); } return output; } }