/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.learner.functions.neuralnet; import java.io.Serializable; import com.rapidminer.example.Example; /** * A node is the abstract superclass for all types of neural net nodes and also represents * the connection between other nodes of the neural net. It performs * most of the calculations and the feedforward / backpropagation mechanism. * * @author Ingo Mierswa, Sebastian Land */ public abstract class Node implements Serializable { private static final long serialVersionUID = -4888796462060891114L; public static final int INPUT = -1; public static final int HIDDEN = 0; public static final int OUTPUT = -2; private int layerIndex; private String nodeName; private int nodeType; private boolean weightsAreUpdated = false; protected Node[] inputNodes = new Node[0]; protected Node[] outputNodes = new Node[0]; protected int[] inputNodeOutputIndices = new int[0]; protected int[] outputNodeInputIndices = new int[0]; protected double currentValue = Double.NaN; protected double currentError = Double.NaN; public Node(String nodeName, int layerIndex, int nodeType) { this.layerIndex = layerIndex; this.nodeName = nodeName; this.nodeType = nodeType; } /** Calculates the output for this node. */ public abstract double calculateValue(boolean calculate, Example example); /** Calculates the error for this node. */ public abstract double calculateError(boolean calculate, Example example); /** Returns 1. Subclasses should overwrite this method. */ public double getWeight(int n) { return 1; } public int getLayerIndex() { return this.layerIndex; } public String getNodeName() { return this.nodeName; } public int getNodeType() { return this.nodeType; } public void update(Example example, double learningRate, double momentum) { if (!weightsAreUpdated) { for (int i = 0; i < inputNodes.length; i++) { inputNodes[i].update(example, learningRate, momentum); } weightsAreUpdated = true; } } public boolean areWeightsUpdated() { return this.weightsAreUpdated; } public void reset() { if (!Double.isNaN(currentValue) || !Double.isNaN(currentError)) { weightsAreUpdated = false; currentValue = Double.NaN; currentError = Double.NaN; for (int i = 0; i < inputNodes.length; i++) { inputNodes[i].reset(); } } } public Node[] getInputNodes() { return inputNodes; } public Node[] getOutputNodes() { return outputNodes; } public int[] getInputNodeOutputIndices() { return inputNodeOutputIndices; } public int[] getOutputNodeInputIndices() { return outputNodeInputIndices; } protected boolean connectInput(Node inputNode, int inputNodeOutputIndex) { Node[] newInputNodes = new Node[inputNodes.length + 1]; System.arraycopy(inputNodes, 0, newInputNodes, 0, inputNodes.length); newInputNodes[newInputNodes.length - 1] = inputNode; inputNodes = newInputNodes; int[] newInputNodeOutputIndices = new int[inputNodeOutputIndices.length + 1]; System.arraycopy(inputNodeOutputIndices, 0, newInputNodeOutputIndices, 0, inputNodeOutputIndices.length); newInputNodeOutputIndices[newInputNodeOutputIndices.length - 1] = inputNodeOutputIndex; inputNodeOutputIndices = newInputNodeOutputIndices; return true; } protected boolean connectOutput(Node outputNode, int outputNodeInputIndex) { Node[] newOutputNodes = new Node[outputNodes.length + 1]; System.arraycopy(outputNodes, 0, newOutputNodes, 0, outputNodes.length); newOutputNodes[newOutputNodes.length - 1] = outputNode; outputNodes = newOutputNodes; int[] newOutputNodeInputIndices = new int[outputNodeInputIndices.length + 1]; System.arraycopy(outputNodeInputIndices, 0, newOutputNodeInputIndices, 0, outputNodeInputIndices.length); newOutputNodeInputIndices[newOutputNodeInputIndices.length - 1] = outputNodeInputIndex; outputNodeInputIndices = newOutputNodeInputIndices; return true; } protected boolean disconnectInput(Node inputNode, int inputNodeOutputIndex) { int deleteIndex = -1; boolean removed = false; int numberOfInputs = inputNodes.length; do { deleteIndex = -1; for (int i = 0; i < inputNodes.length; i++) { if (inputNode == inputNodes[i] && (inputNodeOutputIndex == -1 || inputNodeOutputIndex == inputNodeOutputIndices[i])) { deleteIndex = i; break; } } if (deleteIndex >= 0) { for (int i = deleteIndex + 1; i < numberOfInputs; i++) { inputNodes[i - 1] = inputNodes[i]; inputNodeOutputIndices[i - 1] = inputNodeOutputIndices[i]; inputNodes[i - 1].outputNodeInputIndices[inputNodeOutputIndices[i - 1]] = i - 1; } numberOfInputs--; removed = true; } } while (inputNodeOutputIndex == -1 && deleteIndex != -1); Node[] newInputNodes = new Node[numberOfInputs]; System.arraycopy(inputNodes, 0, newInputNodes, 0, numberOfInputs); inputNodes = newInputNodes; int[] newInputNodeOutputIndices = new int[numberOfInputs]; System.arraycopy(inputNodeOutputIndices, 0, newInputNodeOutputIndices, 0, numberOfInputs); inputNodeOutputIndices = newInputNodeOutputIndices; return removed; } protected boolean disconnectOutput(Node outputNode, int outputNodeInputIndex) { int deleteIndex = -1; boolean removed = false; int numberOfOutputs = outputNodes.length; do { deleteIndex = -1; for (int i = 0; i < outputNodes.length; i++) { if (outputNode == outputNodes[i] && (outputNodeInputIndex == -1 || outputNodeInputIndex == outputNodeInputIndices[i])) { deleteIndex = i; break; } } if (deleteIndex >= 0) { for (int i = deleteIndex + 1; i < numberOfOutputs; i++) { outputNodes[i - 1] = outputNodes[i]; outputNodeInputIndices[i - 1] = outputNodeInputIndices[i]; outputNodes[i - 1].inputNodeOutputIndices[outputNodeInputIndices[i - 1]] = i - 1; } numberOfOutputs--; removed = true; } } while (outputNodeInputIndex == -1 && deleteIndex != -1); Node[] newOutputNodes = new Node[numberOfOutputs]; System.arraycopy(outputNodes, 0, newOutputNodes, 0, numberOfOutputs); outputNodes = newOutputNodes; int[] newOutputNodeInputIndices = new int[numberOfOutputs]; System.arraycopy(outputNodeInputIndices, 0, newOutputNodeInputIndices, 0, numberOfOutputs); outputNodeInputIndices = newOutputNodeInputIndices; return removed; } public static boolean connect(Node firstNode, Node secondNode) { disconnect(firstNode, secondNode); if (!firstNode.connectOutput(secondNode, secondNode.inputNodes.length)) { return false; } if (!secondNode.connectInput(firstNode, firstNode.outputNodes.length - 1)) { firstNode.disconnectOutput(secondNode, secondNode.inputNodes.length); return false; } return true; } public static boolean disconnect(Node firstNode, Node secondNode) { return firstNode.disconnectOutput(secondNode, -1) && secondNode.disconnectInput(firstNode, -1); } }