package func.nn.backprop;
import func.nn.Link;
import func.nn.activation.DifferentiableActivationFunction;
import func.nn.feedfwd.FeedForwardNode;
/**
* A back propagation node
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class BackPropagationNode extends FeedForwardNode {
/**
* The derivative of the error with respect to
* the activation of this node.
*/
private double inputError;
/**
* The deriviative of the error with respect to
* the activation of this node.
*/
private double outputError;
/**
* Create a new back propogation node
* @param function the differentiable activation function
* @param learningRate the learning rate
* @param momentum the momentum
*/
public BackPropagationNode(DifferentiableActivationFunction function) {
super(function);
}
/**
* Back propagate error values.
* For nodes that have output links, first
* calculates the derivative of the error function
* with respect to this node by finding the weighted
* sum of the errors of nodes this node outputs to,
* and multiplying that by the derivative of the activation
* function applied to the weighted input sum.
* For nodes with output links, simply moves the error
* to the output (assuming that the appropriate error
* function / activation function combination was
* used).
*/
public void backpropagate() {
if (getOutLinkCount() > 0) {
double weightedErrorSum = 0;
for (int i = 0; i < getOutLinkCount(); i++) {
BackPropagationLink outLink =
(BackPropagationLink) getOutLink(i);
weightedErrorSum += outLink.getWeightedOutError();
}
setOutputError(weightedErrorSum);
DifferentiableActivationFunction act =
(DifferentiableActivationFunction) getActivationFunction();
setInputError(act.derivative(getWeightedInputSum()) * getOutputError());
} else {
setInputError(getOutputError());
}
}
/**
* Backpropagate error into the incoming links
* from this node
*/
public void backpropagateLinks() {
for (int i = 0; i < getInLinkCount(); i++) {
BackPropagationLink inLink =
(BackPropagationLink) getInLink(i);
inLink.backpropagate();
}
}
/**
* Update the incoming weights with the given rule
* @param rule the rule to use
*/
public void updateWeights(WeightUpdateRule rule) {
for (int i = 0; i < getInLinkCount(); i++) {
BackPropagationLink inLink =
(BackPropagationLink) getInLink(i);
rule.update(inLink);
}
}
/**
* Set the error for this node with respect to
* the output of the node
* @param error the new error value
*/
public void setOutputError(double error) {
outputError = error;
}
/**
* Get the error for this node with respect
* to the output of the node
* @return the error
*/
public double getOutputError() {
return outputError;
}
/**
* Get the error for this node with respect to
* the weighted input of the node
* @return the error
*/
public double getInputError() {
return inputError;
}
/**
* Set the error with respect
* to the weighted input of the node
* @param error the error
*/
public void setInputError(double error) {
inputError = error;
}
/**
* Clears all of the error derivatives for
* the incoming links.
*/
public void clearError() {
for (int i = 0; i < getInLinkCount(); i++) {
BackPropagationLink inLink =
(BackPropagationLink) getInLink(i);
inLink.clearError();
}
}
/**
* @see nn.Node#createLink()
*/
public Link createLink() {
return new BackPropagationLink();
}
}