package org.nd4j.linalg.lossfunctions; import lombok.extern.slf4j.Slf4j; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.lossfunctions.impl.*; import java.util.Arrays; /** * Created by Alex on 08/08/2016. */ @Slf4j public class LossFunctionGradientChecks extends BaseNd4jTest { public static final double epsilon = 1e-6; private static final double maxRelError = 5.0; //5% relative error DataBuffer.Type initialType; public LossFunctionGradientChecks(Nd4jBackend backend) { super(backend); this.initialType = Nd4j.dataType(); } @Before public void before() throws Exception { super.before(); Nd4j.zeros(1); DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); Nd4j.getRandom().setSeed(123); } @Test public void testLossFunctionGradients() { INDArray[] labels = new INDArray[] {Nd4j.create(new double[] {0, 1, 0}), Nd4j.create(new double[] {0, 1, 1}), /*Nd4j.create(new double[][]{{1,0,0},{0,1,0},{0,0,1}}), Nd4j.create(new double[]{1,2,1}), Nd4j.create(new double[][]{{1,2,1},{0.1,1,0.5},{20,3,1}}), Nd4j.create(new double[]{1,0,0}), Nd4j.create(new double[][]{{1,0,0,0},{0,1,0,0},{0,0,1,0},{0,0,0,1}}), Nd4j.create(new double[]{1,2,1}), Nd4j.create(new double[][]{{101,21,110},{10.1,1,0.5},{200,30,0.001}}), */ Nd4j.create(new double[] {1, 2, 1}), Nd4j.create(new double[][] {{101, 21, 110}, {10.1, 1, 0.5}, {200, 30, 0.001}}), Nd4j.create(new double[] {1, 2, 1}), Nd4j.create(new double[][] {{101, 21, 110}, {10.1, 1, 0.5}, {200, 30, 0.001}}), Nd4j.create(new double[] {1, 2, 1}), Nd4j.create(new double[][] {{101, 21, 110}, {10.1, 1, 0.5}, {200, 30, 0.001}}), Nd4j.create(new double[] {1, 2, 1}), Nd4j.create(new double[][] {{101, 21, 110}, {10.1, 1, 0.5}, {200, 30, 0.001}}), //Nd4j.create(new double[][] {{-1,-1,1},{-1,1,1},{-1,1,1}}), Nd4j.create(new double[][] {{-1, 1, -1}, {1, 1, -1}, {-1, 1, 1}}), Nd4j.create(new double[][] {{-1, 1, -1}, {1, 1, -1}, {-1, 1, 1}}), //Nd4j.create(new double[][] {{10,1,3},{1,10,1},{1,2,5}}), //Nd4j.create(new double[][] {{10,-1,3},{1,10,1},{1,2,-5}}), }; INDArray[] preOut = new INDArray[] {Nd4j.rand(1, 3), Nd4j.rand(1, 3), /* Nd4j.rand(3,3), Nd4j.rand(1,3).add(5), Nd4j.rand(3,3), Nd4j.rand(1,3).add(5), Nd4j.rand(4,4),*/ Nd4j.randn(1, 3), Nd4j.randn(3, 3).add(10), Nd4j.rand(1, 3), Nd4j.randn(3, 3).add(10), Nd4j.randn(1, 3), Nd4j.randn(3, 3).add(10), Nd4j.rand(1, 3), Nd4j.randn(3, 3).add(10), /* Nd4j.rand(1,3), Nd4j.randn(3,3).add(10), */ Nd4j.rand(3, 3).addi(-0.5), //adding a neg num makes some +ve/ some -ve Nd4j.rand(3, 3).addi(-0.5), //adding a neg num makes some +ve/ some -ve // Nd4j.rand(3,3), //Nd4j.randn(3,3) }; ILossFunction[] lossFn = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(), /*new LossMCXENT(), new LossMCXENT(), new LossMCXENT(),new LossMSE(), new LossMSE(), new LossKLD(), new LossKLD(), new LossMAE(), new LossMAE(),*/ new LossMAE(), new LossMAE(), new LossMSE(), new LossMSE(), new LossL1(), new LossL1(), new LossL2(), new LossL2(), new LossSquaredHinge(), new LossHinge(), //new LossPoisson(), //new LossCosineProximity() }; String[] activationFns = new String[] {"identity", "tanh", /*"softmax","tanh","identity","tanh", "tanh","identity","identity","identity","identity",*/ "identity", "identity", "identity", "identity", "sigmoid", "relu", "sigmoid", "relu", "identity", "identity", //"relu", //"identity" }; for (int i = 0; i < labels.length; i++) { //if (i != labels.length-2) continue; int totalNFailures = 0; ILossFunction lf = lossFn[i]; INDArray l = labels[i]; INDArray p = preOut[i]; String afn = activationFns[i]; System.out.printf("Starting test: %s, %s, input shape = %s\n", lf, afn, Arrays.toString(p.shape())); INDArray grad = lf.computeGradient(l, p, activationInstance(afn), null); NdIndexIterator iter = new NdIndexIterator(l.shape()); while (iter.hasNext()) { int[] next = iter.next(); double before = p.getDouble(next); p.putScalar(next, before + epsilon); double scorePlus = lf.computeScore(l, p, activationInstance(afn), null, true); p.putScalar(next, before - epsilon); double scoreMinus = lf.computeScore(l, p, activationInstance(afn), null, true); p.putScalar(next, before); double scoreDelta = scorePlus - scoreMinus; double numericalGradient = scoreDelta / (2 * epsilon); double analyticGradient = grad.getDouble(next) / l.size(0); //Analytic gradient method is before dividing by minibatch double relError = Math.abs(analyticGradient - numericalGradient) * 100 / (Math.abs(numericalGradient)); if (analyticGradient == 0.0 && numericalGradient == 0.0) relError = 0.0; //Edge case: i.e., RNNs with time series length of 1.0 if (relError > maxRelError || Double.isNaN(relError)) { System.out.println("Param " + i + " FAILED: grad= " + analyticGradient + ", numericalGrad= " + numericalGradient + ", relErrorPerc= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus); totalNFailures++; } else { System.out.println("Param " + i + " passed: grad= " + analyticGradient + ", numericalGrad= " + numericalGradient + ", relError= " + relError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus); } //System.out.println("Param " + i + " passed: grad= " + analyticGradient + ", numericalGrad= " + numericalGradient // + ", relError= " + relError + ", scorePlus="+scorePlus+", scoreMinus= " + scoreMinus ); } if (totalNFailures > 0) System.out.println("Gradient check failed for loss function " + lf + "; total num failures = " + totalNFailures); System.out.println("DONE"); } } /* public static List<INDArray> makeLabels(String activation,int[]labelSize) { //edge cases are label size of one for everything except softmax which is two //+ve and -ve values, zero and non zero values, less than zero/greater than zero List<INDArray> returnVals = new ArrayList<>(labelSize.length); for (int i=0; i< labelSize.length; i++) { int aLabelSize = labelSize[i]; Random r = new Random(); double[] someVals = new double[aLabelSize]; double someValsSum = 0; for (int j=0; j<aLabelSize; j++) { double someVal = r.nextGaussian(); double transformVal = 0; switch (activation) { case "identity": transformVal = someVal; case "softmax": transformVal = someVal; break; case "sigmoid": transformVal = Math.sin(someVal); break; case "tanh": transformVal = Math.tan(someVal); break; case "reul": transformVal = someVal * someVal + 4; break; } someVals[j] = transformVal; someValsSum += someVals[j]; } if (activation == "sigmoid") { for (int j=0; j<aLabelSize; j++) { someVals[j] /= someValsSum; } } returnVals.add(Nd4j.create(someVals)); } return returnVals; } */ public static IActivation activationInstance(String activation) { IActivation activationFn = new ActivationSigmoid(); switch (activation) { case "identity": activationFn = new ActivationIdentity(); case "softmax": activationFn = new ActivationSoftmax(); break; case "sigmoid": activationFn = new ActivationSigmoid(); break; case "tanh": activationFn = new ActivationTanH(); break; case "reul": activationFn = new ActivationReLU(); break; } return activationFn; } @After public void after() { DataTypeUtil.setDTypeForContext(this.initialType); System.out.println("AFTER DATATYPE HERE: " + Nd4j.dataType()); } @Override public char ordering() { return 'f'; } }