/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.udf.generic; import java.util.ArrayList; import java.util.List; import junit.framework.Assert; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.parse.TypeCheckProcFactory; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.testutil.BaseScalarUdfTest; import org.apache.hadoop.hive.ql.testutil.DataBuilder; import org.apache.hadoop.hive.ql.testutil.OperatorTestUtils; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.objectinspector.InspectableObject; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.io.FloatWritable; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.LongWritable; import org.junit.Test; public class TestGenericUDFRound extends BaseScalarUdfTest { private static final String[] cols = {"s", "i", "d", "f", "b", "sh", "l", "dec"}; @Override public InspectableObject[] getBaseTable() { DataBuilder db = new DataBuilder(); db.setColumnNames(cols); db.setColumnTypes( PrimitiveObjectInspectorFactory.javaStringObjectInspector, PrimitiveObjectInspectorFactory.javaIntObjectInspector, PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, PrimitiveObjectInspectorFactory.javaFloatObjectInspector, PrimitiveObjectInspectorFactory.javaByteObjectInspector, PrimitiveObjectInspectorFactory.javaShortObjectInspector, PrimitiveObjectInspectorFactory.javaLongObjectInspector, PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(TypeInfoFactory.getDecimalTypeInfo(15, 5))); db.addRow("one", 170, new Double("1.1"), new Float("32.1234"), new Byte("25"), new Short("1234"), 123456L, HiveDecimal.create("983.7235")); db.addRow( "-234", null, null, new Float("0.347232"), new Byte("109"), new Short("551"), 923L, HiveDecimal.create("983723.005")); db.addRow("454.45", 22345, new Double("-23.00009"), new Float("-3.4"), new Byte("76"), new Short("2321"), 9232L, HiveDecimal.create("-932032.7")); return db.createRows(); } @Override public InspectableObject[] getExpectedResult() { DataBuilder db = new DataBuilder(); db.setColumnNames("_col1", "_col2", "_col3", "_col4", "_col5", "_col6", "_col7", "_col8"); db.setColumnTypes(PrimitiveObjectInspectorFactory.javaStringObjectInspector, PrimitiveObjectInspectorFactory.writableIntObjectInspector, PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, PrimitiveObjectInspectorFactory.writableFloatObjectInspector, PrimitiveObjectInspectorFactory.writableByteObjectInspector, PrimitiveObjectInspectorFactory.writableShortObjectInspector, PrimitiveObjectInspectorFactory.writableLongObjectInspector, PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector); db.addRow(null, new IntWritable(170), new DoubleWritable(1.1), new FloatWritable(32f), new ByteWritable((byte)0), new ShortWritable((short)1234), new LongWritable(123500L), new HiveDecimalWritable(HiveDecimal.create("983.724"))); db.addRow(new DoubleWritable(-200), null, null, new FloatWritable(0f), new ByteWritable((byte)100), new ShortWritable((short)551), new LongWritable(900L), new HiveDecimalWritable(HiveDecimal.create("983723.005"))); db.addRow(new DoubleWritable(500), new IntWritable(22345), new DoubleWritable(-23.000), new FloatWritable(-3f), new ByteWritable((byte)100), new ShortWritable((short)2321), new LongWritable(9200L), new HiveDecimalWritable(HiveDecimal.create("-932032.7"))); return db.createRows(); } @Override public List<ExprNodeDesc> getExpressionList() throws UDFArgumentException { List<ExprNodeDesc> exprs = new ArrayList<ExprNodeDesc>(cols.length); for (int i = 0; i < cols.length; i++) { exprs.add(OperatorTestUtils.getStringColumn(cols[i])); } ExprNodeDesc[] scales = { new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, -2), new ExprNodeConstantDesc(TypeInfoFactory.byteTypeInfo, (byte)0), new ExprNodeConstantDesc(TypeInfoFactory.shortTypeInfo, (short)3), new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, 0), new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo, -2L), new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, 0), new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, -2), new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, 3) }; List<ExprNodeDesc> earr = new ArrayList<ExprNodeDesc>(); for (int j = 0; j < cols.length; j++) { ExprNodeDesc r = TypeCheckProcFactory.DefaultExprProcessor.getFuncExprNodeDesc("round", exprs.get(j), scales[j]); earr.add(r); } return earr; } @Test public void testDecimalRoundingMetaData() throws UDFArgumentException { GenericUDFRound udf = new GenericUDFRound(); ObjectInspector[] inputOIs = { PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(7, 3)), PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(TypeInfoFactory.intTypeInfo, new IntWritable(2)) }; PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(inputOIs); DecimalTypeInfo outputTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo(); Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(7, 2), outputTypeInfo); } @Test public void testDecimalRoundingMetaData1() throws UDFArgumentException { GenericUDFRound udf = new GenericUDFRound(); ObjectInspector[] inputOIs = { PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(7, 3)), PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(TypeInfoFactory.intTypeInfo, new IntWritable(-2)) }; PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(inputOIs); DecimalTypeInfo outputTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo(); Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(5, 0), outputTypeInfo); } @Test public void testDecimalRoundingMetaData2() throws UDFArgumentException { GenericUDFRound udf = new GenericUDFRound(); ObjectInspector[] inputOIs = { PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(7, 3)), PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(TypeInfoFactory.intTypeInfo, new IntWritable(5)) }; PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(inputOIs); DecimalTypeInfo outputTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo(); Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(9, 5), outputTypeInfo); } }