/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.cf.taste.hadoop.als;
import com.google.common.io.Closeables;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.map.OpenIntObjectHashMap;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.Map;
/**
* <p>Measures the root-mean-squared error of a ratring matrix factorization against a test set.</p>
*
* <p>Command line arguments specific to this class are:</p>
*
* <ol>
* <li>--output (path): path where output should go</li>
* <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li>
* <li>--userFeatures (path): path to the user feature matrix</li>
* <li>--itemFeatures (path): path to the item feature matrix</li>
* </ol>
*/
public class FactorizationEvaluator extends AbstractJob {
private static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures";
private static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures";
public static void main(String[] args) throws Exception {
ToolRunner.run(new FactorizationEvaluator(), args);
}
@Override
public int run(String[] args) throws Exception {
addInputOption();
addOption("userFeatures", null, "path to the user feature matrix", true);
addOption("itemFeatures", null, "path to the item feature matrix", true);
addOutputOption();
Map<String,String> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
return -1;
}
Path errors = getTempPath("errors");
Job predictRatings = prepareJob(getInputPath(), errors, TextInputFormat.class, PredictRatingsMapper.class,
DoubleWritable.class, NullWritable.class, SequenceFileOutputFormat.class);
predictRatings.getConfiguration().set(USER_FEATURES_PATH, parsedArgs.get("--userFeatures"));
predictRatings.getConfiguration().set(ITEM_FEATURES_PATH, parsedArgs.get("--itemFeatures"));
predictRatings.waitForCompletion(true);
BufferedWriter writer = null;
try {
FileSystem fs = FileSystem.get(getOutputPath().toUri(), getConf());
FSDataOutputStream outputStream = fs.create(getOutputPath("rmse.txt"));
double rmse = computeRmse(errors);
writer = new BufferedWriter(new OutputStreamWriter(outputStream));
writer.write(String.valueOf(rmse));
} finally {
Closeables.closeQuietly(writer);
}
return 0;
}
protected double computeRmse(Path errors) {
RunningAverage average = new FullRunningAverage();
for (Pair<DoubleWritable,NullWritable> entry :
new SequenceFileDirIterable<DoubleWritable, NullWritable>(errors, PathType.LIST, PathFilters.logsCRCFilter(),
getConf())) {
DoubleWritable error = entry.getFirst();
average.addDatum(error.get() * error.get());
}
return Math.sqrt(average.getAverage());
}
public static class PredictRatingsMapper extends Mapper<LongWritable,Text,DoubleWritable,NullWritable> {
private OpenIntObjectHashMap<Vector> U;
private OpenIntObjectHashMap<Vector> M;
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
Path pathToU = new Path(ctx.getConfiguration().get(USER_FEATURES_PATH));
Path pathToM = new Path(ctx.getConfiguration().get(ITEM_FEATURES_PATH));
U = ALSUtils.readMatrixByRows(pathToU, ctx.getConfiguration());
M = ALSUtils.readMatrixByRows(pathToM, ctx.getConfiguration());
}
@Override
protected void map(LongWritable key, Text value, Context ctx) throws IOException, InterruptedException {
String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
int userID = Integer.parseInt(tokens[0]);
int itemID = Integer.parseInt(tokens[1]);
double rating = Double.parseDouble(tokens[2]);
if (U.containsKey(userID) && M.containsKey(itemID)) {
double estimate = U.get(userID).dot(M.get(itemID));
double err = rating - estimate;
ctx.write(new DoubleWritable(err), NullWritable.get());
}
}
}
}