package org.nd4j.linalg.api.blas;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class Level1Test extends BaseNd4jTest {
public Level1Test(Nd4jBackend backend) {
super(backend);
}
@Test
public void testDot() {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArray row = matrix.getRow(1);
double dot = Nd4j.getBlasWrapper().dot(row, row);
assertEquals(20, dot, 1e-1);
}
@Test
public void testAxpy() {
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
INDArray row = matrix.getRow(1);
Nd4j.getBlasWrapper().level1().axpy(row.length(), 1.0, row, row);
assertEquals(getFailureMessage(), Nd4j.create(new double[] {4, 8}), row);
}
@Override
public char ordering() {
return 'f';
}
}