/*
* To change this template, choose Tools | Templates
* and open the template in the editor.
*/
package com.nativelibs4java.opencl.util;
import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLEvent;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.nativelibs4java.opencl.util.ReductionUtils;
import com.nativelibs4java.opencl.util.ReductionUtils.Reductor;
import com.nativelibs4java.util.IOUtils;
import com.nativelibs4java.util.Pair;
import static com.nativelibs4java.util.NIOUtils.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
*
* @author ochafik
*/
@SuppressWarnings("unused")
public class ParallelMath {
protected CLContext context;
protected CLQueue queue;
public ParallelMath() {
this(JavaCL.createBestContext().createDefaultQueue());
}
public ParallelMath(CLQueue queue) {
this.queue = queue;
CLContext context = queue.getContext();
}
public CLQueue getQueue() {
return queue;
}
public CLContext getContext() {
return getQueue().getContext();
}
protected String createVectFun1Source(Fun1 function, Primitive type, StringBuilder out) {
String t = type.clTypeName();
String kernelName = "vect_" + function.name() + "_" + t;// + (inPlace ? "_inplace" : "");
out.append("__kernel void " + kernelName + "(\n");
out.append("\t__global const " + t + "* in,\n");
out.append("\t__global " + t + "* out,\n");
out.append("\tlong length\n");
out.append(") {\n");
out.append("\tint i = get_global_id(0);\n");
out.append("\tif (i >= length) return;\n");
out.append("\tout[i] = ");
function.expr("in[i]", out);
out.append(";\n");
out.append("}\n");
return kernelName;
}
protected String createVectFun2Source(Fun2 function, Primitive type1, Primitive type2, Primitive typeOut, StringBuilder out, boolean secondOperandIsScalar) {
String t1 = type1.clTypeName(), t2 = type2.clTypeName(), to = typeOut.clTypeName();
String kernelName = "vect_" + function.name() + "_" + t1 + "_" + t2 + "_" + to;
out.append("__kernel void " + kernelName + "(\n");
out.append("\t__global const " + t1 + "* in1,\n");
if (secondOperandIsScalar)
out.append("\t" + t2 + " in2,\n");
else
out.append("\t__global const " + t2 + "* in2,\n");
out.append("\t__global " + to + "* out,\n");
out.append("\tlong length\n");
out.append(") {\n");
out.append("\tint i = get_global_id(0);\n");
out.append("\tif (i >= length) return;\n");
out.append("\tout[i] = (" + to + ")");
function.expr("in1[i]", (secondOperandIsScalar ? "in2" : "in2[i]"), out);
out.append(";\n");
out.append("}\n");
return kernelName;
}
private EnumMap<Fun1, EnumMap<Primitive, CLKernel>> fun1Kernels = new EnumMap<Fun1, EnumMap<Primitive, CLKernel>>(Fun1.class);
public synchronized CLKernel getKernel(Fun1 op, Primitive prim) throws CLBuildException {
EnumMap<Primitive, CLKernel> m = fun1Kernels.get(op);
if (m == null)
fun1Kernels.put(op, m = new EnumMap<Primitive, CLKernel>(Primitive.class));
CLKernel kers = m.get(prim);
if (kers == null) {
StringBuilder out = new StringBuilder(300);
String name = createVectFun1Source(op, prim, out);
CLProgram prog = getContext().createProgram(out.toString()).build();
kers = prog.createKernel(name);
m.put(prim, kers);
}
return kers;
}
static class PrimitiveTrio extends Pair<Pair<Primitive, Primitive>, Pair<Primitive, Boolean>> {
public PrimitiveTrio(Primitive a, Primitive b, Primitive c, boolean secondOperandIsScalar) {
super(new Pair<Primitive, Primitive>(a, b), new Pair<Primitive, Boolean>(c, secondOperandIsScalar));
}
}
private EnumMap<Fun2, Map<PrimitiveTrio, CLKernel>> fun2Kernels = new EnumMap<Fun2, Map<PrimitiveTrio, CLKernel>>(Fun2.class);
public synchronized CLKernel getKernel(Fun2 op, Primitive prim, boolean secondOperandIsScalar) throws CLBuildException {
return getKernel(op, prim, prim, prim, secondOperandIsScalar);
}
public synchronized CLKernel getKernel(Fun2 op, Primitive prim1, Primitive prim2, Primitive primOut, boolean secondOperandIsScalar) throws CLBuildException {
Map<PrimitiveTrio, CLKernel> m = fun2Kernels.get(op);
if (m == null)
fun2Kernels.put(op, m = new HashMap<PrimitiveTrio, CLKernel>());
PrimitiveTrio key = new PrimitiveTrio(prim1, prim2, primOut, secondOperandIsScalar);
CLKernel ker = m.get(key);
if (ker == null) {
StringBuilder out = new StringBuilder(300);
String name = createVectFun2Source(op, prim1, prim2, primOut, out, secondOperandIsScalar);
CLProgram prog = getContext().createProgram(out.toString()).build();
ker = prog.createKernel(name);
m.put(key, ker);
}
return ker;
}
}