package org.nd4j.linalg.ops; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.RationalTanhDerivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertTrue; /** * Rational tanh approximation from * https://arxiv.org/pdf/1508.01292v3 * https://github.com/deeplearning4j/libnd4j/issues/351 */ @RunWith(Parameterized.class) public class RationalTanhTest extends BaseNd4jTest { public RationalTanhTest(Nd4jBackend backend) { super(backend); } @Test public void gradientCheck() { double eps = 1e-6; INDArray A = Nd4j.linspace(-3, 3, 10).reshape(2, 5); INDArray ADer = Nd4j.getExecutioner().execAndReturn(new RationalTanhDerivative(A.dup())); double[] a = A.data().asDouble(); double[] aDer = ADer.data().asDouble(); for (int i = 0; i < 10; i++) { double empirical = (f(a[i] + eps) - f(a[i] - eps)) / (2 * eps); double analytic = aDer[i]; assertTrue(Math.abs(empirical - analytic) / (Math.abs(empirical) + Math.abs(analytic)) < 0.001); } } public static double f(double x) { return 1.7159 * tanhApprox(2.0 / 3 * x); } /* public static INDArray fDeriv(double x){ //return C1 * 2.0/3 * tanhDeriv(2.0 / 3 * x); } */ public static double tanhApprox(double y) { return Math.signum(y) * (1.0 - 1.0 / (1 + Math.abs(y) + y * y + 1.41645 * Math.pow(y, 4.0))); } /* public static double tanhDeriv(double y){ double a = 1 + Math.abs(y) + y*y + C * Math.pow(y,4); return (1 + Math.signum(y) * (2*y + 4*C*Math.pow(y,3))) / (a * a); } */ @Override public char ordering() { return 'f'; } }