package org.nd4j.linalg.api.ndarray; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.io.Assert; import org.nd4j.linalg.ops.transforms.Transforms; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; /** * Created by susaneraly on 6/18/16. */ @RunWith(Parameterized.class) public class TestNdArrReadWriteTxtC extends BaseNd4jTest { public TestNdArrReadWriteTxtC(Nd4jBackend backend) { super(backend); } @Test public void TestReadWrite() { INDArray origArr = Nd4j.rand('c', 10, 10).muli(100); //since we write only two decimal points.. Nd4j.writeTxt(origArr, "someArr.txt"); INDArray readBack = Nd4j.readTxt("someArr.txt"); System.out.println("========================================================================="); System.out.println(origArr); System.out.println("========================================================================="); System.out.println(readBack); Assert.isTrue(Transforms.abs(origArr.subi(readBack)).maxNumber().doubleValue() < 0.01); try { Files.delete(Paths.get("someArr.txt")); } catch (IOException e) { e.printStackTrace(); } } @Test public void TestReadWriteSimple() { INDArray origArr = Nd4j.rand(1, 1).muli(100); //since we write only two decimal points.. Nd4j.writeTxt(origArr, "someArr.txt"); INDArray readBack = Nd4j.readTxt("someArr.txt"); System.out.println("========================================================================="); System.out.println(origArr); System.out.println("========================================================================="); System.out.println(readBack); Assert.isTrue(Transforms.abs(origArr.subi(readBack)).maxNumber().doubleValue() < 0.01); try { Files.delete(Paths.get("someArr.txt")); } catch (IOException e) { e.printStackTrace(); } } @Test public void TestReadWriteNd() { INDArray origArr = Nd4j.rand(13, 2, 11, 3, 7, 19).muli(100); //since we write only two decimal points.. Nd4j.writeTxt(origArr, "someArr.txt"); INDArray readBack = Nd4j.readTxt("someArr.txt"); System.out.println("========================================================================="); System.out.println(origArr); System.out.println("========================================================================="); System.out.println(readBack); Assert.isTrue(Transforms.abs(origArr.subi(readBack)).maxNumber().doubleValue() < 0.01); try { Files.delete(Paths.get("someArr.txt")); } catch (IOException e) { e.printStackTrace(); } } @Test public void TestWierdShape() { INDArray origArr = Nd4j.rand(1, 1, 2, 1, 1).muli(100); //since we write only two decimal points.. Nd4j.writeTxt(origArr, "someArr.txt"); INDArray readBack = Nd4j.readTxt("someArr.txt"); System.out.println("========================================================================="); System.out.println(origArr); System.out.println("========================================================================="); System.out.println(readBack); Assert.isTrue(Transforms.abs(origArr.subi(readBack)).maxNumber().doubleValue() < 0.01); try { Files.delete(Paths.get("someArr.txt")); } catch (IOException e) { e.printStackTrace(); } } @Override public char ordering() { return 'c'; } }