package org.nd4j.jita.constant; import lombok.extern.slf4j.Slf4j; import org.junit.Before; import org.junit.Test; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.ShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.*; /** * Created by raver119 on 30.09.16. */ @Slf4j public class ProtectedCudaShapeInfoProviderTest { @Before public void setUp() throws Exception { } @Test public void testPurge1() throws Exception { INDArray array = Nd4j.create(10, 10); ProtectedCudaShapeInfoProvider provider = (ProtectedCudaShapeInfoProvider) ProtectedCudaShapeInfoProvider.getInstance(); assertEquals(true, provider.protector.containsDataBuffer(0, new ShapeDescriptor(array.shape(), array.stride(),0, array.elementWiseStride(), array.ordering()))); Nd4j.getMemoryManager().purgeCaches(); assertEquals(false, provider.protector.containsDataBuffer(0, new ShapeDescriptor(array.shape(), array.stride(),0, array.elementWiseStride(), array.ordering()))); // INDArray array2 = Nd4j.create(10, 10); } @Test public void testPurge2() throws Exception { INDArray arrayA = Nd4j.create(10, 10); DataBuffer shapeInfoA = arrayA.shapeInfoDataBuffer(); INDArray arrayE = Nd4j.create(10, 10); DataBuffer shapeInfoE = arrayE.shapeInfoDataBuffer(); int[] arrayShapeA = shapeInfoA.asInt(); assertTrue(shapeInfoA == shapeInfoE); ShapeDescriptor descriptor = new ShapeDescriptor(arrayA.shape(), arrayA.stride(), 0, arrayA.elementWiseStride(), arrayA.ordering()); ConstantProtector protector = ConstantProtector.getInstance(); AllocationPoint pointA = AtomicAllocator.getInstance().getAllocationPoint(arrayA.shapeInfoDataBuffer()); assertEquals(true, protector.containsDataBuffer(0, descriptor)); //////////////////////////////////// Nd4j.getMemoryManager().purgeCaches(); //////////////////////////////////// assertEquals(false, protector.containsDataBuffer(0, descriptor)); INDArray arrayB = Nd4j.create(10, 10); DataBuffer shapeInfoB = arrayB.shapeInfoDataBuffer(); assertFalse(shapeInfoA == shapeInfoB); AllocationPoint pointB = AtomicAllocator.getInstance().getAllocationPoint(arrayB.shapeInfoDataBuffer()); assertArrayEquals(arrayShapeA, shapeInfoB.asInt()); // pointers should be equal, due to offsets reset assertEquals(pointA.getPointers().getDevicePointer().address(), pointB.getPointers().getDevicePointer().address()); } @Test public void testPurge3() throws Exception { INDArray arrayA = Nd4j.create(10, 10); DataBuffer shapeInfoA = arrayA.shapeInfoDataBuffer(); int[] shapeA = shapeInfoA.asInt(); log.info("ShapeA: {}", shapeA); Nd4j.getMemoryManager().purgeCaches(); INDArray arrayB = Nd4j.create(20, 20); DataBuffer shapeInfoB = arrayB.shapeInfoDataBuffer(); int[] shapeB = shapeInfoB.asInt(); log.info("ShapeB: {}", shapeB); } }