/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.api.operators; import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.functions.util.ListCollector; import org.apache.flink.api.common.state.FoldingState; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.state.KeyedStateStore; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.base.ByteSerializer; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.functions.windowing.FoldApplyProcessAllWindowFunction; import org.apache.flink.streaming.api.functions.windowing.FoldApplyProcessWindowFunction; import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction; import org.apache.flink.streaming.api.functions.windowing.ProcessWindowFunction; import org.apache.flink.streaming.api.graph.StreamGraph; import org.apache.flink.streaming.api.graph.StreamGraphGenerator; import org.apache.flink.streaming.api.transformations.OneInputTransformation; import org.apache.flink.streaming.api.transformations.SourceTransformation; import org.apache.flink.streaming.api.transformations.StreamTransformation; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.streaming.runtime.operators.windowing.AccumulatingProcessingTimeWindowOperator; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessAllWindowFunction; import org.apache.flink.streaming.runtime.operators.windowing.functions.InternalIterableProcessWindowFunction; import org.apache.flink.util.Collector; import org.junit.Assert; import org.junit.Test; import java.util.ArrayList; import java.util.List; public class FoldApplyProcessWindowFunctionTest { /** * Tests that the FoldWindowFunction gets the output type serializer set by the * StreamGraphGenerator and checks that the FoldWindowFunction computes the correct result. */ @Test public void testFoldWindowFunctionOutputTypeConfigurable() throws Exception{ StreamExecutionEnvironment env = new DummyStreamExecutionEnvironment(); List<StreamTransformation<?>> transformations = new ArrayList<>(); int initValue = 1; FoldApplyProcessWindowFunction<Integer, TimeWindow, Integer, Integer, Integer> foldWindowFunction = new FoldApplyProcessWindowFunction<>( initValue, new FoldFunction<Integer, Integer>() { @Override public Integer fold(Integer accumulator, Integer value) throws Exception { return accumulator + value; } }, new ProcessWindowFunction<Integer, Integer, Integer, TimeWindow>() { @Override public void process(Integer integer, Context context, Iterable<Integer> input, Collector<Integer> out) throws Exception { for (Integer in: input) { out.collect(in); } } }, BasicTypeInfo.INT_TYPE_INFO ); AccumulatingProcessingTimeWindowOperator<Integer, Integer, Integer> windowOperator = new AccumulatingProcessingTimeWindowOperator<>( new InternalIterableProcessWindowFunction<>(foldWindowFunction), new KeySelector<Integer, Integer>() { private static final long serialVersionUID = -7951310554369722809L; @Override public Integer getKey(Integer value) throws Exception { return value; } }, IntSerializer.INSTANCE, IntSerializer.INSTANCE, 3000, 3000 ); SourceFunction<Integer> sourceFunction = new SourceFunction<Integer>(){ private static final long serialVersionUID = 8297735565464653028L; @Override public void run(SourceContext<Integer> ctx) throws Exception { } @Override public void cancel() { } }; SourceTransformation<Integer> source = new SourceTransformation<>("", new StreamSource<>(sourceFunction), BasicTypeInfo.INT_TYPE_INFO, 1); transformations.add(new OneInputTransformation<>(source, "test", windowOperator, BasicTypeInfo.INT_TYPE_INFO, 1)); StreamGraph streamGraph = StreamGraphGenerator.generate(env, transformations); List<Integer> result = new ArrayList<>(); List<Integer> input = new ArrayList<>(); List<Integer> expected = new ArrayList<>(); input.add(1); input.add(2); input.add(3); for (int value : input) { initValue += value; } expected.add(initValue); FoldApplyProcessWindowFunction<Integer, TimeWindow, Integer, Integer, Integer>.Context ctx = foldWindowFunction.new Context() { @Override public TimeWindow window() { return new TimeWindow(0, 1); } @Override public long currentProcessingTime() { return 0; } @Override public long currentWatermark() { return 0; } @Override public KeyedStateStore windowState() { return new DummyKeyedStateStore(); } @Override public KeyedStateStore globalState() { return new DummyKeyedStateStore(); } }; foldWindowFunction.open(new Configuration()); foldWindowFunction.process(0, ctx, input, new ListCollector<>(result)); Assert.assertEquals(expected, result); } /** * Tests that the FoldWindowFunction gets the output type serializer set by the * StreamGraphGenerator and checks that the FoldWindowFunction computes the correct result. */ @Test public void testFoldAllWindowFunctionOutputTypeConfigurable() throws Exception{ StreamExecutionEnvironment env = new DummyStreamExecutionEnvironment(); List<StreamTransformation<?>> transformations = new ArrayList<>(); int initValue = 1; FoldApplyProcessAllWindowFunction<TimeWindow, Integer, Integer, Integer> foldWindowFunction = new FoldApplyProcessAllWindowFunction<>( initValue, new FoldFunction<Integer, Integer>() { @Override public Integer fold(Integer accumulator, Integer value) throws Exception { return accumulator + value; } }, new ProcessAllWindowFunction<Integer, Integer, TimeWindow>() { @Override public void process(Context context, Iterable<Integer> input, Collector<Integer> out) throws Exception { for (Integer in: input) { out.collect(in); } } }, BasicTypeInfo.INT_TYPE_INFO ); AccumulatingProcessingTimeWindowOperator<Byte, Integer, Integer> windowOperator = new AccumulatingProcessingTimeWindowOperator<>( new InternalIterableProcessAllWindowFunction<>(foldWindowFunction), new KeySelector<Integer, Byte>() { private static final long serialVersionUID = -7951310554369722809L; @Override public Byte getKey(Integer value) throws Exception { return 0; } }, ByteSerializer.INSTANCE, IntSerializer.INSTANCE, 3000, 3000 ); SourceFunction<Integer> sourceFunction = new SourceFunction<Integer>(){ private static final long serialVersionUID = 8297735565464653028L; @Override public void run(SourceContext<Integer> ctx) throws Exception { } @Override public void cancel() { } }; SourceTransformation<Integer> source = new SourceTransformation<>("", new StreamSource<>(sourceFunction), BasicTypeInfo.INT_TYPE_INFO, 1); transformations.add(new OneInputTransformation<>(source, "test", windowOperator, BasicTypeInfo.INT_TYPE_INFO, 1)); StreamGraph streamGraph = StreamGraphGenerator.generate(env, transformations); List<Integer> result = new ArrayList<>(); List<Integer> input = new ArrayList<>(); List<Integer> expected = new ArrayList<>(); input.add(1); input.add(2); input.add(3); for (int value : input) { initValue += value; } expected.add(initValue); FoldApplyProcessAllWindowFunction<TimeWindow, Integer, Integer, Integer>.Context ctx = foldWindowFunction.new Context() { @Override public TimeWindow window() { return new TimeWindow(0, 1); } @Override public KeyedStateStore windowState() { return new DummyKeyedStateStore(); } @Override public KeyedStateStore globalState() { return new DummyKeyedStateStore(); } }; foldWindowFunction.open(new Configuration()); foldWindowFunction.process(ctx, input, new ListCollector<>(result)); Assert.assertEquals(expected, result); } public static class DummyKeyedStateStore implements KeyedStateStore { @Override public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) { return null; } @Override public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) { return null; } @Override public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) { return null; } @Override public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) { return null; } @Override public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) { return null; } } public static class DummyStreamExecutionEnvironment extends StreamExecutionEnvironment { @Override public JobExecutionResult execute(String jobName) throws Exception { return null; } } }