/***********************************************************************************************************************
*
* Copyright (C) 2010-2014 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package hu.sztaki.stratosphere.workshop.batch.customals;
import eu.stratosphere.api.java.IterativeDataSet;
import eu.stratosphere.core.fs.FileSystem;
import eu.stratosphere.api.java.DataSet;
import eu.stratosphere.api.java.ExecutionEnvironment;
import eu.stratosphere.api.java.operators.DataSink;
import eu.stratosphere.api.java.tuple.Tuple3;
//Parameters: [noSubStasks] [matrix] [output] [rank] [numberOfIterations] [lambda]
public class CustomALS {
public static void executeALS(int numSubTasks, String matrixInput, String output, int k,
int numIterations, double lambda) throws Exception {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// input rating matrix
DataSet<MatrixEntry> matrixSource = env.readCsvFile(matrixInput)
.fieldDelimiter('|').lineDelimiter("|\n").includeFields(true, true, true)
.tupleType(MatrixEntry.class);
// create the rowwise partition of A for machines
DataSet<Partition<MatrixEntry>> rowPartitionA = matrixSource
.map(new MultiplyMatrix(numSubTasks, 0)).name("MultiplyingMatrixRows");
// create the columnwise partition of A for machines
DataSet<Partition<MatrixEntry>> colPartitionA = matrixSource
.map(new MultiplyMatrix(numSubTasks, 1)).name("MultiplyingMatrixColumns");
// for creating a random matrix
IterativeDataSet<Partition<MatrixLine>> initialQ = matrixSource
.groupBy(1)
.reduceGroup(new RandomMatrix(numSubTasks, k))
.name("Create q as a random matrix")
.iterate(numIterations);
DataSet<Partition<MatrixLine>> p = rowPartitionA.coGroup(initialQ)
.where(0)
.equalTo(0)
.with(new Iteration(numSubTasks, k, lambda, 0))
.name("P iteration");
DataSet<Partition<MatrixLine>> nextQ = colPartitionA.coGroup(p)
.where(0)
.equalTo(0)
.with(new Iteration(numSubTasks, k, lambda, 1))
.name("Q iteration");
DataSet<Partition<MatrixLine>> q = initialQ.closeWith(nextQ);
p = rowPartitionA.coGroup(q)
.where(0)
.equalTo(0)
.with(new Iteration(numSubTasks, k, lambda, 0))
.name("P iteration");
// delete marker fields
DataSet<MatrixLine> pOutFormat = p.groupBy(0)
.reduceGroup(new OutputFormatter(numSubTasks)).name("P output format");
DataSet<MatrixLine> qOutFormat = q.groupBy(0)
.reduceGroup(new OutputFormatter(numSubTasks)).name("Q output format");
// output
ColumnOutputFormat pFormat = new ColumnOutputFormat(output + "/p");
pFormat.setWriteMode(FileSystem.WriteMode.OVERWRITE);
DataSink<MatrixLine> pSink = pOutFormat.output(pFormat);
ColumnOutputFormat qFormat = new ColumnOutputFormat(output + "/q");
qFormat.setWriteMode(FileSystem.WriteMode.OVERWRITE);
DataSink<MatrixLine> qSink = qOutFormat.output(qFormat);
env.setDegreeOfParallelism(numSubTasks);
env.execute("CustomALS");
}
public static void main(String[] args) throws Exception{
int numSubTasks = 1;
String sampleDB2 = "file://" + CustomALS.class.getResource("/testdata/als_batch/sampledb2.csv");
String sampleDB3 = "file://" + CustomALS.class.getResource("/testdata/als_batch/sampledb3.csv");
String output = "file:///" + System.getProperty("user.dir") + "/als_custom_output";
int k = 5;
int numIterations = 3;
double lambda = 0.1;
CustomALS.executeALS(numSubTasks, sampleDB2, output, k, numIterations, lambda);
}
}