/* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.hadoop.stochasticsvd;
import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;
/**
* Mahout CLI adapter for SSVDSolver
*/
public class SSVDCli extends AbstractJob {
@Override
public int run(String[] args) throws Exception {
addInputOption();
addOutputOption();
addOption("rank", "k", "decomposition rank", true);
addOption("oversampling", "p", "oversampling", String.valueOf(15));
addOption("blockHeight",
"r",
"Y block height (must be > (k+p))",
String.valueOf(10000));
addOption("outerProdBlockHeight",
"oh",
"block height of outer products during multiplication, increase for sparse inputs",
String.valueOf(30000));
addOption("abtBlockHeight",
"abth",
"block height of Y_i in ABtJob during AB' multiplication, increase for extremely sparse inputs",
String.valueOf(200000));
addOption("minSplitSize", "s", "minimum split size", String.valueOf(-1));
addOption("computeU", "U", "compute U (true/false)", String.valueOf(true));
addOption("uHalfSigma",
"uhs",
"Compute U as UHat=U x pow(Sigma,0.5)",
String.valueOf(false));
addOption("computeV", "V", "compute V (true/false)", String.valueOf(true));
addOption("vHalfSigma",
"vhs",
"compute V as VHat= V x pow(Sigma,0.5)",
String.valueOf(false));
addOption("reduceTasks",
"t",
"number of reduce tasks (where applicable)",
true);
addOption("powerIter",
"q",
"number of additional power iterations (0..2 is good)",
String.valueOf(0));
addOption("broadcast",
"br",
"whether use distributed cache to broadcast matrices wherever possible",
String.valueOf(true));
addOption(DefaultOptionCreator.overwriteOption().create());
Map<String, String> pargs = parseArguments(args);
if (pargs == null) {
return -1;
}
int k = Integer.parseInt(pargs.get("--rank"));
int p = Integer.parseInt(pargs.get("--oversampling"));
int r = Integer.parseInt(pargs.get("--blockHeight"));
int h = Integer.parseInt(pargs.get("--outerProdBlockHeight"));
int abh = Integer.parseInt(pargs.get("--abtBlockHeight"));
int q = Integer.parseInt(pargs.get("--powerIter"));
int minSplitSize = Integer.parseInt(pargs.get("--minSplitSize"));
boolean computeU = Boolean.parseBoolean(pargs.get("--computeU"));
boolean computeV = Boolean.parseBoolean(pargs.get("--computeV"));
boolean cUHalfSigma = Boolean.parseBoolean(pargs.get("--uHalfSigma"));
boolean cVHalfSigma = Boolean.parseBoolean(pargs.get("--vHalfSigma"));
int reduceTasks = Integer.parseInt(pargs.get("--reduceTasks"));
boolean broadcast = Boolean.parseBoolean(pargs.get("--broadcast"));
boolean overwrite =
pargs.containsKey(keyFor(DefaultOptionCreator.OVERWRITE_OPTION));
Configuration conf = getConf();
if (conf == null) {
throw new IOException("No Hadoop configuration present");
}
SSVDSolver solver =
new SSVDSolver(conf,
new Path[] { getInputPath() },
getTempPath(),
r,
k,
p,
reduceTasks);
solver.setMinSplitSize(minSplitSize);
solver.setComputeU(computeU);
solver.setComputeV(computeV);
solver.setcUHalfSigma(cUHalfSigma);
solver.setcVHalfSigma(cVHalfSigma);
solver.setOuterBlockHeight(h);
solver.setAbtBlockHeight(abh);
solver.setQ(q);
solver.setBroadcast(broadcast);
solver.setOverwrite(overwrite);
solver.run();
// housekeeping
FileSystem fs = FileSystem.get(conf);
fs.mkdirs(getOutputPath());
SequenceFile.Writer sigmaW = null;
try {
sigmaW =
SequenceFile.createWriter(fs,
conf,
getOutputPath("sigma"),
NullWritable.class,
VectorWritable.class);
Writable sValues =
new VectorWritable(new DenseVector(Arrays.copyOf(solver.getSingularValues(),
k),
true));
sigmaW.append(NullWritable.get(), sValues);
} finally {
Closeables.closeQuietly(sigmaW);
}
if (computeU) {
FileStatus[] uFiles = fs.globStatus(new Path(solver.getUPath()));
if (uFiles != null) {
for (FileStatus uf : uFiles) {
fs.rename(uf.getPath(), getOutputPath());
}
}
}
if (computeV) {
FileStatus[] vFiles = fs.globStatus(new Path(solver.getVPath()));
if (vFiles != null) {
for (FileStatus vf : vFiles) {
fs.rename(vf.getPath(), getOutputPath());
}
}
}
return 0;
}
public static void main(String[] args) throws Exception {
ToolRunner.run(new SSVDCli(), args);
}
}