package org.nd4j.linalg.aggregates; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.aggregates.Aggregate; import org.nd4j.linalg.api.ops.aggregates.impl.AggregateAxpy; import org.nd4j.linalg.api.ops.aggregates.impl.AggregateSkipGram; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ @RunWith(Parameterized.class) public class AggregatesTests extends BaseNd4jTest { public AggregatesTests(Nd4jBackend backend) { super(backend); } @Before public void setUp() { //DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); } @Test public void testAggregate1() throws Exception { INDArray arrayX = Nd4j.ones(10); INDArray arrayY = Nd4j.zeros(10); INDArray exp1 = Nd4j.ones(10); AggregateAxpy axpy = new AggregateAxpy(arrayX, arrayY, 1.0f); Nd4j.getExecutioner().exec(axpy); assertEquals(exp1, arrayY); } @Test public void testBatchedAggregate1() throws Exception { INDArray arrayX1 = Nd4j.ones(10); INDArray arrayY1 = Nd4j.zeros(10); INDArray arrayX2 = Nd4j.ones(10); INDArray arrayY2 = Nd4j.zeros(10); INDArray exp1 = Nd4j.create(10).assign(1f); INDArray exp2 = Nd4j.create(10).assign(1f); AggregateAxpy axpy1 = new AggregateAxpy(arrayX1, arrayY1, 1.0f); AggregateAxpy axpy2 = new AggregateAxpy(arrayX2, arrayY2, 1.0f); List<Aggregate> batch = new ArrayList<>(); batch.add(axpy1); batch.add(axpy2); Nd4j.getExecutioner().exec(batch); assertEquals(exp1, arrayY1); assertEquals(exp2, arrayY2); } @Test public void testBatchedAggregate2() throws Exception { INDArray arrayX1 = Nd4j.ones(10); INDArray arrayY1 = Nd4j.zeros(10).assign(2.0f); INDArray arrayX2 = Nd4j.ones(10); INDArray arrayY2 = Nd4j.zeros(10).assign(2.0f); INDArray arrayX3 = Nd4j.ones(10); INDArray arrayY3 = Nd4j.ones(10); INDArray exp1 = Nd4j.create(10).assign(4f); INDArray exp2 = Nd4j.create(10).assign(3f); INDArray exp3 = Nd4j.create(10).assign(3f); AggregateAxpy axpy1 = new AggregateAxpy(arrayX1, arrayY1, 2.0f); AggregateAxpy axpy2 = new AggregateAxpy(arrayX2, arrayY2, 1.0f); AggregateAxpy axpy3 = new AggregateAxpy(arrayX3, arrayY3, 2.0f); List<Aggregate> batch = new ArrayList<>(); batch.add(axpy1); batch.add(axpy2); batch.add(axpy3); Nd4j.getExecutioner().exec(batch); assertEquals(exp1, arrayY1); assertEquals(exp2, arrayY2); assertEquals(exp3, arrayY3); } @Test public void testBatchedSkipGram1() throws Exception { INDArray syn0 = Nd4j.create(10, 10).assign(0.01f); INDArray syn1 = Nd4j.create(10, 10).assign(0.02f); INDArray syn1Neg = Nd4j.ones(10, 10).assign(0.03f); INDArray expTable = Nd4j.create(10000).assign(0.5f); double lr = 0.001; int idxSyn0_1 = 0; int idxSyn0_2 = 3; INDArray expSyn0 = Nd4j.create(10).assign(0.01f); INDArray expSyn1_1 = Nd4j.create(10).assign(0.020005); // gradient is 0.00005 INDArray expSyn1_2 = Nd4j.create(10).assign(0.019995f); // gradient is -0.00005 INDArray syn0row_1 = syn0.getRow(idxSyn0_1); INDArray syn0row_2 = syn0.getRow(idxSyn0_2); AggregateSkipGram op1 = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, null, idxSyn0_1, new int[] {1, 2}, new int[] {0, 1}, 0, 0, 10, lr, 1L, 10); AggregateSkipGram op2 = new AggregateSkipGram(syn0, syn1, syn1Neg, expTable, null, idxSyn0_2, new int[] {4, 5}, new int[] {0, 1}, 0, 0, 10, lr, 1L, 10); List<Aggregate> batch = new ArrayList<>(); batch.add(op1); batch.add(op2); Nd4j.getExecutioner().exec(batch); /* Since expTable contains all-equal values, and only difference for ANY index is code being 0 or 1, syn0 row will stay intact, because neu1e will be full of 0.0f, and axpy will have no actual effect */ assertEquals(expSyn0, syn0row_1); assertEquals(expSyn0, syn0row_2); // syn1 row 1 modified only once assertEquals(expSyn1_1, syn1.getRow(1)); assertEquals(expSyn1_1, syn1.getRow(4)); // syn1 row 2 modified only once assertEquals(expSyn1_2, syn1.getRow(2)); assertEquals(expSyn1_2, syn1.getRow(5)); } @Override public char ordering() { return 'c'; } }