/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.examples.java.ml;
import java.io.Serializable;
import java.util.Collection;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.examples.java.ml.util.LinearRegressionData;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.IterativeDataSet;
/**
* This example implements a basic Linear Regression to solve the y = theta0 + theta1*x problem using batch gradient descent algorithm.
*
* <p>
* Linear Regression with BGD(batch gradient descent) algorithm is an iterative clustering algorithm and works as follows:<br>
* Giving a data set and target set, the BGD try to find out the best parameters for the data set to fit the target set.
* In each iteration, the algorithm computes the gradient of the cost function and use it to update all the parameters.
* The algorithm terminates after a fixed number of iterations (as in this implementation)
* With enough iteration, the algorithm can minimize the cost function and find the best parameters
* This is the Wikipedia entry for the <a href = "http://en.wikipedia.org/wiki/Linear_regression">Linear regression</a> and <a href = "http://en.wikipedia.org/wiki/Gradient_descent">Gradient descent algorithm</a>.
*
* <p>
* This implementation works on one-dimensional data. And find the two-dimensional theta.<br>
* It find the best Theta parameter to fit the target.
*
* <p>
* Input files are plain text files and must be formatted as follows:
* <ul>
* <li>Data points are represented as two double values separated by a blank character. The first one represent the X(the training data) and the second represent the Y(target).
* Data points are separated by newline characters.<br>
* For example <code>"-0.02 -0.04\n5.3 10.6\n"</code> gives two data points (x=-0.02, y=-0.04) and (x=5.3, y=10.6).
* </ul>
*
* <p>
* This example shows how to use:
* <ul>
* <li> Bulk iterations
* <li> Broadcast variables in bulk iterations
* <li> Custom Java objects (PoJos)
* </ul>
*/
@SuppressWarnings("serial")
public class LinearRegression {
// *************************************************************************
// PROGRAM
// *************************************************************************
public static void main(String[] args) throws Exception {
final ParameterTool params = ParameterTool.fromArgs(args);
// set up execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
final int iterations = params.getInt("iterations", 10);
// make parameters available in the web interface
env.getConfig().setGlobalJobParameters(params);
// get input x data from elements
DataSet<Data> data;
if (params.has("input")) {
// read data from CSV file
data = env.readCsvFile(params.get("input"))
.fieldDelimiter(" ")
.includeFields(true, true)
.pojoType(Data.class);
} else {
System.out.println("Executing LinearRegression example with default input data set.");
System.out.println("Use --input to specify file input.");
data = LinearRegressionData.getDefaultDataDataSet(env);
}
// get the parameters from elements
DataSet<Params> parameters = LinearRegressionData.getDefaultParamsDataSet(env);
// set number of bulk iterations for SGD linear Regression
IterativeDataSet<Params> loop = parameters.iterate(iterations);
DataSet<Params> new_parameters = data
// compute a single step using every sample
.map(new SubUpdate()).withBroadcastSet(loop, "parameters")
// sum up all the steps
.reduce(new UpdateAccumulator())
// average the steps and update all parameters
.map(new Update());
// feed new parameters back into next iteration
DataSet<Params> result = loop.closeWith(new_parameters);
// emit result
if(params.has("output")) {
result.writeAsText(params.get("output"));
// execute program
env.execute("Linear Regression example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
result.print();
}
}
// *************************************************************************
// DATA TYPES
// *************************************************************************
/**
* A simple data sample, x means the input, and y means the target.
*/
public static class Data implements Serializable{
public double x,y;
public Data() {};
public Data(double x ,double y){
this.x = x;
this.y = y;
}
@Override
public String toString() {
return "(" + x + "|" + y + ")";
}
}
/**
* A set of parameters -- theta0, theta1.
*/
public static class Params implements Serializable{
private double theta0,theta1;
public Params(){};
public Params(double x0, double x1){
this.theta0 = x0;
this.theta1 = x1;
}
@Override
public String toString() {
return theta0 + " " + theta1;
}
public double getTheta0() {
return theta0;
}
public double getTheta1() {
return theta1;
}
public void setTheta0(double theta0) {
this.theta0 = theta0;
}
public void setTheta1(double theta1) {
this.theta1 = theta1;
}
public Params div(Integer a){
this.theta0 = theta0 / a ;
this.theta1 = theta1 / a ;
return this;
}
}
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/**
* Compute a single BGD type update for every parameters.
*/
public static class SubUpdate extends RichMapFunction<Data,Tuple2<Params,Integer>> {
private Collection<Params> parameters;
private Params parameter;
private int count = 1;
/** Reads the parameters from a broadcast variable into a collection. */
@Override
public void open(Configuration parameters) throws Exception {
this.parameters = getRuntimeContext().getBroadcastVariable("parameters");
}
@Override
public Tuple2<Params, Integer> map(Data in) throws Exception {
for(Params p : parameters){
this.parameter = p;
}
double theta_0 = parameter.theta0 - 0.01*((parameter.theta0 + (parameter.theta1*in.x)) - in.y);
double theta_1 = parameter.theta1 - 0.01*(((parameter.theta0 + (parameter.theta1*in.x)) - in.y) * in.x);
return new Tuple2<Params,Integer>(new Params(theta_0,theta_1),count);
}
}
/**
* Accumulator all the update.
* */
public static class UpdateAccumulator implements ReduceFunction<Tuple2<Params, Integer>> {
@Override
public Tuple2<Params, Integer> reduce(Tuple2<Params, Integer> val1, Tuple2<Params, Integer> val2) {
double new_theta0 = val1.f0.theta0 + val2.f0.theta0;
double new_theta1 = val1.f0.theta1 + val2.f0.theta1;
Params result = new Params(new_theta0,new_theta1);
return new Tuple2<Params, Integer>( result, val1.f1 + val2.f1);
}
}
/**
* Compute the final update by average them.
*/
public static class Update implements MapFunction<Tuple2<Params, Integer>,Params> {
@Override
public Params map(Tuple2<Params, Integer> arg0) throws Exception {
return arg0.f0.div(arg0.f1);
}
}
}