/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.dataset;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.neural.NeuralNetworkError;
import org.encog.neural.flat.FlatLayer;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.structure.NeuralStructure;
/**
* Extend {@link NeuralStructure} to set {@link FloatFlatNetwork}.
*
* <p>
* {@link #finalizeStruct()} is used to replace {@link #finalizeStructure()} as {@link #finalizeStructure()} is set to
* final and cannot be override.
*/
public class FloatNeuralStructure extends NeuralStructure {
private static final long serialVersionUID = 7662087479144051670L;
public FloatNeuralStructure(BasicNetwork network) {
super(network);
}
/**
* Build the synapse and layer structure. This method should be called afteryou are done adding layers to a network,
* or change the network's logic property.
*/
public void finalizeStruct() {
if(this.getLayers().size() < 2) {
throw new NeuralNetworkError("There must be at least two layers before the structure is finalized.");
}
final FlatLayer[] flatLayers = new FlatLayer[this.getLayers().size()];
for(int i = 0; i < this.getLayers().size(); i++) {
final BasicLayer layer = (BasicLayer) this.getLayers().get(i);
if(layer.getActivation() == null) {
layer.setActivation(new ActivationLinear());
}
flatLayers[i] = layer;
}
this.setFlat(new FloatFlatNetwork(flatLayers));
finalizeLimit();
this.getLayers().clear();
enforceLimit();
}
}