/*********************************************************************************************************************** * * Copyright (C) 2010-2014 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. * **********************************************************************************************************************/ package hu.sztaki.stratosphere.workshop.batch.customals; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import Jama.Matrix; import eu.stratosphere.api.java.functions.CoGroupFunction; import eu.stratosphere.api.java.tuple.Tuple2; import eu.stratosphere.util.Collector; public class Iteration extends CoGroupFunction<Partition<MatrixEntry>, Partition<MatrixLine>, Partition<MatrixLine>> { private int k; private int numOfTasks; private double lambda; private int idx; public Iteration(int numTasks, int k, double lambda, int index) { this.numOfTasks = numTasks; this.k = k; this.idx = index; this.lambda = lambda; } @Override public void coGroup(Iterator<Partition<MatrixEntry>> entries, Iterator<Partition<MatrixLine>> lines, Collector<Partition<MatrixLine>> out)throws Exception { Map<Integer, List<Tuple2<Integer, Double>>> matrixElements = new HashMap<Integer, List<Tuple2<Integer, Double>>>(); Map<Integer, double[]> vectors = new HashMap<Integer, double[]>(); if(!lines.hasNext()){ return; } while (entries.hasNext()) { MatrixEntry entry = entries.next().f1; addToMatrixMap(matrixElements, entry); } while(lines.hasNext()){ MatrixLine line = lines.next().f1; addToVectorMap(vectors, line); } for (int recordIndex : matrixElements.keySet()) { //solve the linear equation system corresponding to this machine Matrix p = compute(recordIndex, matrixElements, vectors); writeOutput(p, recordIndex, out); } } private void writeOutput(Matrix p, int recordIndex, Collector<Partition<MatrixLine>> out) { double[] output_elements = new double[k]; for (int i = 0; i < k; ++i) { output_elements[i] = p.get(i, 0); } //TODO: set the element of the output vector and collect it with all machineIDs //Hint: the output has the following format: (machineID,FALSE,recordIndex,ZERO,output_elements) } private Matrix compute(int recordIndex, Map<Integer, List<Tuple2<Integer, Double>>> matrixElements, Map<Integer, double[]> vectors) { double[][] matrix = new double[k][k]; double element = lambda; //Lambda-regularization if(lambda != 0.0) { for(double[] row : matrix) { Arrays.fill(row, element); } } double[][] column = new double[k][1]; List<Tuple2<Integer, Double>> list = matrixElements.get(recordIndex); for (Tuple2<Integer, Double> pair : list) { double rating = pair.f1; double[] vector = vectors.get(pair.f0); for (int i = 0; i < k; ++i) { column[i][0] += rating * vector[i]; } for (int i = 0; i < k; ++i) { for (int j = 0; j <= i; ++j) { matrix[i][j] += vector[i] * vector[j]; } } } for (int i = 0; i < k; ++i) { for (int j = i + 1; j < k; ++j) { matrix[i][j] = matrix[j][i]; } } Matrix a = new Matrix(matrix); Matrix b = new Matrix(column); Matrix p = a./*chol().*/solve(b); return p; } private void addToVectorMap(Map<Integer, double[]> vectors, MatrixLine line) { //TODO: add the given vector's element to the map //Hint: the map contains (columnID, vectorOfTheElements) pairs } private void addToMatrixMap(Map<Integer, List<Tuple2<Integer, Double>>> map, MatrixEntry entry) throws Exception { int recordIndex = entry.getField(idx); int otherIndex = entry.getField(1-idx); double value = entry.getEntry(); //Hint: each IntDoublePair is a (otherIndex,value) pair. A List of these object corresponds to each vector which marked for update and has the recordIndex identifier. //TODO: store the incoming record's fields in the given map, but make sure there is no duplication of the data } }