/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.pig.test; import static org.junit.Assert.assertEquals; import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Random; import org.apache.pig.Algebraic; import org.apache.pig.ComparisonFunc; import org.apache.pig.EvalFunc; import org.apache.pig.FuncSpec; import org.apache.pig.backend.executionengine.ExecException; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.POStatus; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.PhysicalOperator; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.Result; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.expressionOperators.POUserComparisonFunc; import org.apache.pig.backend.hadoop.executionengine.physicalLayer.expressionOperators.POUserFunc; import org.apache.pig.data.BagFactory; import org.apache.pig.data.DataBag; import org.apache.pig.data.DataType; import org.apache.pig.data.Tuple; import org.apache.pig.data.TupleFactory; import org.apache.pig.impl.logicalLayer.schema.Schema; import org.apache.pig.impl.plan.OperatorKey; import org.apache.pig.test.utils.GenRandomData; import org.junit.Test; public class TestPOUserFunc { Random r = new Random(42L); int MAX_TUPLES = 10; public static class ARITY extends EvalFunc<Integer> { @Override public Integer exec(Tuple input) throws IOException { try { return new Integer(((Tuple)input.get(0)).size()); } catch (ExecException e) { // TODO Auto-generated catch block e.printStackTrace(); } return 0; } @Override public Schema outputSchema(Schema input) { return new Schema(new Schema.FieldSchema(null, DataType.INTEGER)); } } public static class WeirdComparator extends ComparisonFunc { @Override public int compare(Tuple t1, Tuple t2) { // TODO Auto-generated method stub Object o1 = null; Object o2 = null; try { o1 = t1.get(2); o2 = t2.get(2); } catch (ExecException e) { // TODO Auto-generated catch block e.printStackTrace(); } if ( o1==null || o2==null ){ return -1; } int i1 = (Integer) o1 - 2; int i2 = (Integer) o2 - 2; return (int) (i1 * i1 - i2 * i2); } } /** * Generates the average of the values of the first field of a tuple. This * class is Algebraic in implemenation, so if possible the execution will be * split into a local and global application */ public static class AVG extends EvalFunc<Double> implements Algebraic { private static TupleFactory mTupleFactory = TupleFactory.getInstance(); @Override public Double exec(Tuple input) throws IOException { double sum = 0; double count = 0; try { sum = sum(input); count = count(input); } catch (ExecException e) { e.printStackTrace(); } double avg = 0; if (count > 0) avg = sum / count; return new Double(avg); } public String getInitial() { return Initial.class.getName(); } public String getIntermed() { return Intermed.class.getName(); } public String getFinal() { return Final.class.getName(); } static public class Initial extends EvalFunc<Tuple> { @Override public Tuple exec(Tuple input) throws IOException { try { Tuple t = mTupleFactory.newTuple(2); t.set(0, new Double(sum(input))); t.set(1, new Long(count(input))); return t; } catch (ExecException t) { throw new RuntimeException(t.getMessage() + ": " + input); } } } static public class Intermed extends EvalFunc<Tuple> { @Override public Tuple exec(Tuple input) throws IOException { DataBag b = null; Tuple t = null; try { b = (DataBag) input.get(0); t = combine(b); } catch (ExecException e) { // TODO Auto-generated catch block e.printStackTrace(); } return t; } } static public class Final extends EvalFunc<Double> { @Override public Double exec(Tuple input) throws IOException { double sum = 0; double count = 0; try { DataBag b = (DataBag) input.get(0); Tuple combined = combine(b); sum = (Double) combined.get(0); count = (Long) combined.get(1); } catch (ExecException e) { e.printStackTrace(); } double avg = 0; if (count > 0) { avg = sum / count; } return new Double(avg); } } static protected Tuple combine(DataBag values) throws ExecException { double sum = 0; long count = 0; Tuple output = mTupleFactory.newTuple(2); for (Iterator<Tuple> it = values.iterator(); it.hasNext();) { Tuple t = it.next(); sum += (Double) t.get(0); count += (Long) t.get(1); } output.set(0, new Double(sum)); output.set(1, new Long(count)); return output; } static protected long count(Tuple input) throws ExecException { DataBag values = (DataBag) input.get(0); return values.size(); } static protected double sum(Tuple input) throws ExecException { DataBag values = (DataBag) input.get(0); double sum = 0; for (Iterator<Tuple> it = values.iterator(); it.hasNext();) { Tuple t = it.next(); Double d = DataType.toDouble(t.get(0)); if (d == null) continue; sum += d; } return sum; } @Override public Schema outputSchema(Schema input) { return new Schema(new Schema.FieldSchema(null, DataType.DOUBLE)); } } @Test public void testUserFuncArity() throws ExecException { DataBag input = (DataBag) GenRandomData.genRandSmallTupDataBag(r, MAX_TUPLES, 100); userFuncArity( input ); } @Test public void testUserFuncArityWithNulls() throws ExecException { DataBag input = (DataBag) GenRandomData.genRandSmallTupDataBagWithNulls(r, MAX_TUPLES, 100); userFuncArity( input ); } public void userFuncArity(DataBag input ) throws ExecException { String funcSpec = ARITY.class.getName() + "()"; PORead read = new PORead(new OperatorKey("", r.nextLong()), input); List<PhysicalOperator> inputs = new LinkedList<PhysicalOperator>(); inputs.add(read); POUserFunc userFunc = new POUserFunc(new OperatorKey("", r.nextLong()), -1, inputs, new FuncSpec(funcSpec)); Result res = new Result(); Integer i = null; res = userFunc.getNextInteger(); while (res.returnStatus != POStatus.STATUS_EOP) { // System.out.println(res.result); int result = (Integer) res.result; assertEquals(2, result); res = userFunc.getNextInteger(); } } @Test public void testUDFCompare() throws ExecException { DataBag input = (DataBag) GenRandomData.genRandSmallTupDataBag(r, 2, 100); udfCompare(input); } @Test public void testUDFCompareWithNulls() throws ExecException { DataBag input = (DataBag) GenRandomData.genRandSmallTupDataBagWithNulls(r, 2, 100); udfCompare(input); } public void udfCompare(DataBag input) throws ExecException { String funcSpec = WeirdComparator.class.getName() + "()"; POUserComparisonFunc userFunc = new POUserComparisonFunc(new OperatorKey("", r.nextLong()), -1, null, new FuncSpec(funcSpec)); Iterator<Tuple> it = input.iterator(); Tuple t1 = it.next(); Tuple t2 = it.next(); t1.append(2); t2.append(3); userFunc.attachInput(t1, t2); Integer i = null; // System.out.println(t1 + " " + t2); int result = (Integer) (userFunc.getNextInteger().result); assertEquals(-1, result); } @Test public void testAlgebraicAVG() throws IOException, ExecException { Integer input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; algebraicAVG( input, 55.0, 10L, 110.0, 20L, 5.5 ); } /* NOTE: for calculating the average * * A pig "count" will include data that had "null",and the sum will * A pig "count" will include data that had "null",and the sum will * treat the null as a 0, impacting the average * A SQL "count" will exclude data that had "null" */ @Test public void testAlgebraicAVGWithNulls() throws IOException, ExecException { Integer input[] = { 1, 2, 3, 4, null, 6, 7, 8, 9, 10 }; algebraicAVG( input, 50.0, 10L, 100.0, 20L, 5.0 ); } public void algebraicAVG( Integer[] input , Double initialExpectedSum, Long initialExpectedCount , Double intermedExpectedSum, Long intermedExpectedCount , Double expectedAvg ) throws IOException, ExecException { // generate data byte INIT = 0; byte INTERMED = 1; byte FINAL = 2; Tuple tup1 = Util.loadNestTuple(TupleFactory.getInstance().newTuple(1), input); Tuple tup2 = Util.loadNestTuple(TupleFactory.getInstance().newTuple(1), input); // System.out.println("Input = " + tup1); String funcSpec = AVG.class.getName() + "()"; POUserFunc po = new POUserFunc(new OperatorKey("", r.nextLong()), -1, null, new FuncSpec(funcSpec)); //************ Initial Calculations ****************** TupleFactory tf = TupleFactory.getInstance(); po.setAlgebraicFunction(INIT); po.attachInput(tup1); Tuple t = null; Result res = po.getNextTuple(); Tuple outputInitial1 = (res.returnStatus == POStatus.STATUS_OK) ? (Tuple) res.result : null; Tuple outputInitial2 = (res.returnStatus == POStatus.STATUS_OK) ? (Tuple) res.result : null; System.out.println(outputInitial1 + " " + outputInitial2); assertEquals(outputInitial1, outputInitial2); Double sum = (Double) outputInitial1.get(0); Long count = (Long) outputInitial1.get(1); assertEquals(initialExpectedSum, sum); assertEquals(initialExpectedCount, count); //************ Intermediate Data and Calculations ****************** DataBag bag = BagFactory.getInstance().newDefaultBag(); bag.add(outputInitial1); bag.add(outputInitial2); Tuple outputInitial = tf.newTuple(); outputInitial.append(bag); // Tuple outputIntermed = intermed.exec(outputInitial); po = new POUserFunc(new OperatorKey("", r.nextLong()), -1, null, new FuncSpec(funcSpec)); po.setAlgebraicFunction(INTERMED); po.attachInput(outputInitial); res = po.getNextTuple(); Tuple outputIntermed = (res.returnStatus == POStatus.STATUS_OK) ? (Tuple) res.result : null; sum = (Double) outputIntermed.get(0); count = (Long) outputIntermed.get(1); assertEquals(intermedExpectedSum, sum); assertEquals(intermedExpectedCount, count); System.out.println(outputIntermed); //************ Final Calculations ****************** po = new POUserFunc(new OperatorKey("", r.nextLong()), -1, null, new FuncSpec(funcSpec)); po.setAlgebraicFunction(FINAL); po.attachInput(outputInitial); res = po.getNextTuple(); Double output = (res.returnStatus == POStatus.STATUS_OK) ? (Double) res.result : null; // Double output = fin.exec(outputInitial); assertEquals((Double)expectedAvg, output); // System.out.println("output = " + output); } }