/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.neural.flat; import java.io.Serializable; import java.util.Arrays; import org.encog.engine.network.activation.ActivationLinear; import org.encog.mathutil.rbf.RadialBasisFunction; import org.encog.util.EngineArray; /** * A flat network designed to handle an RBF. */ public class FlatNetworkRBF extends FlatNetwork implements Serializable, Cloneable { /** * */ private static final long serialVersionUID = 1L; /** * The RBF's used. */ private RadialBasisFunction[] rbf; /** * Default constructor. */ public FlatNetworkRBF() { } /** * Construct an RBF flat network. * * @param inputCount * The number of input neurons. (also the number of dimensions) * @param hiddenCount * The number of hidden neurons. * @param outputCount * The number of output neurons. * @param rbf * The radial basis functions to use. */ public FlatNetworkRBF(final int inputCount, final int hiddenCount, final int outputCount, final RadialBasisFunction[] rbf) { FlatLayer[] layers = new FlatLayer[3]; this.rbf = rbf; layers[0] = new FlatLayer(new ActivationLinear(), inputCount, 0.0); layers[1] = new FlatLayer(new ActivationLinear(), hiddenCount, 0.0); layers[2] = new FlatLayer(new ActivationLinear(), outputCount, 0.0); init(layers,false); } /** * Clone the network. * * @return A clone of the network. */ @Override public FlatNetworkRBF clone() { final FlatNetworkRBF result = new FlatNetworkRBF(); cloneFlatNetwork(result); result.rbf = this.rbf; return result; } /** * Calculate the output for the given input. * * @param x * The input. * @param output * Output will be placed here. */ @Override public void compute(final double[] x, final double[] output) { int outputIndex = this.getLayerIndex()[1]; for (int i = 0; i < rbf.length; i++) { double o = this.rbf[i].calculate(x); this.getLayerOutput()[outputIndex + i] = o; } // now compute the output computeLayer(1); EngineArray.arrayCopy(this.getLayerOutput(), 0, output, 0, this .getOutputCount()); } /** * Set the RBF's used. * @param rbf The RBF's used. */ public void setRBF(final RadialBasisFunction[] rbf) { this.rbf = Arrays.copyOf(rbf, rbf.length); } /** * @return The RBF's used. */ public RadialBasisFunction[] getRBF() { return this.rbf; } }