/* * To change this template, choose Tools | Templates * and open the template in the editor. */ package com.nativelibs4java.opencl.blas; import com.nativelibs4java.opencl.CLBuildException; import com.nativelibs4java.opencl.CLEvent; import com.nativelibs4java.opencl.util.Fun1; import com.nativelibs4java.opencl.util.Fun2; import com.nativelibs4java.opencl.util.Primitive; import com.nativelibs4java.opencl.util.ReductionUtils.Reductor; /** * * @author ochafik */ public class CLMatrixUtils { static CLEvent[] join(CLEvent[]... evts) { int n = 0; for (CLEvent[] e : evts) n += e.length; CLEvent[] out = new CLEvent[n]; n = 0; for (CLEvent[] e : evts) System.arraycopy(e, 0, out, n, e.length); return out; } public static long roundUp(long size, int blockSize) { return ((size + blockSize - 1) / blockSize) * blockSize; } public static <T> void matrixMultiply( final CLMatrix2D<T> a, final CLMatrix2D<T> b, final CLMatrix2D<T> out) throws CLBuildException { final CLKernels kernels = a.getKernels(); final Primitive primitive = a.getPrimitive(); a.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] aevents) { return b.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] bevents) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(final CLEvent[] cevents) { CLEvent evt = kernels.matrixMultiply( primitive, a.getBuffer(), a.getRowCount(), a.getColumnCount(), a.getStride(), a.getBlockSize(), b.getBuffer(), b.getRowCount(), b.getColumnCount(), b.getStride(), b.getBlockSize(), out.getBuffer(), join(aevents, bevents, cevents) ); return evt; } }); } }); } }); } static final int MAX_REDUCTION_SIZE = 32; public static <T> void reduce( final CLMatrix2D<T> in, final CLMatrix2D<T> out, final Reductor<T> reductor ) { in.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] ievents) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(CLEvent[] oevents) { return reductor.reduce(in.getQueue(), in.getBuffer(), in.getBuffer().getElementCount(), out.getBuffer(), MAX_REDUCTION_SIZE, join(ievents, oevents)); } }); } }); } public static <T> void matrixTranspose( final CLMatrix2D<T> a, final CLMatrix2D<T> out) throws CLBuildException { final Primitive primitive = a.getPrimitive(); final CLKernels kernels = a.getKernels(); a.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] aevents) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(final CLEvent[] cevents) { CLEvent evt = kernels.matrixTranspose( primitive, a.getBuffer(), a.getRowCount(), a.getColumnCount(), a.getStride(), out.getBuffer(), join(aevents, cevents) ); return evt; } }); } }); } public static <T> CLMatrix2D<T> clone(final CLMatrix2D<T> matrix) { final CLMatrix2D<T> out = matrix.blankClone(); matrix.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] aevents) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(CLEvent[] bevents) { return matrix.getBuffer().copyTo(matrix.getQueue(), out.getBuffer(), CLMatrixUtils.join(aevents, bevents)); } }); } }); return out; } public static <T> CLMatrix2D<T> createMatrix(long rows, long columns, Class<T> elementClass, CLKernels kernels) { if (elementClass == Double.class) return (CLMatrix2D<T>)new CLDefaultMatrix2D(Primitive.Double, null, rows, columns, kernels); throw new UnsupportedOperationException("Cannot build buffers of " + elementClass.getName() + " yet"); } public static <V> CLMatrix2D<V> op1(final CLMatrix2D<V> in, final Fun1 fun, final CLMatrix2D<V> out) throws CLBuildException { in.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] ievents) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(CLEvent[] oevents) { return in.getKernels().op1(in.getPrimitive(), fun, in.getBuffer(), in.getRowCount(), in.getColumnCount(), in.getStride(), out.getBuffer(), join(ievents, oevents)); } }); } }); return out; } public static <V> CLMatrix2D<V> op2(final CLMatrix2D<V> in1, final Fun2 fun, final CLMatrix2D<V> in2, final CLMatrix2D<V> out) throws CLBuildException { in1.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] i1events) { return in2.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] i2events) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(CLEvent[] oevents) { return in1.getKernels().op2(in1.getPrimitive(), fun, in1.getBuffer(), in2.getBuffer(), in1.getRowCount(), in1.getColumnCount(), in1.getStride(), out.getBuffer(), join(i1events, i2events, oevents)); } }); } }); } }); return out; } public static <V> CLMatrix2D<V> op2(final CLMatrix2D<V> in, final Fun2 fun, final V s2, final CLMatrix2D<V> out) throws CLBuildException { in.getEvents().performRead(new CLEvents.Action() { public CLEvent perform(final CLEvent[] ievents) { return out.getEvents().performWrite(new CLEvents.Action() { public CLEvent perform(CLEvent[] oevents) { return in.getKernels().op2( in.getPrimitive(), fun, in.getBuffer(), s2, in.getRowCount(), in.getColumnCount(), in.getStride(), out.getBuffer(), join(ievents, oevents)); } }); } }); return out; } }