/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package de.gaalop.gappopencl;
import de.gaalop.gapp.ConstantSetVectorArgument;
import de.gaalop.gapp.PairSetOfVariablesAndIndices;
import de.gaalop.gapp.PosSelector;
import de.gaalop.gapp.Selector;
import de.gaalop.gapp.SelectorIndex;
import de.gaalop.gapp.SetVectorArgument;
import de.gaalop.gapp.instructionSet.CalculationType;
import de.gaalop.gapp.instructionSet.GAPPAssignInputsVector;
import de.gaalop.gapp.instructionSet.GAPPAssignMv;
import de.gaalop.gapp.instructionSet.GAPPCalculateMv;
import de.gaalop.gapp.instructionSet.GAPPCalculateMvCoeff;
import de.gaalop.gapp.instructionSet.GAPPDotVectors;
import de.gaalop.gapp.instructionSet.GAPPResetMv;
import de.gaalop.gapp.instructionSet.GAPPSetMv;
import de.gaalop.gapp.instructionSet.GAPPSetVector;
import de.gaalop.gapp.variables.GAPPConstant;
import de.gaalop.gapp.variables.GAPPMultivector;
import de.gaalop.gapp.variables.GAPPMultivectorComponent;
import de.gaalop.gapp.variables.GAPPScalarVariable;
import de.gaalop.gapp.variables.GAPPValueHolder;
import de.gaalop.gapp.variables.GAPPVector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.StringTokenizer;
/**
*
* @author patrick
*/
public class GAPPOpenCLVisitor extends de.gaalop.gapp.visitor.CFGGAPPVisitor
implements de.gaalop.gapp.variables.GAPPVariableVisitor {
protected static int dotCount = 0;
protected static final String lo = ".lo";
protected static final String hi = ".hi";
protected static final int maxOpenCLVectorSize = 16;
protected Map<String,Integer> mvSizes;
protected boolean gpcMetaInfo = true;
protected Map<String,Map<Integer,String>> mvBladeMap = new HashMap<String,Map<Integer,String>>();
protected StringBuilder result = new StringBuilder();
public GAPPOpenCLVisitor(Map<String, Integer> mvSizes) {
this.mvSizes = mvSizes;
}
@Override
public Object visitResetMv(GAPPResetMv gappResetMv, Object arg) {
final String destMv = GAPPOpenCLCodeGenerator.getVarName(gappResetMv.getDestination().getName());
if(gpcMetaInfo && !destMv.startsWith(GAPPOpenCLCodeGenerator.tempMv))
result.append("//#pragma gpc multivector ").append(destMv).append("\n");
printOpenCLVectorType(computeNearestOpenCLVectorSize(mvSizes.get(destMv)));
result.append(" ");
result.append(destMv).append(";\n");
mvBladeMap.put(destMv,new HashMap<Integer,String>());
return null;
}
@Override
public Object visitSetMv(GAPPSetMv gappSetMv, Object arg) {
final String destMv = GAPPOpenCLCodeGenerator.getVarName(gappSetMv.getDestination().getName());
Integer thisMvSetCount = mvBladeMap.get(destMv).size();
int selCount = 0;
for(PosSelector sel : gappSetMv.getSelectorsDest()) {
declareGPCMultivectorComponent(destMv, thisMvSetCount, sel);
final String bladeCoeff = getBladeCoeff(destMv,thisMvSetCount);
result.append(bladeCoeff);
result.append(" = ");
if(gappSetMv.getSelectorsSrc().get(0).getSign() < 0)
result.append("-");
result.append(mvBladeMap.get(GAPPOpenCLCodeGenerator.getVarName(gappSetMv.getSource().getName())).get(gappSetMv.getSelectorsSrc().get(selCount++).getIndex()));
result.append(";\n");
mvBladeMap.get(destMv).put(sel.getIndex(),bladeCoeff);
++thisMvSetCount;
}
return null;
}
protected String getBladeCoeff(String mv,int blade) {
return mv + ((mvSizes.get(mv) == 1) ? "" : (".s" + getOpenCLIndex(blade)));
}
protected static String formatBladeName(final String bladeName) {
if(bladeName.equals("1.0") || bladeName.equals("1.0f"))
return "1";
// remove whitespaces from bladeIndex
StringTokenizer tokenizer = new StringTokenizer(bladeName," \t\n\r\f()");
StringBuilder bladeBuffer = new StringBuilder();
while(tokenizer.hasMoreTokens())
bladeBuffer.append(tokenizer.nextToken());
return bladeBuffer.toString();
}
protected void declareGPCMultivectorComponent(String mv, Integer blade, SelectorIndex sel) {
if(!gpcMetaInfo || mv.startsWith(GAPPOpenCLCodeGenerator.tempMv))
return;
result.append("//#pragma gpc multivector_component ");
result.append(mv);
result.append(" ").append(formatBladeName(sel.getBladeName()));
result.append(" ").append(mv);
if(mvSizes.get(mv) > 1)
result.append(".s").append(blade);
result.append("\n");
}
@Override
public Object visitSetVector(GAPPSetVector gappSetVector, Object arg) {
// determine vector sizes
int wholeVectorSize = 0;
for(SetVectorArgument setVectorArg : gappSetVector.getEntries()) {
if(setVectorArg.isConstant())
++wholeVectorSize;
else
wholeVectorSize += ((PairSetOfVariablesAndIndices)setVectorArg).getSelectors().size();
}
// get destVecBase
final String destVecBase = GAPPOpenCLCodeGenerator.getVarName(gappSetVector.getDestination().getName());
// collect all entries as string
ArrayList<String> entries = new ArrayList<String>();
Iterator<SetVectorArgument> itSetVectorArg = gappSetVector.getEntries().iterator();
while(itSetVectorArg.hasNext()) {
final SetVectorArgument setVectorArg = itSetVectorArg.next();
if(setVectorArg.isConstant())
entries.add(String.valueOf(((ConstantSetVectorArgument)setVectorArg).getValue()));
else {
PairSetOfVariablesAndIndices pair = (PairSetOfVariablesAndIndices)setVectorArg;
// get a fresh selector iterator
Iterator<Selector> itSelector = pair.getSelectors().iterator();
// get all entries from selectors
while (itSelector.hasNext())
entries.add(visitSelector(itSelector.next(),
GAPPOpenCLCodeGenerator.getVarName(pair.getSetOfVariable().getName())));
}
}
// parallel multiply operation
Map<Integer,String> bladeMap = new HashMap<Integer,String>();
Iterator<String> itEntries = entries.iterator();
int vectorSizeRemainder = wholeVectorSize;
int subvectorIndex = 0;
int bladeGlobalIndex = 0;
do {
// compute nearest OpenCL vector size
final int openCLVectorSize = computeNearestOpenCLVectorSize(vectorSizeRemainder);
// set destVec name
final String destVec = destVecBase + "_" + subvectorIndex;
// print declaration
printOpenCLVectorType(openCLVectorSize);
result.append(" ");
result.append(destVec);
result.append(" = (");
printOpenCLVectorType(openCLVectorSize);
result.append(")(");
// print entries
int bladeLocalIndex = 0;
// print first entry
if(itEntries.hasNext()) {
result.append(itEntries.next());
++bladeLocalIndex;
}
// print further entries
while(bladeLocalIndex++ < openCLVectorSize && itEntries.hasNext()) {
result.append(",");
result.append(itEntries.next());
}
// fill remaining vector space with zeros
assert(bladeLocalIndex > 0); // cannot be the first entry
while(bladeLocalIndex++ < openCLVectorSize)
result.append(",0");
// print end of line
result.append(");\n");
// update bladeIndex map
for(bladeLocalIndex = 0; bladeLocalIndex < openCLVectorSize; ++bladeLocalIndex)
bladeMap.put(bladeGlobalIndex++,destVec + ".s" + bladeLocalIndex);
mvBladeMap.put(destVecBase,bladeMap);
// compute vector size remainder
vectorSizeRemainder -= maxOpenCLVectorSize; // if wholeVectorSize < 16 all should be done now
// increment subvector index
++subvectorIndex;
} while(vectorSizeRemainder > 0);
return null;
}
protected void printOpenCLVectorType(final int openCLVectorSize) {
result.append("float");
if(openCLVectorSize != 1)
result.append(openCLVectorSize);
}
protected String visitSelector(final Selector sel, final String sourceName) {
StringBuilder out = new StringBuilder();
if (sel.getSign() < 0)
out.append("-");
final String lookupBladeCoeff = mvBladeMap.get(sourceName).get(sel.getIndex());
if(lookupBladeCoeff == null) {
if(sourceName.equals("1.0"))
out.append("1");
else
out.append(sourceName);
}
else
out.append(lookupBladeCoeff);
return out.toString();
}
@Override
public Object visitCalculateMv(GAPPCalculateMv gappCalculateMv, Object arg) {
return null;
}
protected int computeNearestOpenCLVectorSize(final int in) {
if(in <= 0)
return -1;
else if(in == 1)
return in;
else if(in == 2)
return in;
else if(in <= 4)
return 4;
else if(in <= 8)
return 8;
else
return maxOpenCLVectorSize;
}
protected String getOpenCLIndex(Integer index) {
if(index < 10)
return index.toString();
else switch(index) {
case 10:
return "a";
case 11:
return "b";
case 12:
return "c";
case 13:
return "d";
case 14:
return "e";
case 15:
return "f";
}
assert(false);
return "fail";
}
@Override
public Object visitCalculateMvCoeff(GAPPCalculateMvCoeff gappCalculateMvCoeff, Object arg) {
final String destMv = GAPPOpenCLCodeGenerator.getVarName(gappCalculateMvCoeff.getDestination().getName());
final Integer thisMvSetCount = mvBladeMap.get(destMv).size();
final String bladeCoeff = getBladeCoeff(destMv,thisMvSetCount);
result.append(bladeCoeff);
result.append(" = ");
visitCalculateOp(gappCalculateMvCoeff.getType(),
GAPPOpenCLCodeGenerator.getVarName(gappCalculateMvCoeff.getOperand1().getName()),
GAPPOpenCLCodeGenerator.getVarName(gappCalculateMvCoeff.getOperand2().getName()));
result.append(";\n");
mvBladeMap.get(destMv).put(gappCalculateMvCoeff.getDestination().getBladeIndex(),bladeCoeff);
return null;
}
protected void visitCalculateOp(CalculationType type,String operand1,String operand2) {
switch (type) {
case DIVISION:
result.append("(");
result.append(operand1);
result.append(") / (");
result.append(operand2);
result.append(")");
break;
case ABS:
result.append("abs(");
result.append(operand1);
result.append(")");
break;
case ACOS:
result.append("acos(");
result.append(operand1);
result.append(")");
break;
case ASIN:
result.append("asin(");
result.append(operand1);
result.append(")");
break;
case ATAN:
result.append("atan(");
result.append(operand1);
result.append(")");
break;
case CEIL:
result.append("ceil(");
result.append(operand1);
result.append(")");
break;
case COS:
result.append("cos(");
result.append(operand1);
result.append(")");
break;
case FLOOR:
result.append("floor(");
result.append(operand1);
result.append(")");
break;
case LOG:
result.append("log(");
result.append(operand1);
result.append(")");
break;
case SIN:
result.append("sin(");
result.append(operand1);
result.append(")");
break;
case SQRT:
result.append("sqrt(");
result.append(operand1);
result.append(")");
break;
case TAN:
result.append("tan(");
result.append(operand1);
result.append(")");
break;
}
}
@Override
public Object visitAssignMv(GAPPAssignMv gappAssignMv, Object arg) {
final String destMv = GAPPOpenCLCodeGenerator.getVarName(gappAssignMv.getDestination().getName());
Integer thisMvSetCount = mvBladeMap.get(destMv).size();
int selCount = 0;
for(GAPPValueHolder val : gappAssignMv.getValues()) {
final String bladeCoeff = getBladeCoeff(destMv,thisMvSetCount);
result.append(bladeCoeff);
result.append(" = ");
result.append(val.toString());
result.append(";\n");
mvBladeMap.get(destMv).put(gappAssignMv.getSelectors().get(selCount++).getIndex(),bladeCoeff);
++thisMvSetCount;
}
return null;
}
@Override
public Object visitDotVectors(GAPPDotVectors gappDotVectors, Object arg) {
final String destMv = GAPPOpenCLCodeGenerator.getVarName(gappDotVectors.getDestination().getName());
Integer thisMvSetCount = mvBladeMap.get(destMv).size();
// print gpc meta info
declareGPCMultivectorComponent(destMv,
thisMvSetCount,
gappDotVectors.getDestSelector());
// compute bladeIndex coeff name
final String bladeCoeff = getBladeCoeff(destMv,thisMvSetCount);
// update bladeIndex map
mvBladeMap.get(destMv).
put(gappDotVectors.getDestSelector().getIndex(),bladeCoeff);
// special case for operands of size 1
final int operandSize = mvBladeMap.get(GAPPOpenCLCodeGenerator.getVarName(gappDotVectors.getParts().get(0).getName())).size();
if(operandSize == 1) {
result.append(bladeCoeff).append(" = ");
visitDotVectorsParallelMultiply(gappDotVectors,0);
return null;
}
// save dot count for multiplication to be used later
final int multiplyDotCount = dotCount;
// parallel multiply operation
int operandSizeRemainder = operandSize;
int openCLVectorSize;
int subvectorIndex = 0;
do {
// get vector size
openCLVectorSize = computeNearestOpenCLVectorSize(operandSizeRemainder);
// print vector data type
printOpenCLVectorType(openCLVectorSize);
// print operation
result.append(" ").append(GAPPOpenCLCodeGenerator.dot).append(multiplyDotCount);
result.append("_").append(subvectorIndex);
//visitWriteMask(operandSize);
result.append(" = ");
visitDotVectorsParallelMultiply(gappDotVectors,subvectorIndex);
// compute operand size remainder
operandSizeRemainder -= openCLVectorSize;
// increment subvector index
++subvectorIndex;
} while(operandSizeRemainder > 0);
// in case of multiple float16 add them together first
openCLVectorSize = computeNearestOpenCLVectorSize(operandSize);
if(operandSize / maxOpenCLVectorSize > 1) {
assert(openCLVectorSize == maxOpenCLVectorSize);
// print vector data type (always float16 here)
printOpenCLVectorType(maxOpenCLVectorSize);
// print varname
result.append(" ").append(GAPPOpenCLCodeGenerator.dot).append(dotCount+1);
result.append("_0");
// print assignment
result.append(" = ");
// print first addition
result.append(" ").append(GAPPOpenCLCodeGenerator.dot).append(dotCount);
result.append("_0");
// print further additions
operandSizeRemainder = operandSize - openCLVectorSize;
subvectorIndex = 1;
while(operandSizeRemainder / maxOpenCLVectorSize > 0) {
// print operation
result.append(" + ");
result.append(" ").append(GAPPOpenCLCodeGenerator.dot).append(dotCount);
result.append("_").append(subvectorIndex++);
// compute operand size remainder
operandSizeRemainder -= maxOpenCLVectorSize;
}
result.append(";\n");
}
// we created a new variable above
++dotCount;
// parallel pyramid sum reduce operations
int multiplyIndex = 1;
openCLVectorSize = computeNearestOpenCLVectorSize(operandSize);
while((openCLVectorSize >>= 1) > 1) {
// print type
printOpenCLVectorType(openCLVectorSize);
// print varname
result.append(" ").append(GAPPOpenCLCodeGenerator.dot).append(dotCount+1);
result.append("_0");
// print assignment
result.append(" = ");
// print addition
result.append(GAPPOpenCLCodeGenerator.dot).append(dotCount).append("_0").append(lo);
result.append(" + ");
result.append(GAPPOpenCLCodeGenerator.dot).append(dotCount).append("_0").append(hi);
// in case of another existing vector of same size
// add it to this sum.
if((operandSize % (openCLVectorSize << 1)) / openCLVectorSize > 0) {
// mathematically, there can only be one more fitting vector
// of that size.
result.append(" + ");
result.append(GAPPOpenCLCodeGenerator.dot).append(multiplyDotCount).append("_").append(multiplyIndex++);
// (dotCount has to be from multiplication,
// therefore use multiplyDotCount.
// Count those using multiplyIndex.)
}
// print end of line
result.append(";\n");
++dotCount;
}
// last step directly assigns to destination
result.append(bladeCoeff);
result.append(" = ");
result.append(GAPPOpenCLCodeGenerator.dot).append(dotCount).append("_0").append(lo);
result.append(" + ");
result.append(GAPPOpenCLCodeGenerator.dot).append(dotCount).append("_0").append(hi);
result.append(";\n");
++dotCount;
return null;
}
public void visitDotVectorsParallelMultiply(GAPPDotVectors gappDotVectors,
final int subvectorIndex) {
Iterator<GAPPVector> it = gappDotVectors.getParts().iterator();
result.append(it.next().getName());
result.append("_").append(subvectorIndex);
while (it.hasNext()) {
result.append(" * ");
result.append(it.next().getName());
result.append("_").append(subvectorIndex);
}
result.append(";\n");
}
public void visitWriteMask(int operandSize) {
result.append(".s");
for(int counter = 0; counter < operandSize; ++counter)
result.append(getOpenCLIndex(counter));
}
String getCode() {
return result.toString();
}
@Override
public Object visitConstant(GAPPConstant gappConstant, Object arg) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public Object visitMultivector(GAPPMultivector gappMultivector, Object arg) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public Object visitMultivectorComponent(GAPPMultivectorComponent gappMultivectorComponent, Object arg) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public Object visitScalarVariable(GAPPScalarVariable gappScalarVariable, Object arg) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public Object visitVector(GAPPVector gappVector, Object arg) {
throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public Object visitAssignInputsVector(GAPPAssignInputsVector gappAssignInputsVector, Object arg) {
final String inputsArrayName = GAPPOpenCLCodeGenerator.getVarName(GAPPOpenCLCodeGenerator.inputsVector);
// create bladeIndex map
Map<Integer,String> bladeMap = new HashMap<Integer,String>();
Iterator<GAPPValueHolder> it = gappAssignInputsVector.getValues().iterator();
while(it.hasNext()) {
// instead of explicitly declaring the inputs vector
// just put the elements into the bladeIndex map
bladeMap.put(bladeMap.size(), it.next().prettyPrint());
}
// add bladeIndex map to multivector->bladeIndex map
mvBladeMap.put(inputsArrayName,bladeMap);
return null;
}
}