/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.projectflink.als; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import com.github.projectflink.common.als.ALSUtils; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.io.TypeSerializerInputFormat; import org.apache.flink.api.java.io.TypeSerializerOutputFormat; import org.apache.flink.api.java.operators.IterativeDataSet; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.FileSystem.WriteMode; import org.apache.flink.core.fs.Path; import org.apache.flink.util.Collector; import org.jblas.FloatMatrix; import org.jblas.SimpleBlas; import org.jblas.Solve; @SuppressWarnings("serial") public class ALSBroadcastJava { private static final String BC_MATRIX_NAME = "matrix"; private static final long RANDOM_SEED = 0xDEADBADC0FFEEL; public static void main(String[] args) throws Exception { final int numLatentFactors; final int numIterations; final double lambda; final String inPath; final String outPath; final String persistencePath; if (args.length < 6) { numLatentFactors = 5; numIterations = 1; lambda = 1.0; persistencePath = null; inPath = null; outPath = null; } else { numLatentFactors = Integer.parseInt(args[0]); numIterations = Integer.parseInt(args[2]); lambda = Double.parseDouble(args[1]); persistencePath = args[3]; inPath = args[4]; outPath = args[5]; } final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); final DataSet<Rating> ratings = (inPath == null) ? env.fromElements(ALSSampleData.RATINGS_TUPLES) : env.readCsvFile(inPath).tupleType(Rating.class); DataSet<Rating> swappedRatings = ratings.map(new MapFunction<Rating, Rating>() { @Override public Rating map(Rating value) { return new Rating(value.getItem(), value.getUser(), value.getRating()); } }); // group the ratings by Item DataSet<Tuple2<Integer, Pair[]>> ratingsPerUser = groupAndCollectAsArray(ratings); // group the ratings by Item DataSet<Tuple2<Integer, Pair[]>> ratingsPerItem = groupAndCollectAsArray(swappedRatings); // use a random ratings matrix DataSet<Factors> initialItemMatrix = generateRandomMatrix( ratings.<Tuple1<Integer>>project(1).distinct(), numLatentFactors, RANDOM_SEED); if(persistencePath != null){ String path; if(!persistencePath.endsWith("/")){ path = persistencePath + "/"; }else{ path = persistencePath; } Path itemMatrixPath = new Path(path + "initialItemMatrix"); Path ratingsPerUserPath = new Path(path + "userRatings"); Path ratingsPerItemPath = new Path(path + "itemRatings"); TypeSerializerOutputFormat<Factors> iiMatrixOF = new TypeSerializerOutputFormat<Factors>(); iiMatrixOF.setOutputFilePath(itemMatrixPath); iiMatrixOF.setWriteMode(WriteMode.OVERWRITE); initialItemMatrix.output(iiMatrixOF); TypeSerializerOutputFormat<Tuple2<Integer, Pair[]>> userRatingsOF = new TypeSerializerOutputFormat<Tuple2<Integer, Pair[]>>(); userRatingsOF.setOutputFilePath(ratingsPerUserPath); userRatingsOF.setWriteMode(WriteMode.OVERWRITE); ratingsPerUser.output(userRatingsOF); TypeSerializerOutputFormat<Tuple2<Integer, Pair[]>> itemRatingsOF = new TypeSerializerOutputFormat<Tuple2<Integer, Pair[]>>(); itemRatingsOF.setOutputFilePath(ratingsPerItemPath); itemRatingsOF.setWriteMode(WriteMode.OVERWRITE); ratingsPerItem.output(itemRatingsOF); env.execute("Preprocessing"); TypeSerializerInputFormat<Factors> iiMatrixIF = new TypeSerializerInputFormat<Factors>(initialItemMatrix.getType()); iiMatrixIF.setFilePath(itemMatrixPath); initialItemMatrix = env.createInput(iiMatrixIF, initialItemMatrix.getType()); TypeSerializerInputFormat<Tuple2<Integer, Pair[]>> userRatingsIF = new TypeSerializerInputFormat<Tuple2<Integer, Pair[]>>(ratingsPerUser.getType()); userRatingsIF.setFilePath(ratingsPerUserPath); ratingsPerUser = env.createInput(userRatingsIF, ratingsPerUser.getType()); TypeSerializerInputFormat<Tuple2<Integer, Pair[]>> itemRatingsIF = new TypeSerializerInputFormat<Tuple2<Integer, Pair[]>>(ratingsPerItem.getType()); itemRatingsIF.setFilePath(ratingsPerItemPath); ratingsPerItem = env.createInput(itemRatingsIF, ratingsPerItem.getType()); } IterativeDataSet<Factors> itemsIteration = initialItemMatrix.iterate(numIterations); DataSet<Factors> userMatrix = ratingsPerUser.map(new Solver(numLatentFactors, lambda, BC_MATRIX_NAME)).withBroadcastSet(itemsIteration, BC_MATRIX_NAME); DataSet<Factors> itemsMatrix = ratingsPerItem.map(new Solver(numLatentFactors, lambda, BC_MATRIX_NAME)).withBroadcastSet(userMatrix, BC_MATRIX_NAME); DataSet<Factors> itemsResult = itemsIteration.closeWith(itemsMatrix); if(persistencePath != null){ String path = persistencePath; if(!persistencePath.endsWith("/")){ path += "/"; } Path itemsResultPath = new Path(path + "itemsMatrix"); TypeSerializerOutputFormat iOF = new TypeSerializerOutputFormat(); iOF.setOutputFilePath(itemsResultPath); iOF.setWriteMode(WriteMode.OVERWRITE); itemsResult.output(iOF); env.execute("Post iteration"); TypeSerializerInputFormat iIF = new TypeSerializerInputFormat(itemsResult.getType()); iIF.setFilePath(itemsResultPath); itemsResult = env.createInput(iIF, itemsResult.getType()); } DataSet<Factors> usersResult = ratingsPerUser.map(new Solver(numLatentFactors, lambda, BC_MATRIX_NAME)).withBroadcastSet(itemsResult, BC_MATRIX_NAME); if (outPath == null) { usersResult.print(); itemsResult.print(); } else { String path = outPath; if(!outPath.endsWith("/")){ path += "/"; } String usersResultOutPath = path + "usersResult"; String itemsResultOutPath = path + "itemsResult"; usersResult.writeAsText(usersResultOutPath, WriteMode.OVERWRITE); itemsResult.writeAsText(itemsResultOutPath, WriteMode.OVERWRITE); // itemsResult.writeAsCsv(outPath, WriteMode.OVERWRITE); } // System.out.println(env.getExecutionPlan()); env.execute("ALS Broadcast"); } // -------------------------------------------------------------------------------------------- // Utility Methods // -------------------------------------------------------------------------------------------- private static DataSet<Factors> generateRandomMatrix(DataSet<Tuple1<Integer>> ids, final int numFactors, final long seed) { return ids.map(new MapFunction<Tuple1<Integer>, Factors>() { private final Random rnd = new Random(seed); @Override public Factors map(Tuple1<Integer> value) { float[] vals = new float[numFactors]; for (int i = 0; i < numFactors; i++) { vals[i] = rnd.nextFloat(); } return new Factors(value.f0, vals); } }); } private static DataSet<Tuple2<Integer, Pair[]>> groupAndCollectAsArray (DataSet<Rating> ratings) { return ratings.groupBy(0).reduceGroup(new GroupReduceFunction<Rating, Tuple2<Integer, Pair[]>>() { private final List<Pair> list = new ArrayList<Pair>(); @Override public void reduce(Iterable<Rating> values, Collector<Tuple2<Integer, Pair[]>> out) throws Exception { int userID = 0; for (Rating t : values) { userID = t.getUser(); list.add(new Pair(t.getItem(), t.getRating())); } out.collect(new Tuple2<Integer, Pair[]>(userID, list.toArray(new Pair[list.size ()]))); list.clear(); } }); } // -------------------------------------------------------------------------------------------- // Custom Data Types // -------------------------------------------------------------------------------------------- public static final class Rating extends Tuple3<Integer, Integer, Float> { public Rating() {} public Rating(Integer user, Integer item, Float rating) { super(user, item, rating); } public Integer getUser() { return f0; } public Integer getItem() { return f1; } public Float getRating() { return f2; } public void setUser(Integer value) { f0 = value; } public void setItem(Integer value) { f1 = value; } public void setRating(Float value) { f2 = value; } } public static final class Pair extends Tuple2<Integer, Float> { public Pair() {} public Pair(Integer item, Float rating) { super(item, rating); } public Integer getItem() { return f0; } public Float getRating() { return f1; } public void setItem(Integer value){ f0 = value; } public void setRating(Float value){ f1 = value; } } public static final class Factors extends Tuple2<Integer, float[]> { public Factors() {} public Factors(Integer id, float[] factors) { super(id, factors); } public Integer getId() { return f0; } public float[] getFactors() { return f1; } @Override public String toString(){ return "(" + f0 + ", " + Arrays.toString(f1) + ")"; } } public static class Solver extends RichMapFunction<Tuple2<Integer, Pair[]>, Factors>{ private final int numFactors; private final double lambda; private final String bcVarName; private final int triangleSize; private final FloatMatrix xtx; private final FloatMatrix vector; private final FloatMatrix fullMatrix; private Map<Integer, FloatMatrix> matrix = null; public Solver(int numFactors, double lambda, final String bcVarName) { this.numFactors = numFactors; this.lambda = lambda; this.bcVarName = bcVarName; triangleSize = (numFactors*numFactors - numFactors)/2 + numFactors; xtx = FloatMatrix.zeros(triangleSize); vector = FloatMatrix.zeros(numFactors); fullMatrix = FloatMatrix.zeros(numFactors, numFactors); } @Override public void open(Configuration parameters){ matrix = getRuntimeContext().getBroadcastVariableWithInitializer(bcVarName, new MatrixBuilder()); } @Override public Factors map(Tuple2<Integer, Pair[]> integerTuple2) throws Exception { xtx.fill(0.0f); vector.fill(0.0f); int n = integerTuple2.f1.length; for(Pair p: integerTuple2.f1){ FloatMatrix v = matrix.get(p.f0); ALSUtils.outerProductInPlace(v, xtx, numFactors); SimpleBlas.axpy(p.f1, v, vector); } ALSUtils.generateFullMatrix(xtx, fullMatrix, numFactors); for(int i =0; i < numFactors; i++){ fullMatrix.data[i*numFactors + i] += (float)(n * lambda); } return new Factors(integerTuple2.f0, Solve.solvePositive(fullMatrix, vector).data); } } public static class MatrixBuilder implements BroadcastVariableInitializer<Factors, Map<Integer, FloatMatrix>>{ @Override public Map<Integer, FloatMatrix> initializeBroadcastVariable(Iterable<Factors> iterable) { Map<Integer, FloatMatrix> matrix = new HashMap<Integer, FloatMatrix>(); for(Factors factors: iterable){ matrix.put(factors.getId(), new FloatMatrix(factors.getFactors())); } return matrix; } } }