/***********************************************************************************************************************
*
* Copyright (C) 2010-2014 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package hu.sztaki.stratosphere.workshop.batch.als;
import eu.stratosphere.api.java.IterativeDataSet;
import eu.stratosphere.core.fs.FileSystem;
import eu.stratosphere.api.java.operators.DataSink;
import eu.stratosphere.api.java.ExecutionEnvironment;
import eu.stratosphere.api.java.DataSet;
import eu.stratosphere.api.java.tuple.Tuple3;
import eu.stratosphere.api.java.tuple.Tuple2;
public class ALS {
public static void executeALS(int numSubTasks, String matrixInput, String output, int k,
int numIterations, double lambda) throws Exception{
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
//input rating matrix
DataSet<Tuple3<Integer,Integer,Double>> matrixSource = env.readCsvFile(matrixInput)
.fieldDelimiter('|')
.lineDelimiter("|\n")
.includeFields(true, true, true)
.types(Integer.class,Integer.class,Double.class);
//for random q matrix as input
DataSet<Tuple2<Integer,double[]>> q = matrixSource
.groupBy(1)
.reduceGroup(new RandomMatrix(k))
.name("Create q as a random matrix");
IterativeDataSet<Tuple2<Integer, double[]>> initialQ = q.iterate(numIterations);
DataSet<Tuple3<Integer, Integer, double[]>> multipliedQ = matrixSource.join(initialQ)
.where(1)
.equalTo(0)
.with(new MultiplyVector())
.name("Sends the columns of q with multiple keys");
DataSet<Tuple2<Integer,double[]>> p = matrixSource.coGroup(multipliedQ)
.where(0).equalTo(0)
.with( new PIteration(k,lambda))
.name("For fixed q calculates optimal p");
DataSet<Tuple3<Integer,Integer, double[]>> multipliedP = matrixSource.join(p)
.where(0).equalTo(0)
.with( new MultiplyVector())
.name("Sends the rows of p with multiple keys)");
DataSet<Tuple2<Integer, double[]>> nextQ = matrixSource.coGroup(multipliedP)
.where(1).equalTo(1)
.with( new QIteration(k,lambda))
.name("For fixed p calculates optimal q");
q = initialQ.closeWith(nextQ);
multipliedQ = matrixSource.join(q)
.where(1)
.equalTo(0)
.with(new MultiplyVector())
.name("Sends the columns of q with multiple keys");
p = matrixSource.coGroup(multipliedQ)
.where(0).equalTo(0)
.with( new PIteration(k,lambda))
.name("For fixed q calculates optimal p");
//output:
ColumnOutputFormat pFormat = new ColumnOutputFormat(output + "/p");
pFormat.setWriteMode(FileSystem.WriteMode.OVERWRITE);
DataSink<Tuple2<Integer,double[]>> pOut = p.output(pFormat);
ColumnOutputFormat qFormat = new ColumnOutputFormat(output + "/q");
qFormat.setWriteMode(FileSystem.WriteMode.OVERWRITE);
DataSink<Tuple2<Integer,double[]>> qOut = q.output(qFormat);
env.setDegreeOfParallelism(numSubTasks);
env.execute("ALS");
}
public static void main(String[] args) throws Exception {
int numSubTasks = 1;
String sampleDB2 = "file://" + ALS.class.getResource("/testdata/als_batch/sampledb2.csv").getPath();
String sampleDB3 = "file://" + ALS.class.getResource("/testdata/als_batch/sampledb3.csv").getPath();
String output = "file:///" + System.getProperty("user.dir") + "/als_output";
int k = 5;
int numIterations = 3;
double lambda = 0.1;
executeALS(numSubTasks, sampleDB2, output, k, numIterations, lambda);
}
}