import mpi.MPI;
import mpi.MPIException;
import mpi.MpiOps;
import org.apache.commons.cli.*;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
public class PWeightCalculator {
private String vectorFolder;
private String distFolder;
private boolean mpi = false;
private MpiOps mpiOps;
private int distanceType;
private boolean sharedInput = false;
private int threads;
private BlockingQueue<Work> workQueue = new ArrayBlockingQueue<Work>(64);
private boolean run = true;
public PWeightCalculator(String vectorFolder, String distFolder, boolean mpi, int distanceType, boolean sharedInput, int threads) {
this.vectorFolder = vectorFolder;
this.distFolder = distFolder;
this.mpi = mpi;
this.distanceType = distanceType;
this.sharedInput = sharedInput;
this.threads = threads;
}
public static void main(String[] args) throws InterruptedException {
Options options = new Options();
options.addOption("v", true, "Input Vector folder");
options.addOption("d", true, "Distance matrix folder");
options.addOption("n", false, "normalize");
options.addOption("m", false, "mpi");
options.addOption("t", true, "distance type");
options.addOption("s", false, "shared input directory");
options.addOption(Utils.createOption("f", true, "Single calc", false));
options.addOption(Utils.createOption("tn", true, "Threads", false));
CommandLineParser commandLineParser = new BasicParser();
try {
CommandLine cmd = commandLineParser.parse(options, args);
String _vectorFile = cmd.getOptionValue("v");
String _distFile = cmd.getOptionValue("d");
boolean _normalize = cmd.hasOption("n");
boolean mpi = cmd.hasOption("m");
int distanceType = Integer.parseInt(cmd.getOptionValue("t"));
boolean sharedInput = cmd.hasOption("s");
String singleFile = cmd.getOptionValue("f");
int threads = Integer.parseInt(cmd.getOptionValue("tn"));
if (singleFile == null) {
String print = "vector: " + _vectorFile + " ,distance matrix folder: "
+ _distFile + " ,normalize: "
+ _normalize + " ,mpi: " + mpi
+ " ,distance type: " + distanceType
+ " ,shared input: " + sharedInput;
System.out.println(print);
if (mpi) {
MPI.Init(args);
}
PWeightCalculator program = new PWeightCalculator(_vectorFile, _distFile, mpi, distanceType, sharedInput, threads);
program.process();
if (mpi) {
MPI.Finalize();
}
} else {
PWeightCalculator program = new PWeightCalculator(_vectorFile, _distFile, mpi, distanceType, sharedInput, threads);
program.processFile(new File(singleFile));
}
} catch (MPIException | ParseException e) {
e.printStackTrace();
System.out.println(options.toString());
}
}
private void process() {
System.out.println("Starting Distance calculator...");
File inFolder = new File(vectorFolder);
if (!inFolder.isDirectory()) {
System.out.println("In should be a folder: " + vectorFolder);
return;
}
// create the out directory
Utils.createDirectory(distFolder);
int rank = 0;
int size = 0;
try {
if (mpi) {
mpiOps = new MpiOps();
rank = mpiOps.getRank();
size = mpiOps.getSize();
}
List<File> files = new ArrayList<>();
List<File> list = new ArrayList<File>();
Collections.addAll(list, inFolder.listFiles());
Collections.sort(list);
if (mpi && sharedInput) {
Iterator<File> datesItr = list.iterator();
int i = 0;
while (datesItr.hasNext()) {
File next = datesItr.next();
if (i == rank) {
files.add(next);
}
i++;
if (i == size) {
i = 0;
}
}
} else {
files.addAll(list);
}
// start the threads
for (int i = 0; i < threads; i++) {
Thread t = new Thread(new PartitionWorker(distanceType));
t.start();
}
for (File file : files) {
processFile(file);
}
run = false;
System.out.println("Distance calculator finished...");
} catch (MPIException e) {
throw new RuntimeException("Failed to communicate");
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
private void processFile(File fileEntry) throws InterruptedException {
long start = System.currentTimeMillis();
WriterWrapper writer;
if (fileEntry.isDirectory()) {
return;
}
String outFileName = distFolder + "/" + fileEntry.getName();
System.out.println("Calculator vector file: " + fileEntry.getAbsolutePath() + " Output: " + outFileName);
writer = new WriterWrapper(outFileName, false);
List<VectorPoint> vectors = Utils.readVectors(fileEntry, 0, Integer.MAX_VALUE);
int lineCount = vectors.size();
// initialize the double arrays for this block
double values[][] = new double[lineCount][];
for (int i = 0; i < values.length; i++) {
values[i] = new double[lineCount];
}
List<Double> localMaxValues = new ArrayList<>();
CountDownLatch doneSignal = new CountDownLatch(threads);
// assign values to the workers
assignWorks(lineCount, localMaxValues, vectors, values, doneSignal);
// barrier, wait until the workers finish
doneSignal.await();
double globalMax = Double.MIN_VALUE;
for (Double localMaxValue : localMaxValues) {
if (localMaxValue > globalMax) {
globalMax = localMaxValue;
}
}
// now go through the values and copy the diagonal to next
for (int i = 0; i < values.length; i++) {
for (int j = 0; j < values[i].length; j++) {
if (j > i) {
values[i][j] = values[j][i];
}
}
}
// now write the output
// write the vectors to file
for (int i = 0; i < vectors.size(); i++) {
for (int j = 0; j < values[i].length; j++) {
double doubleValue = values[i][j]/globalMax;
if (doubleValue < 0) {
throw new RuntimeException("Invalid distance");
} else if (doubleValue > 1) {
throw new RuntimeException("Invalid distance");
}
short shortValue = (short) (doubleValue * Short.MAX_VALUE);
writer.writeShort(shortValue);
}
writer.line();
}
writer.close();
long end = System.currentTimeMillis();
System.out.println("Time: " + (end - start) / 1000);
}
private void assignWorks(int lineCount, List<Double> localMaxValues, List<VectorPoint> vectorPoints,
double[][] values, CountDownLatch signal) throws InterruptedException {
// assign values to the workers
int cellsPerWorker = (lineCount * lineCount / 2) / threads;
int currentCellCount = 0;
int currentStart = 0;
for (int i = 0; i < lineCount; i++) {
if (currentCellCount > cellsPerWorker) {
Work work = new Work(currentStart, i, localMaxValues, vectorPoints, values, signal);
workQueue.put(work);
System.out.println(work + " cell count: " + currentCellCount);
currentCellCount = 0;
currentStart = i;
}
currentCellCount += (i + 1);
}
if (currentStart < lineCount) {
Work work = new Work(currentStart, lineCount, localMaxValues, vectorPoints, values, signal);
workQueue.put(work);
System.out.println(work + " cell count: " + currentCellCount);
}
}
private class Work {
int startRow;
int endRow;
private List<Double> localMaxValues;
private List<VectorPoint> vectorPoints;
private double[][] values;
private CountDownLatch signal;
public Work(int startRow, int endRow, List<Double> localMaxValues, List<VectorPoint> vectorPoints, double[][] values, CountDownLatch signal) {
this.startRow = startRow;
this.endRow = endRow;
this.localMaxValues = localMaxValues;
this.vectorPoints = vectorPoints;
this.values = values;
this.signal = signal;
}
@Override
public String toString() {
return "Work{" +
"startRow=" + startRow +
", endRow=" + endRow + " diff: " + (endRow - startRow) + '}';
}
}
private class PartitionWorker implements Runnable {
private int type;
public PartitionWorker(int type) {
this.type = type;
}
@Override
public void run() {
try {
while (run) {
Work work = workQueue.poll(100, TimeUnit.MILLISECONDS);
if (work == null) continue;
System.out.println("Running worker: " + work);
List<VectorPoint> vectorPoints = work.vectorPoints;
double [][]values = work.values;
double max = Double.MIN_VALUE;
double val;
for (int i = work.startRow; i < work.endRow; i++) {
VectorPoint rowVec = vectorPoints.get(i);
for (int j = 0; j <= i; j++) {
VectorPoint colVec = vectorPoints.get(j);
val = rowVec.correlation(colVec, type);
if (val > max) {
max = val;
}
values[i][j] = val;
}
}
work.localMaxValues.add(max);
work.signal.countDown();
}
System.out.println("Worker done...");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}