package edu.washington.escience.myria.operator; import static org.junit.Assert.assertEquals; import org.junit.Test; import com.fasterxml.jackson.databind.ObjectReader; import com.fasterxml.jackson.databind.ObjectWriter; import com.google.common.collect.ImmutableList; import edu.washington.escience.myria.Schema; import edu.washington.escience.myria.Type; import edu.washington.escience.myria.api.MyriaJsonMapperProvider; import edu.washington.escience.myria.expression.ConditionalExpression; import edu.washington.escience.myria.expression.ConstantExpression; import edu.washington.escience.myria.expression.Expression; import edu.washington.escience.myria.expression.ExpressionOperator; import edu.washington.escience.myria.expression.GreaterThanExpression; import edu.washington.escience.myria.expression.PlusExpression; import edu.washington.escience.myria.expression.StateExpression; import edu.washington.escience.myria.expression.VariableExpression; import edu.washington.escience.myria.operator.agg.Aggregate; import edu.washington.escience.myria.operator.agg.AggregatorFactory; import edu.washington.escience.myria.operator.agg.UserDefinedAggregatorFactory; import edu.washington.escience.myria.storage.TupleBatch; import edu.washington.escience.myria.storage.TupleBatchBuffer; import edu.washington.escience.myria.storage.TupleUtils; import edu.washington.escience.myria.util.TestEnvVars; public class UserDefinedAggregatorTest { private final ObjectReader reader = MyriaJsonMapperProvider.getReader().withType(AggregatorFactory.class); private final ObjectWriter writer = MyriaJsonMapperProvider.getWriter(); private final int NUM_TUPLES = 2 * TupleUtils.getBatchSize(Type.LONG_TYPE); private final int NUM_TUPLES_20K = 2 * 10000; /** * Tests a re-implementation of the Count aggregate using a user-defined aggregate. Also tests serialization and * deserialization. * * @throws Exception if something goes wrong. */ @Test public void testCount() throws Exception { final Schema schema = new Schema(ImmutableList.of(Type.STRING_TYPE), ImmutableList.of("name")); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); for (long i = 0; i < NUM_TUPLES; i++) { tbb.putString(0, "Foo" + i); } Expression initializer = new Expression("counter", new ConstantExpression(0L)); Expression increment = new Expression( "counter", new PlusExpression(new StateExpression(0), new ConstantExpression(1L))); Expression emitter = new Expression("index", new StateExpression(0)); ImmutableList.Builder<Expression> Initializers = ImmutableList.builder(); Initializers.add(initializer); ImmutableList.Builder<Expression> Updaters = ImmutableList.builder(); Updaters.add(increment); ImmutableList.Builder<Expression> Emitters = ImmutableList.builder(); Emitters.add(emitter); AggregatorFactory factory = new UserDefinedAggregatorFactory(Initializers.build(), Updaters.build(), Emitters.build()); factory = reader.readValue(writer.writeValueAsString(factory)); Aggregate agg = new Aggregate(new BatchTupleSource(tbb), new int[] {}, factory); agg.open(TestEnvVars.get()); TupleBatch result; int resultSize = 0; while (!agg.eos()) { result = agg.nextReady(); if (result != null) { assertEquals(1, result.numTuples()); assertEquals(1, result.numColumns()); assertEquals(Type.LONG_TYPE, result.getSchema().getColumnType(0)); assertEquals(NUM_TUPLES, result.getLong(0, 0)); resultSize += result.numTuples(); } } assertEquals(1, resultSize); agg.close(); } /** * Tests a re-implementation of the Count aggregate using a user-defined aggregate that also implements a constant * value column. Also tests serialization and deserialization. * * @throws Exception if something goes wrong. */ @Test public void testCountAndConst() throws Exception { final Schema schema = new Schema(ImmutableList.of(Type.STRING_TYPE), ImmutableList.of("name")); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); for (long i = 0; i < NUM_TUPLES; i++) { tbb.putString(0, "Foo" + i); } Expression initializer = new Expression("counter", new ConstantExpression(0L)); Expression increment = new Expression( "counter", new PlusExpression(new StateExpression(0), new ConstantExpression(1L))); Expression emitter = new Expression("index", new StateExpression(0)); Expression constEmitter = new Expression("const", new ConstantExpression(5L)); ImmutableList.Builder<Expression> Initializers = ImmutableList.builder(); Initializers.add(initializer); ImmutableList.Builder<Expression> Updaters = ImmutableList.builder(); Updaters.add(increment); ImmutableList.Builder<Expression> Emitters = ImmutableList.builder(); Emitters.add(emitter); Emitters.add(constEmitter); AggregatorFactory factory = new UserDefinedAggregatorFactory(Initializers.build(), Updaters.build(), Emitters.build()); factory = reader.readValue(writer.writeValueAsString(factory)); Aggregate agg = new Aggregate(new BatchTupleSource(tbb), new int[] {}, factory); agg.open(TestEnvVars.get()); TupleBatch result; int resultSize = 0; while (!agg.eos()) { result = agg.nextReady(); if (result != null) { assertEquals(1, result.numTuples()); assertEquals(2, result.numColumns()); assertEquals(Type.LONG_TYPE, result.getSchema().getColumnType(0)); assertEquals(NUM_TUPLES, result.getLong(0, 0)); assertEquals(5L, result.getLong(1, 0)); resultSize += result.numTuples(); } } assertEquals(1, resultSize); agg.close(); } /** * Tests an arg-max-like aggregate function. Also tests serialization and deserialization. * * @throws Exception if something goes wrong. */ @Test public void testRowOfMax() throws Exception { final Schema schema = new Schema(ImmutableList.of(Type.STRING_TYPE), ImmutableList.of("name")); final TupleBatchBuffer tbb = new TupleBatchBuffer(schema); for (long i = 0; i < NUM_TUPLES_20K; i++) { tbb.putString(0, "Foo" + i); } ImmutableList.Builder<Expression> Initializers = ImmutableList.builder(); Initializers.add(new Expression("counter", new ConstantExpression(0L))); Initializers.add(new Expression("maxrow", new ConstantExpression(-1L))); Initializers.add(new Expression("maxval", new ConstantExpression(""))); ImmutableList.Builder<Expression> Updaters = ImmutableList.builder(); // State.$0 counts the index of the current row. Updaters.add( new Expression( "counter", new PlusExpression(new StateExpression(0), new ConstantExpression(1L)))); ExpressionOperator newRowIsBigger = new GreaterThanExpression(new VariableExpression(0), new StateExpression(2)); // State.$1 tracks the index of the biggest row. Updaters.add( new Expression( "maxrow", new ConditionalExpression( newRowIsBigger, new StateExpression(0), new StateExpression(1)))); // State.$2 tracks the value of the biggest row. Updaters.add( new Expression( "maxval", new ConditionalExpression( newRowIsBigger, new VariableExpression(0), new StateExpression(2)))); ImmutableList.Builder<Expression> Emitters = ImmutableList.builder(); Emitters.add(new Expression("indexOfMax", new StateExpression(1))); Emitters.add(new Expression("max", new StateExpression(2))); AggregatorFactory factory = new UserDefinedAggregatorFactory(Initializers.build(), Updaters.build(), Emitters.build()); factory = reader.readValue(writer.writeValueAsString(factory)); Aggregate agg = new Aggregate(new BatchTupleSource(tbb), new int[] {}, factory); agg.open(TestEnvVars.get()); TupleBatch result; int resultSize = 0; while (!agg.eos()) { result = agg.nextReady(); if (result != null) { assertEquals(1, result.numTuples()); assertEquals(2, result.numColumns()); assertEquals(Type.LONG_TYPE, result.getSchema().getColumnType(0)); assertEquals(Type.STRING_TYPE, result.getSchema().getColumnType(1)); assertEquals(10000, result.getLong(0, 0)); assertEquals("Foo9999", result.getString(1, 0)); resultSize += result.numTuples(); } } assertEquals(1, resultSize); agg.close(); } }