package jcuda.jcublas.ops;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
/**
* @author raver119@gmail.com
*/
@Ignore
public class DoublesTests {
@Before
public void setUp() throws Exception {
System.out.println("----------------------");
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
CudaEnvironment.getInstance().getConfiguration().enableDebug(true).setVerbose(true).allowMultiGPU(false);
}
@Test
public void testDoubleAxpy1() throws Exception {
// Nd4j.getConstantHandler().getConstantBuffer(new double[]{7.0});
Nd4j.getConstantHandler().getConstantBuffer(new int[10]);
Nd4j.getConstantHandler().getConstantBuffer(new int[7]);
// Nd4j.getConstantHandler().getConstantBuffer(new double[]{1.0, 63.0});
INDArray array1 = Nd4j.zeros(63).reshape('f', 7, 9);
//INDArray array1 = Nd4j.create(7, 9, 'f');
// array1.assign(0);
INDArray array2 = Nd4j.create(new double[] {0.48634444232816687, 1.4758265649675548, 0.39963731960854953, 1.0023591510099152, 0.7645957605153649, 1.9310904186956557, 1.1878964257563174, 0.9057360169583474, -0.3769285854145248, 0.2946010549062492, 0.46557669032521287, 1.2115125297848275, 0.9569626633310937, 0.3256059072916878, 1.612267239273259, 0.33003744088867437, 1.0449266064014164, -0.00789237850243385, 0.5410744173090415, 2.782774008354224, 1.2842283430247856, 0.9086056301544619, 1.1085112167932198, 0.7433898520033356, 1.2140223632630698, 0.7934105561182277, 1.005842641658745, 0.9997499007926636, -0.1593645983224512, 0.07349026680376536, -0.5085137730665015, 0.850035725832587, 0.24118248705567213, -0.13896796919660326, -0.43713991780505523, 0.6690021182865782, 0.17830184441787855, 0.29319561397733207, -0.1418393347014404, -0.2680684817530423, 0.17735833749207552, -0.004662964475220743, 1.0057286813222013, 0.4512230513884966, 0.9534626972218946, 0.40334611958442246, 1.019885308172407, 1.2501698497386884, 0.7623575059565331, 1.887393331295686, 0.9690210825194697, 2.0731574687887475, 1.0805132391495538, 2.8244644868991746, 3.0849853112831913, 2.2252621118259084, 1.0998660836316718, 0.5441178083600947, 1.0045439544127797, 0.3382649318030707, 1.0090081066003418, 0.5477619833704549, 0.7327435087799476}).reshape('f', 7, 9);
long time1 = System.nanoTime();
Nd4j.getBlasWrapper().axpy(new Double(1.0), array1, array2);
long time2 = System.nanoTime();
System.out.println("AXPY execution time: [" + (time2 - time1) + "] ns");
assertEquals(0.4863444, array2.getDouble(0), 0.001);
assertEquals(1.4758265, array2.getDouble(1), 0.001);
}
}