/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.mahout.math.hadoop.solver; import java.io.IOException; import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.apache.mahout.math.hadoop.DistributedRowMatrix; import org.apache.mahout.math.solver.ConjugateGradientSolver; import org.apache.mahout.math.solver.Preconditioner; /** * Distributed implementation of the conjugate gradient solver. More or less, this is just the standard solver * but wrapped with some methods that make it easy to run it on a DistributedRowMatrix. */ public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool { private Configuration conf; private Map<String, String> parsedArgs; /** * * Runs the distributed conjugate gradient solver programmatically to solve the system (A + lambda*I)x = b. * * @param inputPath Path to the matrix A * @param tempPath Path to scratch output path, deleted after the solver completes * @param numRows Number of rows in A * @param numCols Number of columns in A * @param b Vector b * @param preconditioner Optional preconditioner for the system * @param maxIterations Maximum number of iterations to run, defaults to numCols * @param maxError Maximum error tolerated in the result. If the norm of the residual falls below this, then the * algorithm stops and returns. * @return The vector that solves the system. */ public Vector runJob(Path inputPath, Path tempPath, int numRows, int numCols, Vector b, Preconditioner preconditioner, int maxIterations, double maxError) { DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, tempPath, numRows, numCols); matrix.setConf(conf); return solve(matrix, b, preconditioner, maxIterations, maxError); } @Override public Configuration getConf() { return conf; } @Override public void setConf(Configuration conf) { this.conf = conf; } @Override public int run(String[] strings) throws Exception { Path inputPath = new Path(parsedArgs.get("--input")); Path outputPath = new Path(parsedArgs.get("--output")); Path tempPath = new Path(parsedArgs.get("--tempDir")); Path vectorPath = new Path(parsedArgs.get("--vector")); int numRows = Integer.parseInt(parsedArgs.get("--numRows")); int numCols = Integer.parseInt(parsedArgs.get("--numCols")); int maxIterations = parsedArgs.containsKey("--maxIter") ? Integer.parseInt(parsedArgs.get("--maxIter")) : numCols; double maxError = parsedArgs.containsKey("--maxError") ? Double.parseDouble(parsedArgs.get("--maxError")) : ConjugateGradientSolver.DEFAULT_MAX_ERROR; Vector b = loadInputVector(vectorPath); Vector x = runJob(inputPath, tempPath, numRows, numCols, b, null, maxIterations, maxError); saveOutputVector(outputPath, x); tempPath.getFileSystem(conf).delete(tempPath, true); return 0; } public DistributedConjugateGradientSolverJob job() { return new DistributedConjugateGradientSolverJob(); } private Vector loadInputVector(Path path) throws IOException { FileSystem fs = path.getFileSystem(conf); SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf); VectorWritable value = new VectorWritable(); try { if (!reader.next(new IntWritable(), value)) { throw new IOException("Input vector file is empty."); } return value.get(); } finally { reader.close(); } } private void saveOutputVector(Path path, Vector v) throws IOException { FileSystem fs = path.getFileSystem(conf); SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class); try { writer.append(new IntWritable(0), new VectorWritable(v)); } finally { writer.close(); } } public class DistributedConjugateGradientSolverJob extends AbstractJob { @Override public void setConf(Configuration conf) { DistributedConjugateGradientSolver.this.setConf(conf); } @Override public Configuration getConf() { return DistributedConjugateGradientSolver.this.getConf(); } @Override public int run(String[] args) throws Exception { addInputOption(); addOutputOption(); addOption("numRows", "nr", "Number of rows in the input matrix", true); addOption("numCols", "nc", "Number of columns in the input matrix", true); addOption("vector", "b", "Vector to solve against", true); addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0"); addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true"); addOption("maxIter", "x", "Maximum number of iterations to run"); addOption("maxError", "err", "Maximum residual error to allow before stopping"); DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(args); if (DistributedConjugateGradientSolver.this.parsedArgs == null) { return -1; } else { DistributedConjugateGradientSolver.this.setConf(new Configuration()); return DistributedConjugateGradientSolver.this.run(args); } } } public static void main(String[] args) throws Exception { ToolRunner.run(new DistributedConjugateGradientSolver().job(), args); } }