package com.nativelibs4java.opencl.blas.ujmp;
import java.io.IOException;
import org.bridj.Pointer;
import static org.bridj.Pointer.*;
import com.nativelibs4java.opencl.CLBuffer;
import com.nativelibs4java.opencl.CLPlatform.DeviceFeature;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.nativelibs4java.opencl.blas.CLKernels;
import static com.nativelibs4java.opencl.blas.ujmp.MatrixUtils.read;
import static com.nativelibs4java.opencl.blas.ujmp.MatrixUtils.write;
import com.nativelibs4java.util.Pair;
import static org.junit.Assert.*;
import static java.lang.Math.*;
import org.junit.Before;
import org.junit.Test;
import org.ujmp.core.Matrix;
import org.ujmp.core.MatrixFactory;
import org.ujmp.core.calculation.Calculation.Ret;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.doublematrix.impl.DefaultDenseDoubleMatrix2D;
import org.ujmp.core.floatmatrix.DenseFloatMatrix2D;
import org.ujmp.core.floatmatrix.FloatMatrix2D;
import org.ujmp.core.floatmatrix.impl.DefaultDenseFloatMatrix2D;
import org.ujmp.core.mapper.MatrixMapper;
/**
*
* @author ochafik
*/
public class UJMPOpenCLTest {
CLDenseDoubleMatrix2DFactory doubleFactory = new CLDenseDoubleMatrix2DFactory();
CLDenseFloatMatrix2DFactory floatFactory = new CLDenseFloatMatrix2DFactory();
@Before
public void installUJMPCL() {
try {
MatrixMapper.getInstance().setDenseFloatMatrix2DClassName(CLDenseFloatMatrix2D.class.getName());
MatrixMapper.getInstance().setDenseDoubleMatrix2DClass(CLDenseDoubleMatrix2D.class);
} catch (Exception ex) {
ex.printStackTrace();
throw new RuntimeException(ex);
}
}
@Test
public void testInstalledUJMPCL() {
Matrix m = MatrixFactory.dense(1, 1);
assertTrue(m instanceof CLDenseDoubleMatrix2D);
}
@Test
public void testOp2() {
DenseFloatMatrix2D a = floatFactory.dense(2, 2);
DenseFloatMatrix2D b = floatFactory.dense(2, 2);
float[] fa = new float[] { 10, 20, 30, 40 }, fb = new float[] { 1, 2, 3, 4 };
write(fa, a.getColumnCount(), a);
write(fb, b.getColumnCount(), b);
assertArrayEquals("failed plus", new float[] { 11, 22, 33, 44 }, read((FloatMatrix2D)a.plus(b), a.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed minus", new float[] { 9, 18, 27, 36 }, read((FloatMatrix2D)a.minus(b), a.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed times", new float[] { 10, 40, 90, 160 }, read((FloatMatrix2D)a.times(b), b.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed divide", new float[] { 10, 10, 10, 10 }, read((FloatMatrix2D)a.divide(b), a.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed scalar plus", new float[] { 11, 21, 31, 41 }, read((FloatMatrix2D)a.plus(1), a.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed scalar minus", new float[] { 9, 19, 29, 39 }, read((FloatMatrix2D)a.minus(1), a.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed scalar times", new float[] { 20, 40, 60, 80 }, read((FloatMatrix2D)a.times(2), b.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed scalar divide", new float[] { 1, 2, 3, 4 }, read((FloatMatrix2D)a.divide(10), a.getColumnCount()).getFloats(), 0);
assertArrayEquals("failed sin", new float[] { (float)sin(fa[0]), (float)sin(fa[1]), (float)sin(fa[2]), (float)sin(fa[3]) },
read((FloatMatrix2D)a.sin(Ret.NEW), a.getColumnCount()).getFloats(), 0.0001f);
assertArrayEquals("failed cos", new float[] { (float)cos(fa[0]), (float)cos(fa[1]), (float)cos(fa[2]), (float)cos(fa[3]) },
read((FloatMatrix2D)a.cos(Ret.NEW), a.getColumnCount()).getFloats(), 0.0001f);
assertArrayEquals("failed tan", new float[] { (float)tan(fa[0]), (float)tan(fa[1]), (float)tan(fa[2]), (float)tan(fa[3]) },
read((FloatMatrix2D)a.tan(Ret.NEW), a.getColumnCount()).getFloats(), 0.0001f);
}
@Test
public void testMultFloat() {
DenseFloatMatrix2D m = floatFactory.dense(3, 3);
DenseFloatMatrix2D v = floatFactory.dense(3, 1);
//CLBuffer<Float> buffer = ((CLDenseFloatMatrix2D)m).getBuffer();
CLQueue queue = ((CLDenseFloatMatrix2D)m).getImpl().getQueue();
//System.out.println("Context = " + buffer.getContext());
float[] min = new float[] { 0, 0, 1, 0, 1, 0, 1, 0, 0 };
write(min, m.getColumnCount(), m);
Pointer<Float> back = read(m, m.getColumnCount());
for (int i = 0, cap = (int)back.getValidElements(); i < cap; i++) {
assertEquals(min[i], back.get(i), 0);
//System.out.println(back.get(i));
}
queue.finish();
DenseFloatMatrix2D mout = (DenseFloatMatrix2D) m.mtimes(m);
queue.finish();
//System.out.println("m = \n" + m);
//System.out.println("mout = \n" + mout);
//if (la instanceof CLLinearAlgebra)
// ((CLLinearAlgebra)la).queue.finish();
//dmout.write((FloatBuffer)mout.read());
back = read(mout, mout.getColumnCount());
//for (int i = 0, cap = (int)back.getValidElements(); i < cap; i++)
// System.out.println(back.get(i));
assertEquals(0, mout.getFloat(0, 1), 0);
assertEquals(0, mout.getFloat(1, 0), 0);
assertEquals(1, mout.getFloat(0, 0), 0);
assertEquals(1, mout.getFloat(1, 1), 0);
assertEquals(1, mout.getFloat(2, 2), 0);
write(new float[] { 1, 0, 0}, v.getColumnCount(), v);
DenseFloatMatrix2D vout = (DenseFloatMatrix2D)m.mtimes(v);
//System.out.println(v);
//System.out.println(vout);
assertEquals(0, vout.getFloat(0, 0), 0);
assertEquals(0, vout.getFloat(1, 0), 0);
assertEquals(1, vout.getFloat(2, 0), 0);
}
@Test
public void testContainsDouble() throws IOException {
CLKernels.setInstance(new CLKernels());
CLDenseDoubleMatrix2D m = (CLDenseDoubleMatrix2D)MatrixFactory.dense(2, 2);
int row = 1, column = 1;
m.setDouble(1.1, row, column);
assertEquals(1.1, m.getDouble(row, column), 0.0);
assertTrue(m.containsDouble(1.1));
assertTrue(!m.containsDouble(2.0));
}
@Test
public void testContainsFloat() {
CLDenseFloatMatrix2D m = floatFactory.dense(2, 2);
int row = 1, column = 1;
m.setFloat(1.1f, row, column);
assertEquals(1.1f, m.getFloat(row, column), 0.0);
assertTrue(m.containsFloat(1.1f));
assertTrue(!m.containsFloat(2.0f));
}
@Test
public void testClearFloat() {
CLDenseFloatMatrix2D m = floatFactory.dense(2, 2);
int row = 0, column = 1;
m.setFloat(1.1f, row, column);
assertEquals(1.1f, m.getFloat(row, column), 0.0);
m.clear();
assertEquals(0f, m.getFloat(row, column), 0.0);
}
static void sleep(long millis) {
try { Thread.sleep(millis); } catch (Exception ex) { ex.printStackTrace(); }
}
}