package org.deeplearning4j.examples.misc.activationfunctions;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Tanh;
import org.nd4j.linalg.factory.Nd4j;
/**
* This is an example of how to implement a custom activation function that does not take any learnable parameters
* Custom activation functions of this case should extend from BaseActivationFunction and implement the methods
* shown here.
* IMPORTANT: Do not forget gradient checks. Refer to these in the deeplearning4j repo,
* deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java
*
* The form of the activation function implemented here is from https://arxiv.org/abs/1508.01292
* "Compact Convolutional Neural Network Cascade for Face Detection" by Kalinovskii I.A. and Spitsyn V.G.
*
* h(x) = 1.7159 tanh(2x/3)
*
* @author susaneraly
*/
public class CustomActivation extends BaseActivationFunction{
/*
For the forward pass:
Transform "in" with the activation function. Best practice is to do the transform in place as shown below
Can support different behaviour during training and test with the boolean argument
*/
@Override
public INDArray getActivation(INDArray in, boolean training) {
//Modify array "in" inplace to transform it with the activation function
// h(x) = 1.7159*tanh(2x/3)
Nd4j.getExecutioner().execAndReturn(new Tanh(in.muli(2/3.0)));
in.muli(1.7159);
return in;
}
/*
For the backward pass:
Given epsilon, the gradient at every activation node calculate the next set of gradients for the backward pass
Best practice is to modify in place.
Using the terminology,
in -> linear input to the activation node
out -> the output of the activation node, or in other words h(out) where h is the activation function
epsilon -> the gradient of the loss function with respect to the output of the activation node, d(Loss)/dout
h(in) = out;
d(Loss)/d(in) = d(Loss)/d(out) * d(out)/d(in)
= epsilon * h'(in)
*/
@Override
public Pair<INDArray,INDArray> backprop(INDArray in, INDArray epsilon) {
//dldZ here is h'(in) in the description above
//
// h(x) = 1.7159*tanh(2x/3);
// h'(x) = 1.7159*[tanh(2x/3)]' * 2/3
INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new Tanh(in.muli(2/3.0)).derivative());
dLdz.muli(2/3.0);
dLdz.muli(1.7159);
//Multiply with epsilon
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
}
}