package skywriting.examples.skyhout.linalg;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.serializer.Serialization;
import org.apache.hadoop.io.serializer.WritableSerialization;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.function.PlusMult;
import skywriting.examples.skyhout.common.SkywritingTaskFileSystem;
import uk.co.mrry.mercator.task.JarTaskLoader;
import uk.co.mrry.mercator.task.Task;
public class ConjugateGradientReduceTask implements Task {
@Override
public void invoke(InputStream[] fis, OutputStream[] fos, String[] args) {
try {
Configuration conf = new Configuration();
conf.setClassLoader(JarTaskLoader.CLASSLOADER);
conf.setClass("io.serializations", WritableSerialization.class, Serialization.class);
SkywritingTaskFileSystem fs = new SkywritingTaskFileSystem(fis, fos, conf);
assert args.length == 2;
// args[0] is the rank of the vector.
int rank = Integer.parseInt(args[0]);
// args[1] is the "first time" flag.
boolean firstRound = Boolean.parseBoolean(args[1]);
// First round, m + 1 inputs: [Ap], b.
// Subsequent rounds, m + 3 inputs: [Ap], x, r and p.
// args[2] is epsilon.
double epsilon = Double.parseDouble(args[2]);
DenseVector aTimesOldP;
DenseVector oldX;
DenseVector oldR;
DenseVector oldP;
if (firstRound) {
aTimesOldP = VectorMerger.mergeInputs(fs, fs.numInputs() - 1, rank);
oldX = new DenseVector(new double[rank]);
oldR = VectorMerger.readSingleVectorFile(fs, new Path("/in/" + (fis.length - 1)));
oldP = oldR.clone();
} else {
aTimesOldP = VectorMerger.mergeInputs(fs, fs.numInputs() - 3, rank);
oldX = VectorMerger.readSingleVectorFile(fs, new Path("/in/" + (fis.length - 3)));
oldR = VectorMerger.readSingleVectorFile(fs, new Path("/in/" + (fis.length - 2)));
oldP = VectorMerger.readSingleVectorFile(fs, new Path("/in/" + (fis.length - 1)));
}
// Computational phase.
double oldRdotOldR = oldR.dotSelf();
double oldPdotAtimesOldP = oldP.dot(aTimesOldP);
double alpha = oldRdotOldR / oldPdotAtimesOldP;
DenseVector newX = (DenseVector) oldX.assign(oldP, new PlusMult(alpha));
DenseVector newR = (DenseVector) oldR.assign(oldP, new PlusMult(-alpha));
boolean converged = newR.norm(2.0) < epsilon;
// 4 outputs: converged?, newX, newR, newP. We don't write newR or newP if converged.
Writer convergedOutput = new OutputStreamWriter(fos[0]);
convergedOutput.write(Boolean.toString(converged));
VectorMerger.writeResultFile(fs, new Path("/out/1"), newX);
if (!converged) {
double beta = oldRdotOldR / newR.dotSelf();
DenseVector newP = (DenseVector) newR.assign(oldP, new PlusMult(beta));
VectorMerger.writeResultFile(fs, new Path("/out/2"), newR);
VectorMerger.writeResultFile(fs, new Path("/out/3"), newP);
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
}