package com.nativelibs4java.opencl.generator;
import com.nativelibs4java.opencl.*;
import com.ochafik.lang.jnaerator.*;
import com.ochafik.lang.jnaerator.TypeConversion.JavaPrimitive;
import com.ochafik.lang.jnaerator.TypeConversion.TypeConversionMode;
import com.ochafik.lang.jnaerator.UniversalReconciliator;
import com.ochafik.lang.jnaerator.UnsupportedConversionException;
import com.ochafik.lang.jnaerator.parser.*;
import com.ochafik.lang.jnaerator.runtime.NativeSize;
import com.nativelibs4java.jalico.Adapter;
import com.nativelibs4java.jalico.Pair;
import com.ochafik.util.string.RegexUtils;
import com.ochafik.util.string.StringUtils;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
import static com.ochafik.lang.jnaerator.parser.ElementsHelper.*;
import java.io.PrintStream;
import java.util.regex.Pattern;
public class JavaCLGenerator extends JNAerator {
static Pattern nameExtPatt = Pattern.compile("(.*?)\\.(\\w+)");
public JavaCLGenerator(JNAeratorConfig config) {
super(config);
config.forceOverwrite = true;
config.outputMode = JNAeratorConfig.OutputMode.Directory;
config.noCPlusPlus = true;
config.genCPlusPlus = false;
config.gccLong = true;
config.putTopStructsInSeparateFiles = false;
config.runtime = JNAeratorConfig.Runtime.BridJ;//NL4JStructs;
config.fileToLibrary = new Adapter<File, String>() {
@Override
public String adapt(File value) {
String[] m = RegexUtils.match(value.getName(), nameExtPatt);
return m == null ? null : m[1];
}
};
config.functionsAccepter = new Adapter<Function, Boolean>() {
@Override
public Boolean adapt(Function value) {
List<Modifier> mods = value.getModifiers();
if (ModifierType.__kernel.isContainedBy(mods))
return true;
if (value.getValueType() == null)
return null;
mods = value.getValueType().getModifiers();
return ModifierType.__kernel.isContainedBy(mods);
}
};
}
Map<String, Set<String>> macrosByFile = new HashMap<String, Set<String>>();
@Override
protected JNAeratorParser createJNAeratorParser() {
return new JNAeratorParser() {
@Override
protected com.ochafik.lang.jnaerator.parser.ObjCppParser newObjCppParser(TypeConversion typeConverter, String s, boolean verbose, PrintStream errorOut) throws IOException {
com.ochafik.lang.jnaerator.parser.ObjCppParser parser = super.newObjCppParser(typeConverter, s, verbose, errorOut);
parser.allowKinds(ModifierKind.OpenCL);
return parser;
}
};
}
static Set<String> openclPrimitives = new HashSet<String>();
static {
openclPrimitives.add("half");
openclPrimitives.add("image2d_t");
openclPrimitives.add("image3d_t");
openclPrimitives.add("sampler_t");
openclPrimitives.add("event_t");
}
@Override
public Result createResult(final ClassOutputter outputter, Feedback feedback) {
return new Result(config, outputter, feedback) {
@Override
public Identifier getLibraryClassFullName(String library) {
return null;
}
@Override
public void init() {
typeConverter = new BridJTypeConversion(this) {
@Override
public void initTypes() {
super.initTypes();
}
@Override
protected Identifier packageMember(Identifier libraryPackage, Identifier name) {
return name;
}
@Override
public boolean isObjCppPrimitive(String s) {
int len;
if (s == null || (len = s.length()) == 0)
return false;
if (super.isObjCppPrimitive(s))
return true;
// handle case of "(int|long|short|byte|double|float)\\d"
if (len > 1 && Character.isDigit(s.charAt(len - 1))) {
String ss = s.substring(0, len - 1);
if (ss.charAt(0) == 'u')
ss = ss.substring(1);
if (super.isObjCppPrimitive(ss))
return true;
}
return openclPrimitives.contains(s);
}
};
declarationsConverter = new BridJDeclarationsConverter(this) {
@Override
public void convertFunction(Function function, Signatures signatures, boolean isCallback, DeclarationsHolder declarations, DeclarationsHolder implementations, Identifier libraryClassName, int iConstructor) {
if (isCallback)
return;
if (!result.config.functionsAccepter.adapt(function))
return;
List<Arg> args = function.getArgs();
List<Arg> convArgs = new ArrayList<Arg>(args.size());
String queueName = "commandQueue";
convArgs.add(new Arg(queueName, typeRef(CLQueue.class)));
List<Expression> convArgExpr = new ArrayList<Expression>(args.size());
List<Statement> extraStatements = new ArrayList<Statement>();
int iArg = 1;
for (Arg arg : args) {
TypeRef tr = arg.createMutatedType();
if (tr == null)
return;
try {
tr = result.typeConverter.normalizeTypeRef(tr);//, null/*libraryClassName*/, false, false);
List<Modifier> mods = arg.harvestModifiers();
TypeRef convTr;
String argName = arg.getName() == null ? "arg" + iArg : arg.getName();
Expression argExpr;
if (ModifierType.__local.isContainedBy(mods)) {
argName += "LocalByteSize";
//convTr = typeRef(Long.TYPE);
//argExpr = new Expression.New(typeRef(LocalSize.class), varRef(argName));
convTr = typeRef(LocalSize.class);
argExpr = varRef(argName);//new Expression.New(typeRef(LocalSize.class), varRef(argName));
} else {
Conversion conv = convertTypeToJavaCL(result, argName, tr, TypeConversion.TypeConversionMode.PrimitiveOrBufferParameter, null);
convTr = conv.outerJavaTypeRef;
argExpr = conv.convertedExpr;
extraStatements.addAll(conv.extraStatements);
//String convTrStr = convTr.toString();
/*if (convTrStr.equals(NativeSize.class.getName()) || convTrStr.equals(NativeLong.class.getName()))
argExpr = new Expression.New(tr, varRef(conv.argName));
else
argExpr = varRef(ident(argName));*/
}
convArgs.add(new Arg(argName, convTr));
convArgExpr.add(argExpr);//varRef(argName));
} catch (UnsupportedConversionException ex) {
implementations.addDeclaration(skipDeclaration(function, ex.toString()));
}
iArg++;
}
String globalWSName = "globalWorkSizes", localWSName = "localWorkSizes", eventsName = "eventsToWaitFor";
convArgs.add(new Arg(globalWSName, typeRef(int[].class)));
convArgs.add(new Arg(localWSName, typeRef(int[].class)));
convArgs.add(new Arg(eventsName, typeRef(CLEvent.class)).setVarArg(true));
String functionName = function.getName().toString();
String kernelVarName = functionName + "_kernel";
if (signatures.addVariable(kernelVarName))
implementations.addDeclaration(new VariablesDeclaration(typeRef(CLKernel.class), new Declarator.DirectDeclarator(kernelVarName)));
Function method = new Function(Function.Type.JavaMethod, ident(functionName), typeRef(CLEvent.class));
method.addModifiers(ModifierType.Public, ModifierType.Synchronized);
method.addThrown(typeRef(CLBuildException.class));
method.setArgs(convArgs);
List<Statement> statements = new ArrayList<Statement>();
statements.add(
new Statement.If(
expr(varRef(kernelVarName), Expression.BinaryOperator.IsEqual, new Expression.NullExpression()),
stat(
expr(
varRef(kernelVarName), Expression.AssignmentOperator.Equal,
methodCall(
"createKernel",
new Expression.Constant(Expression.Constant.Type.String, functionName, null)
)
)
),
null
)
);
statements.addAll(extraStatements);
statements.add(
stat(methodCall(
varRef(kernelVarName),
Expression.MemberRefStyle.Dot,
"setArgs",
convArgExpr.toArray(new Expression[convArgExpr.size()])
))
);
statements.add(
new Statement.Return(methodCall(
varRef(kernelVarName),
Expression.MemberRefStyle.Dot,
"enqueueNDRange",
varRef(queueName),
varRef(globalWSName),
varRef(localWSName),
varRef(eventsName)
))
);
method.setBody(block(statements.toArray(new Statement[statements.size()])));
if (signatures.addMethod(method))
implementations.addDeclaration(method);
}
};
globalsGenerator = new BridJGlobalsGenerator(this);
objectiveCGenerator = new ObjectiveCGenerator(this);
universalReconciliator = new UniversalReconciliator();
}
};
}
static class CLPrim {
TypeConversion.JavaPrimitive javaPrim;
int arity;
boolean isLong, isShort;
Expression assertExpr;
Statement checkStatement;
Expression convertStatement;
Class<?> argClass;
public CLPrim(JavaPrimitive javaPrim, int arity) {
this.javaPrim = javaPrim;
this.arity = arity;
}
static Pattern patt = Pattern.compile("(?:(long|short)\\s+)?(float|double|u?(?:char|long|short|int))(\\d)");
public static CLPrim parse(Result result, TypeRef tr) {
String s = tr.toString();
if (s == null || s.length() == 0)
return null;
char c = s.charAt(s.length() - 1);
if (!Character.isDigit(c)) {
//JavaPrim prim = result.typeConverter.getPrimitive(
return null;
}
String[] m = RegexUtils.match(tr.toString(), patt);
if (m == null)
return null;
//boolean isShort = false,
//result.typeConverter
return null;
}
}
static class Conversion {
TypeRef outerJavaTypeRef;
Expression convertedExpr;
String argName;
List<Statement> extraStatements = new ArrayList<Statement>();
}
static Map<String, Pair<Integer, Class<?>>> buffersAndArityByType = new HashMap<String, Pair<Integer, Class<?>>>();
static Map<String, Pair<Integer, Class<?>>> arraysAndArityByType = new HashMap<String, Pair<Integer, Class<?>>>();
static {
Object[] data = new Object[] {
"char", Byte.TYPE, byte[].class, Byte.class,
"long", Long.TYPE, long[].class, Long.class,
"int", Integer.TYPE, int[].class, Integer.class,
"short", Short.TYPE, short[].class, Short.class,
"wchar_t", Character.TYPE, char[].class, Short.class,
"double", Double.TYPE, double[].class, Double.class,
"float", Float.TYPE, float[].class, Float.class,
"bool", Boolean.TYPE, boolean[].class, Boolean.class
};
for (int arity : new int[] { 1, 2, 3, 4, 8, 16 }) {
String suffix = arity == 1 ? "" : arity +"";
for (int i = 0; i < data.length; i += 4) {
String rawType = (String)data[i];
Class<?> scalClass = (Class<?>)data[i + 1];
Class<?> arrClass = (Class<?>)data[i + 2];
Class<?> buffClass = (Class<?>)data[i + 3];
Pair<Integer, Class<?>>
buffPair = new Pair<Integer, Class<?>>(arity, buffClass),
arrPair = new Pair<Integer, Class<?>>(arity, arity == 1 ? scalClass : arrClass);
for (String type : new String[] { rawType + suffix, "u" + rawType + suffix}) {
buffersAndArityByType.put(type, buffPair);
arraysAndArityByType.put(type, arrPair);
}
}
}
data = new Object[] {
"image2d_t", CLImage2D.class,
"image3d_t", CLImage3D.class
};
for (int i = 0; i < data.length; i+=2) {
String type = (String) data[i];
Class<?> scalClass = (Class<?>)data[i + 1];
Pair<Integer, Class<?>> arrPair = new Pair<Integer, Class<?>>(1, scalClass);
arraysAndArityByType.put(type, arrPair);
}
}
private Conversion convertTypeToJavaCL(Result result, String argName, TypeRef valueType, TypeConversionMode typeConversionMode, Identifier libraryClassName) throws UnsupportedConversionException {
Conversion ret = new Conversion();
ret.argName = argName;
ret.convertedExpr = varRef(argName);
if (valueType instanceof TypeRef.Pointer) {
TypeRef target = ((TypeRef.Pointer)valueType).getTarget();
if (target instanceof TypeRef.SimpleTypeRef) {
TypeRef.SimpleTypeRef starget = (TypeRef.SimpleTypeRef)target;
Identifier name = starget.getName();
Pair<Integer, Class<?>> pair = buffersAndArityByType.get((starget + "").equals("long") ? "long" : name + "");
if (pair != null) {
ret.outerJavaTypeRef = typeRef(ident(CLBuffer.class, expr(typeRef(pair.getSecond()))));
return ret;
}
Identifier ref =
result.structsFullNames.contains(name) ||
result.enumsFullNames.contains(name) ?
name : result.typeConverter.findRef(name, target, libraryClassName, true);
if (ref != null) {
ret.outerJavaTypeRef = typeRef(ident(CLBuffer.class, expr(typeRef(ref))));
return ret;
}
} else if (target instanceof Struct) {
TypeRef ref = result.typeConverter.findStructRef((Struct)target, libraryClassName);
if (ref != null) {
ret.outerJavaTypeRef = typeRef(ident(CLBuffer.class, expr(ref)));
return ret;
}
}
throw new UnsupportedConversionException(valueType, "Unknown pointed target type");
} else if (valueType instanceof TypeRef.SimpleTypeRef) {
TypeRef.SimpleTypeRef sr = (TypeRef.SimpleTypeRef)valueType;
String name = sr.getName() == null ? sr.toString() : sr.getName().toString();
if (name.equals("size_t")) {
ret.outerJavaTypeRef = typeRef(Long.TYPE);
ret.convertedExpr = new Expression.New(typeRef(NativeSize.class), ret.convertedExpr);
return ret;
} else {
Pair<Integer, Class<?>> pair = arraysAndArityByType.get(name);
if (pair != null) {
ret.outerJavaTypeRef = typeRef(pair.getSecond());
if (pair.getFirst().intValue() != 1) {
ret.extraStatements.add(
stat(
methodCall(
"checkArrayLength",
varRef(ret.argName),
expr(
Expression.Constant.Type.Int,
pair.getFirst()
),
expr(
Expression.Constant.Type.String,
ret.argName
)
)
)
);
}
return ret;
}
}
}
throw new UnsupportedConversionException(valueType, "Unhandled type : " + valueType);
}
@Override
protected void generateLibraryFiles(SourceFiles sourceFiles, Result result) throws IOException {
//super.generateLibraryFiles(sourceFiles, result);
for (SourceFile sourceFile : sourceFiles.getSourceFiles()) {
String rawSrcFilePath = new File(sourceFile.getElementFile()).getCanonicalPath();
String srcFilePath = result.config.relativizeFileForSourceComments(rawSrcFilePath);
File srcFile = new File(srcFilePath);
String srcParent = srcFile.getParent();
String srcFileName = srcFile.getName();
String[] nameExt = RegexUtils.match(srcFileName, nameExtPatt);
if (nameExt == null)
continue;
String name = nameExt[1], ext = nameExt[2];
if (!ext.equals("c") && !ext.equals("cl"))
continue;
String packageName = srcParent == null || srcParent.length() == 0 ? null : srcParent.replace('/', '.').replace('\\', '.');
Identifier packageIdent = ident(packageName);
String className = (packageName == null ? "" : packageName + ".") + name;
Struct interf = new Struct();
interf.addToCommentBefore("Wrapper around the OpenCL program " + name);
interf.addModifiers(ModifierType.Public);
interf.setTag(ident(name));
interf.addParent(ident(CLAbstractUserProgram.class));
interf.setType(Struct.Type.JavaClass);
String[] constrArgNames = new String[] { "context", "program" };
Class<?>[] constrArgTypes = new Class<?>[] { CLContext.class, CLProgram.class };
for (int i = 0; i < constrArgNames.length; i++) {
String argName = constrArgNames[i];
Function constr = new Function(Function.Type.JavaMethod, ident(name), null, new Arg(argName, typeRef(constrArgTypes[i])));
constr.addModifiers(ModifierType.Public);
constr.addThrown(typeRef(IOException.class));
constr.setBody(
block(
stat(
methodCall(
"super",
varRef(argName),
methodCall(
"readRawSourceForClass",
result.typeConverter.typeLiteral(typeRef(name))
)
)
)
)
);
interf.addDeclaration(constr);
}
//result.declarationsConverter.convertStructs(null, null, interf, null)
Signatures signatures = new Signatures();//result.getSignaturesForOutputClass(fullLibraryClassName);
result.typeConverter.allowFakePointers = true;
String library = name;
Identifier fullLibraryClassName = ident(className);
interf.setResolvedJavaIdentifier(fullLibraryClassName);
result.declarationsConverter.convertStructs(result.structsByLibrary.get(library), signatures, interf, library);
//result.declarationsConverter.convertCallbacks(result.callbacksByLibrary.get(library), signatures, interf, fullLibraryClassName);
int declCount = interf.getDeclarations().size();
result.declarationsConverter.convertFunctions(result.functionsByLibrary.get(library), signatures, interf, interf);
result.declarationsConverter.convertEnums(result.enumsByLibrary.get(library), signatures, interf);
result.declarationsConverter.convertConstants(library, result.definesByLibrary.get(library), sourceFiles, signatures, interf);
boolean hasKernels = interf.getDeclarations().size() > declCount;
if (!hasKernels)
continue;
//for ()
/*
public SampleUserProgram(CLContext context) throws IOException {
super(context, readRawSourceForClass(SampleUserProgram.class));
}*/
for (Set<String> set : macrosByFile.values()) {
for (String macroName : set) {
if (macroName.equals("__LINE__") ||
macroName.equals("__FILE__") ||
macroName.equals("__COUNTER__") ||
config.preprocessorConfig.explicitMacros.containsKey(macroName) ||
config.preprocessorConfig.implicitMacros.containsKey(macroName))
continue;
String[] parts = macroName.split("_+");
List<String> newParts = new ArrayList<String>(parts.length);
for (String part : parts) {
if (part == null || (part = part.trim()).length() == 0)
continue;
newParts.add(StringUtils.capitalize(part));
}
String functionName = "define" + StringUtils.implode(newParts, "");
Function macroDef = new Function(Function.Type.JavaMethod, ident(functionName), typeRef("void"));
String valueName = "value";
macroDef.addArg(new Arg(valueName, typeRef(String.class)));
macroDef.setBody(block(stat(methodCall("defineMacro", expr(Expression.Constant.Type.String, macroName), varRef(valueName)))));
interf.addDeclaration(macroDef);
}
}
PrintWriter out = result.classOutputter.getClassSourceWriter(className);
result.printJavaClass(packageIdent, interf, out);
//if (packageName != null)
// out.println("package " + packageName + ";");
//out.println(interf);
out.close();
}
}
//
// @Override
// protected void autoConfigure() {
// super.autoConfigure();
//
// /*
// __OPENCL_VERSION__
// __ENDIAN_LITTLE__
//
// __IMAGE_SUPPORT__
// __FAST_RELAXED_MATH__
// */
//
// }
public static void main(String[] args) {
JNAerator.main(new JavaCLGenerator(new JNAeratorConfig()),
new String[] {
"-o", "target/generated-sources/test",
//"-o", "/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Demos/target/generated-sources/main/java",
"-noJar",
"-noComp",
"-v",
"-addRootDir", "src/test/opencl",
"src/test/opencl/com/nativelibs4java/opencl/generator/Structs.c",
//"-addRootDir", "/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Blas/target/../src/main/opencl",
//"/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Blas/src/main/opencl/com/nativelibs4java/opencl/blas/LinearAlgebraKernels.c"
//"-addRootDir", "/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Demos/target/../src/main/opencl",
//"/Users/ochafik/Prog/Java/versionedSources/nativelibs4java/trunk/libraries/OpenCL/Demos/target/../src/main/opencl/com/nativelibs4java/opencl/demos/sobelfilter/SimpleSobel.cl"
}
);
}
}