/* (c) 2014 LinkedIn Corp. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use * this file except in compliance with the License. You may obtain a copy of the * License at http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR * CONDITIONS OF ANY KIND, either express or implied. */ package com.linkedin.cubert.functions.builtin; import org.apache.pig.backend.executionengine.ExecException; import org.apache.pig.data.Tuple; import com.linkedin.cubert.block.BlockSchema; import com.linkedin.cubert.block.ColumnType; import com.linkedin.cubert.block.DataType; import com.linkedin.cubert.functions.Function; import com.linkedin.cubert.operator.PreconditionException; import com.linkedin.cubert.operator.PreconditionExceptionType; /** * Suite of built-in arithmetic functions. * <p> * If the inputs are numerical types, these builtin functions will generate output in the * wider data type. If any of the input is null, these functions will output null. * <p> * Following functions are implemented: * <ul> * <li>{@literal +}</li> * <li>{@literal -}</li> * <li>{@literal /}</li> * <li>{@literal *}</li> * <li>{@literal %}</li> * <li>{@literal <<}</li> * </ul> * * @author Maneesh Varshney * */ public class ArithmeticFunction extends Function { private final FunctionType type; private DataType outputType; public ArithmeticFunction(FunctionType type) { this.type = type; } @Override public Object eval(Tuple tuple) throws ExecException { Object o1 = tuple.get(0); Object o2 = tuple.get(1); if (o1 == null || o2 == null) return null; switch (type) { case ADD: return add(o1, o2); case DIVIDE: return divide(o1, o2); case MINUS: return minus(o1, o2); case MOD: return mod(o1, o2); case TIMES: return times(o1, o2); case LSHIFT: return lshift(o1, o2); case RSHIFT: return rshift(o1, o2); default: break; } return null; } private Object add(Object o1, Object o2) { switch (outputType) { case DOUBLE: return ((Number) o1).doubleValue() + ((Number) o2).doubleValue(); case FLOAT: return ((Number) o1).floatValue() + ((Number) o2).floatValue(); case INT: return ((Number) o1).intValue() + ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() + ((Number) o2).longValue(); default: break; } return null; } private Object minus(Object o1, Object o2) { switch (outputType) { case DOUBLE: return ((Number) o1).doubleValue() - ((Number) o2).doubleValue(); case FLOAT: return ((Number) o1).floatValue() - ((Number) o2).floatValue(); case INT: return ((Number) o1).intValue() - ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() - ((Number) o2).longValue(); default: break; } return null; } private Object times(Object o1, Object o2) { switch (outputType) { case DOUBLE: return ((Number) o1).doubleValue() * ((Number) o2).doubleValue(); case FLOAT: return ((Number) o1).floatValue() * ((Number) o2).floatValue(); case INT: return ((Number) o1).intValue() * ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() * ((Number) o2).longValue(); default: break; } return null; } private Object divide(Object o1, Object o2) { switch (outputType) { case DOUBLE: return ((Number) o1).doubleValue() / ((Number) o2).doubleValue(); case FLOAT: return ((Number) o1).floatValue() / ((Number) o2).floatValue(); case INT: return ((Number) o1).intValue() / ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() / ((Number) o2).longValue(); default: break; } return null; } private Object mod(Object o1, Object o2) { switch (outputType) { case INT: return ((Number) o1).intValue() % ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() % ((Number) o2).longValue(); default: break; } return null; } private Object rshift(Object o1, Object o2) { switch (outputType) { case INT: return ((Number) o1).intValue() >>> ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() >>> ((Number) o2).longValue(); default: break; } return null; } private Object lshift(Object o1, Object o2) { switch (outputType) { case INT: return ((Number) o1).intValue() << ((Number) o2).intValue(); case LONG: return ((Number) o1).longValue() << ((Number) o2).longValue(); default: break; } return null; } @Override public ColumnType outputSchema(BlockSchema inputSchema) throws PreconditionException { ColumnType type1 = inputSchema.getColumnType(0); ColumnType type2 = inputSchema.getColumnType(1); if ((type == FunctionType.LSHIFT || type == FunctionType.MOD || type == FunctionType.RSHIFT) && (!type1.getType().isIntOrLong() || !type2.getType().isIntOrLong())) throw new PreconditionException(PreconditionExceptionType.INVALID_SCHEMA, "The LHS and RHS of " + type + " function must be int or long"); if (!type1.getType().isNumerical()) throw new PreconditionException(PreconditionExceptionType.INVALID_SCHEMA, "The LHS of " + type + " function is not numerical"); if (!type2.getType().isNumerical()) throw new PreconditionException(PreconditionExceptionType.INVALID_SCHEMA, "The RHS of " + type + " function is not numerical"); outputType = DataType.getWiderType(type1.getType(), type2.getType()); return new ColumnType(null, outputType); } }