package mpi;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.util.Arrays;
import java.util.stream.IntStream;
import mpi.*;
public class MpiOps {
static final String ERR_RECV_UNSUPPORTED_MPIPACKET = "Unsupported receive operation on MPIPacket - Check receive type";
private IntBuffer intBuff, intBuff2;
private DoubleBuffer doubleBuff;
private boolean [] booleanBuff;
private Object [] objectSendBuff, objectRecvBuff;
private Intracomm comm;
private int size;
private int rank;
public MpiOps(Intracomm comm) throws MPIException {
intBuff = MPI.newIntBuffer(1);
intBuff2 = MPI.newIntBuffer(1); // temporary int buffer used in sendRecv
doubleBuff = MPI.newDoubleBuffer(1);
booleanBuff = new boolean[1];
objectSendBuff = new Object[1];
objectRecvBuff = new Object[1];
this.comm = comm;
size = comm.getSize();
rank = comm.getRank();
}
public MpiOps() throws MPIException {
this(MPI.COMM_WORLD);
}
public int getSize() {
return size;
}
public int getRank() {
return rank;
}
public Intracomm getComm() {
return comm;
}
/* AllReduce */
public int allReduce(int value, Op reduceOp) throws MPIException {
return allReduce(value, reduceOp, comm);
}
public int allReduce(int value, Op reduceOp, Intracomm comm) throws MPIException {
intBuff.put(0,value);
comm.allReduce(intBuff, 1, MPI.INT, reduceOp);
return intBuff.get(0);
}
public void allReduce(int [] values, Op reduceOp) throws MPIException{
allReduce(values, reduceOp, comm);
}
public void allReduce(int [] values, Op reduceOp, Intracomm comm) throws MPIException {
comm.allReduce(values, values.length, MPI.INT, reduceOp);
}
public double allReduce(double value, Op reduceOp) throws MPIException {
return allReduce(value, reduceOp, comm);
}
public double allReduce(double value, Op reduceOp, Intracomm comm) throws MPIException {
doubleBuff.put(0,value);
comm.allReduce(doubleBuff, 1, MPI.DOUBLE, reduceOp);
return doubleBuff.get(0);
}
public void allReduce(double [] values, Op reduceOp) throws MPIException{
allReduce(values, reduceOp, comm);
}
public void allReduce(double [] values, Op reduceOp, Intracomm comm) throws MPIException {
comm.allReduce(values, values.length, MPI.DOUBLE, reduceOp);
}
public boolean allReduce(boolean value, Op reduceOp) throws MPIException {
return allReduce(value, reduceOp, comm);
}
public boolean allReduce(boolean value, Op reduceOp, Intracomm comm) throws MPIException {
booleanBuff[0] = value;
comm.allReduce(booleanBuff, 1, MPI.BOOLEAN, reduceOp);
return booleanBuff[0];
}
public String allReduce(String value) throws MPIException{
return allReduce(value, comm);
}
// TODO - Perf - Probably need to check performance
public String allReduce(String value, Intracomm comm) throws MPIException {
int [] lengths = new int[size];
int length = value.length();
lengths[rank] = length;
comm.allGather(lengths, 1, MPI.INT);
int [] displas = new int[size];
displas[0] = 0;
System.arraycopy(lengths, 0, displas, 1, size - 1);
Arrays.parallelPrefix(displas, (m, n) -> m + n);
int count = IntStream.of(lengths).sum(); // performs very similar to usual for loop, so no harm done
char [] recv = new char[count];
System.arraycopy(value.toCharArray(), 0,recv, displas[rank], length);
comm.allGatherv(recv, lengths, displas, MPI.CHAR);
return new String(recv);
}
public MPIReducePlusIndex allReduce(MPIReducePlusIndex value, MPIReducePlusIndex.Op reduceOp) throws MPIException {
return allReduce(value, reduceOp, comm);
}
public MPIReducePlusIndex allReduce(MPIReducePlusIndex value, MPIReducePlusIndex.Op reduceOp, Intracomm comm) throws MPIException {
ByteBuffer buffer = value.getBuffer();
if (reduceOp == MPIReducePlusIndex.Op.MAX_WITH_INDEX) {
comm.allReduce(buffer,MPIReducePlusIndex.extent, MPI.BYTE, MPIReducePlusIndex.getMaxWithIndex());
} else if (reduceOp == MPIReducePlusIndex.Op.MIN_WITH_INDEX){
comm.allReduce(buffer, MPIReducePlusIndex.extent, MPI.BYTE, MPIReducePlusIndex.getMinWithIndex());
}
return value;
}
/* AllGather */
public int[] allGather(int value) throws MPIException {
int[] result = new int[size];
allGather(value, result, comm);
return result;
}
public void allGather(int value, int[] result) throws MPIException {
allGather(value, result, comm);
}
public void allGather(int value, int[] result, Intracomm comm) throws MPIException {
intBuff.put(0,value);
comm.allGather(intBuff, 1, MPI.INT, result, 1, MPI.INT);
}
public double [] allGather (double value) throws MPIException {
double [] result = new double[size];
allGather(value, result, comm);
return result;
}
public void allGather(double value, double [] result) throws MPIException {
allGather(value, result, comm);
}
public void allGather(double value, double [] result, Intracomm comm) throws MPIException {
doubleBuff.put(0,value);
comm.allGather(doubleBuff,1,MPI.DOUBLE,result, 1, MPI.DOUBLE);
}
/* Broadcast */
public int broadcast(int value, int root) throws MPIException{
return broadcast(value, root, comm);
}
public int broadcast(int value, int root, Intracomm comm) throws MPIException {
intBuff.put(0, value);
comm.bcast(intBuff, 1, MPI.INT, root);
return intBuff.get(0);
}
public void broadcast(int[] values, int root) throws MPIException {
broadcast(values, root, comm);
}
public void broadcast(int[] values, int root, Intracomm comm) throws MPIException {
comm.bcast(values, values.length, MPI.INT, root);
}
public double broadcast(double value, int root) throws MPIException {
return broadcast(value, root, comm);
}
public double broadcast(double value, int root, Intracomm comm) throws MPIException {
doubleBuff.put(0,value);
comm.bcast(doubleBuff, 1, MPI.DOUBLE, root);
return doubleBuff.get(0);
}
public void broadcast(double[] values, int root) throws MPIException {
broadcast(values, root, comm);
}
public void broadcast(double[] values, int root, Intracomm comm) throws MPIException {
comm.bcast(values, values.length, MPI.DOUBLE, root);
}
public boolean broadcast(boolean value, int root) throws MPIException{
return broadcast(value, root, comm);
}
public boolean broadcast(boolean value, int root, Intracomm comm) throws MPIException {
booleanBuff[0] = value;
comm.bcast(booleanBuff, 1, MPI.BOOLEAN, root);
return booleanBuff[0];
}
public void broadcast (boolean[] values, int root) throws MPIException{
broadcast(values, root, comm);
}
public void broadcast(boolean[] values, int root, Intracomm comm) throws MPIException{
comm.bcast(values, values.length, MPI.BOOLEAN, root);
}
public MPIPacket broadcast(MPIPacket value, int root) throws MPIException{
return broadcast(value, root, comm);
}
public MPIPacket broadcast(MPIPacket value, int root, Intracomm comm) throws MPIException{
comm.bcast(value.getBuffer(),value.getExtent(),MPI.BYTE,root);
return value;
}
/* Sendrecv */
public MPITransportComponentPacket sendReceive(MPITransportComponentPacket sendValue, int dest, int destTag, int src, int srcTag) throws MPIException {
return sendReceive(sendValue, dest, destTag, src, srcTag,comm);
}
public MPITransportComponentPacket sendReceive(MPITransportComponentPacket sendValue, int dest, int destTag, int src, int srcTag, Intracomm comm) throws MPIException {
int sendExtent = sendValue.getExtent();
intBuff.put(0, sendExtent);
comm.sendRecv(intBuff,1,MPI.INT,dest,destTag,intBuff2,1,MPI.INT,src,srcTag);
int recvExtent = intBuff2.get(0);
ByteBuffer recvBuff = MPI.newByteBuffer(recvExtent);
comm.sendRecv(sendValue.getBuffer(),sendExtent,MPI.BYTE,dest,destTag,recvBuff,recvExtent,MPI.BYTE,src,srcTag);
return MPITransportComponentPacket.loadMPITransportComponentPacket(recvBuff);
}
/* Send */
public void send(MPIPacket value, int dest, int tag) throws MPIException {
send(value, dest, tag, comm);
}
public void send(MPIPacket value, int dest, int tag, Intracomm comm) throws MPIException {
int extent = value.getExtent();
intBuff.put(0, extent);
comm.send(intBuff,1,MPI.INT,dest,tag);
comm.send(value.getBuffer(), value.getExtent(), MPI.BYTE, dest, tag);
}
/* Receive */
public MPIPacket receive(int src, int tag, MPIPacket.Type type) throws MPIException {
if (type == MPIPacket.Type.Integer) {
return MPIPacket.loadIntegerPacket(receive(src, tag, comm));
} else if (type == MPIPacket.Type.Double){
return MPIPacket.loadDoublePacket(receive(src, tag, comm));
}
throw new UnsupportedOperationException(ERR_RECV_UNSUPPORTED_MPIPACKET);
}
private ByteBuffer receive(int src, int tag, Intracomm comm) throws MPIException {
comm.recv(intBuff,1,MPI.INT,src,tag);
int extent = intBuff.get(0);
ByteBuffer buffer = MPI.newByteBuffer(extent);
comm.recv(buffer, extent, MPI.BYTE, src, tag);
return buffer;
}
/* Barrier */
public void barrier() throws MPIException {
barrier(MPI.COMM_WORLD);
}
public void barrier(Intracomm comm) throws MPIException {
comm.barrier();
}
}