/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.streaming.examples.ml; import org.apache.flink.api.java.utils.ParameterTool; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.api.windowing.time.Time; import org.apache.flink.streaming.api.windowing.windows.TimeWindow; import org.apache.flink.util.Collector; import java.util.concurrent.TimeUnit; /** * Skeleton for incremental machine learning algorithm consisting of a * pre-computed model, which gets updated for the new inputs and new input data * for which the job provides predictions. * * <p> * This may serve as a base of a number of algorithms, e.g. updating an * incremental Alternating Least Squares model while also providing the * predictions. * * <p> * This example shows how to use: * <ul> * <li>Connected streams * <li>CoFunctions * <li>Tuple data types * </ul> */ public class IncrementalLearningSkeleton { // ************************************************************************* // PROGRAM // ************************************************************************* public static void main(String[] args) throws Exception { // Checking input parameters final ParameterTool params = ParameterTool.fromArgs(args); StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime); DataStream<Integer> trainingData = env.addSource(new FiniteTrainingDataSource()); DataStream<Integer> newData = env.addSource(new FiniteNewDataSource()); // build new model on every second of new data DataStream<Double[]> model = trainingData .assignTimestampsAndWatermarks(new LinearTimestamp()) .timeWindowAll(Time.of(5000, TimeUnit.MILLISECONDS)) .apply(new PartialModelBuilder()); // use partial model for newData DataStream<Integer> prediction = newData.connect(model).map(new Predictor()); // emit result if (params.has("output")) { prediction.writeAsText(params.get("output")); } else { System.out.println("Printing result to stdout. Use --output to specify output path."); prediction.print(); } // execute program env.execute("Streaming Incremental Learning"); } // ************************************************************************* // USER FUNCTIONS // ************************************************************************* /** * Feeds new data for newData. By default it is implemented as constantly * emitting the Integer 1 in a loop. */ public static class FiniteNewDataSource implements SourceFunction<Integer> { private static final long serialVersionUID = 1L; private int counter; @Override public void run(SourceContext<Integer> ctx) throws Exception { Thread.sleep(15); while (counter < 50) { ctx.collect(getNewData()); } } @Override public void cancel() { // No cleanup needed } private Integer getNewData() throws InterruptedException { Thread.sleep(5); counter++; return 1; } } /** * Feeds new training data for the partial model builder. By default it is * implemented as constantly emitting the Integer 1 in a loop. */ public static class FiniteTrainingDataSource implements SourceFunction<Integer> { private static final long serialVersionUID = 1L; private int counter = 0; @Override public void run(SourceContext<Integer> collector) throws Exception { while (counter < 8200) { collector.collect(getTrainingData()); } } @Override public void cancel() { // No cleanup needed } private Integer getTrainingData() throws InterruptedException { counter++; return 1; } } public static class LinearTimestamp implements AssignerWithPunctuatedWatermarks<Integer> { private static final long serialVersionUID = 1L; private long counter = 0L; @Override public long extractTimestamp(Integer element, long previousElementTimestamp) { return counter += 10L; } @Override public Watermark checkAndGetNextWatermark(Integer lastElement, long extractedTimestamp) { return new Watermark(counter - 1); } } /** * Builds up-to-date partial models on new training data. */ public static class PartialModelBuilder implements AllWindowFunction<Integer, Double[], TimeWindow> { private static final long serialVersionUID = 1L; protected Double[] buildPartialModel(Iterable<Integer> values) { return new Double[]{1.}; } @Override public void apply(TimeWindow window, Iterable<Integer> values, Collector<Double[]> out) throws Exception { out.collect(buildPartialModel(values)); } } /** * Creates newData using the model produced in batch-processing and the * up-to-date partial model. * <p> * By defaults emits the Integer 0 for every newData and the Integer 1 * for every model update. * </p> */ public static class Predictor implements CoMapFunction<Integer, Double[], Integer> { private static final long serialVersionUID = 1L; Double[] batchModel = null; Double[] partialModel = null; @Override public Integer map1(Integer value) { // Return newData return predict(value); } @Override public Integer map2(Double[] value) { // Update model partialModel = value; batchModel = getBatchModel(); return 1; } // pulls model built with batch-job on the old training data protected Double[] getBatchModel() { return new Double[]{0.}; } // performs newData using the two models protected Integer predict(Integer inTuple) { return 0; } } }