/*- * * * Copyright 2016 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.nn.conf; import lombok.*; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.shade.jackson.databind.JsonNode; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.Serializable; import java.util.*; /** * ComputationGraphConfiguration is a configuration object for neural networks with arbitrary connection structure. * It is analogous to {@link MultiLayerConfiguration}, but allows considerably greater flexibility for the network * architecture.<br> * Specifically, the network architecture is a directed acyclic graph, where each vertex in the graph is a {@link GraphVertex}, * which may for example be a layer or a vertex/object that defines arbitrary forward and backward pass functionality.<br> * Note that the ComputationGraph may have an arbitrary number of inputs (multiple independent inputs, possibly of different * types), and an arbitrary number of outputs (for example, multiple {@link OutputLayer} instances. * Typical usage:<br> * {@code ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()....graphBuilder()...build();} * * @author Alex Black */ @Data @EqualsAndHashCode @AllArgsConstructor(access = AccessLevel.PRIVATE) @NoArgsConstructor public class ComputationGraphConfiguration implements Serializable, Cloneable { private static Logger log = LoggerFactory.getLogger(ComputationGraphConfiguration.class); protected Map<String, GraphVertex> vertices = new LinkedHashMap<>(); protected Map<String, List<String>> vertexInputs = new LinkedHashMap<>(); @Getter @Setter protected WorkspaceMode trainingWorkspaceMode; @Getter @Setter protected WorkspaceMode inferenceWorkspaceMode; /** * List of inputs to the network, by name */ protected List<String> networkInputs; /** * List of network outputs, by name */ protected List<String> networkOutputs; protected boolean pretrain = false; protected boolean backprop = true; protected BackpropType backpropType = BackpropType.Standard; protected int tbpttFwdLength = 20; protected int tbpttBackLength = 20; protected NeuralNetConfiguration defaultConfiguration; //Counter for the number of parameter updates so far // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted // for Spark and model serialization protected int iterationCount = 0; /** * @return JSON representation of configuration */ public String toYaml() { ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); synchronized (mapper) { try { return mapper.writeValueAsString(this); } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) { throw new RuntimeException(e); } } } /** * Create a neural net configuration from json * * @param json the neural net configuration from json * @return {@link ComputationGraphConfiguration} */ public static ComputationGraphConfiguration fromYaml(String json) { ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); try { return mapper.readValue(json, ComputationGraphConfiguration.class); } catch (IOException e) { throw new RuntimeException(e); } } /** * @return JSON representation of computation graph configuration */ public String toJson() { //As per MultiLayerConfiguration.toJson() ObjectMapper mapper = NeuralNetConfiguration.mapper(); synchronized (mapper) { //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 try { return mapper.writeValueAsString(this); } catch (org.nd4j.shade.jackson.core.JsonProcessingException e) { throw new RuntimeException(e); } } } /** * Create a computation graph configuration from json * * @param json the neural net configuration from json * @return {@link ComputationGraphConfiguration} */ public static ComputationGraphConfiguration fromJson(String json) { //As per MultiLayerConfiguration.fromJson() ObjectMapper mapper = NeuralNetConfiguration.mapper(); ComputationGraphConfiguration conf; try { conf = mapper.readValue(json, ComputationGraphConfiguration.class); } catch (IOException e) { throw new RuntimeException(e); } //To maintain backward compatibility after activation function refactoring (configs generated with v0.7.1 or earlier) // Previously: enumeration used for activation functions. Now: use classes int layerCount = 0; Map<String, GraphVertex> vertexMap = conf.getVertices(); JsonNode vertices = null; for (Map.Entry<String, GraphVertex> entry : vertexMap.entrySet()) { if (!(entry.getValue() instanceof LayerVertex)) { continue; } LayerVertex lv = (LayerVertex) entry.getValue(); if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) { Layer layer = lv.getLayerConf().getLayer(); if (layer.getActivationFn() == null) { String layerName = layer.getLayerName(); try { if (vertices == null) { JsonNode jsonNode = mapper.readTree(json); vertices = jsonNode.get("vertices"); } JsonNode vertexNode = vertices.get(layerName); JsonNode layerVertexNode = vertexNode.get("LayerVertex"); if (layerVertexNode == null || !layerVertexNode.has("layerConf") || !layerVertexNode.get("layerConf").has("layer")) { continue; } JsonNode layerWrapperNode = layerVertexNode.get("layerConf").get("layer"); if (layerWrapperNode == null || layerWrapperNode.size() != 1) { continue; } JsonNode layerNode = layerWrapperNode.elements().next(); JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc if (activationFunction != null) { IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction(); layer.setActivationFn(ia); } } catch (IOException e) { log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e); } } } } return conf; } @Override public String toString() { return toJson(); } @Override public ComputationGraphConfiguration clone() { ComputationGraphConfiguration conf = new ComputationGraphConfiguration(); conf.vertices = new LinkedHashMap<>(); for (Map.Entry<String, GraphVertex> entry : this.vertices.entrySet()) { conf.vertices.put(entry.getKey(), entry.getValue().clone()); } conf.vertexInputs = new LinkedHashMap<>(); for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) { conf.vertexInputs.put(entry.getKey(), new ArrayList<>(entry.getValue())); } conf.networkInputs = new ArrayList<>(); conf.networkInputs = new ArrayList<>(this.networkInputs); conf.networkOutputs = new ArrayList<>(this.networkOutputs); conf.pretrain = pretrain; conf.backprop = backprop; conf.backpropType = backpropType; conf.tbpttFwdLength = tbpttFwdLength; conf.tbpttBackLength = tbpttBackLength; conf.defaultConfiguration = defaultConfiguration.clone(); conf.trainingWorkspaceMode = trainingWorkspaceMode; conf.inferenceWorkspaceMode = inferenceWorkspaceMode; return conf; } /** * Check the configuration, make sure it is valid * * @throws IllegalStateException if configuration is not valid */ public void validate() { if (networkInputs == null || networkInputs.size() < 1) { throw new IllegalStateException( "Invalid configuration: network has no inputs. Use .addInputs(String...) to label (and give an ordering to) the network inputs"); } if (networkOutputs == null || networkOutputs.size() < 1) { throw new IllegalStateException( "Invalid configuration: network has no outputs. Use .setOutput(String...) to specify (and give an ordering to) the output vertices"); } //Check uniqueness of names for inputs, layers, GraphNodes for (String s : networkInputs) { if (vertices.containsKey(s)) { throw new IllegalStateException("Invalid configuration: name \"" + s + "\" is present in both network inputs and graph vertices/layers"); } } //Check: each layer & node has at least one input for (Map.Entry<String, List<String>> e : vertexInputs.entrySet()) { String nodeName = e.getKey(); if (e.getValue() == null || e.getValue().isEmpty()) { throw new IllegalStateException("Invalid configuration: vertex \"" + nodeName + "\" has no inputs"); } for (String inputName : e.getValue()) { if (!vertices.containsKey(inputName) && !networkInputs.contains(inputName)) { throw new IllegalStateException("Invalid configuration: Vertex \"" + nodeName + "\" has input \"" + inputName + "\" that does not exist"); } } } //Check output names: for (String s : networkOutputs) { if (!vertices.containsKey(s)) { throw new IllegalStateException( "Invalid configuration: Output name \"" + s + "\" is not a valid vertex"); } } //Check for no graph cycles: done in ComputationGraph.init() } /** * Add preprocessors automatically, given the specified types of inputs for the network. Inputs are specified using the * {@link InputType} class, in the same order in which the inputs were defined in the original configuration.<br> * For example, in a network with two inputs: a convolutional input (28x28x1 images) and feed forward inputs, use * {@code .addPreProcessors(InputType.convolutional(1,28,28),InputType.feedForward())}.<br> * For the CNN->Dense and CNN->RNN transitions, the nIns on the Dense/RNN layers will also be added automatically. * <b>NOTE</b>: This method will be called automatically when using the * {@link GraphBuilder#setInputTypes(InputType...)} functionality. * See that method for details. */ public void addPreProcessors(InputType... inputTypes) { if (inputTypes == null || inputTypes.length != networkInputs.size()) { throw new IllegalArgumentException( "Invalid number of InputTypes: cannot add preprocessors if number of InputType " + "objects differs from number of network inputs"); } //Now: need to do essentially a forward pass through the network, to work out what type of preprocessors to add //To do this: need to know what the output types are for each GraphVertex. //First step: build network in reverse order (i.e., define map of a -> list(b) instead of list(a) -> b) Map<String, List<String>> verticesOutputTo = new HashMap<>(); //Key: vertex. Values: vertices that this node is an input for for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) { String vertexName = entry.getKey(); List<String> vertexInputNames; vertexInputNames = vertexInputs.get(vertexName); if (vertexInputNames == null) continue; //Build reverse network structure: for (String s : vertexInputNames) { List<String> list = verticesOutputTo.get(s); if (list == null) { list = new ArrayList<>(); verticesOutputTo.put(s, list); } list.add(vertexName); //Edge: s -> vertexName } } //Now: do topological sort LinkedList<String> noIncomingEdges = new LinkedList<>(networkInputs); //Set of all nodes with no incoming edges List<String> topologicalOrdering = new ArrayList<>(); Map<String, Set<String>> inputEdges = new HashMap<>(); for (Map.Entry<String, List<String>> entry : vertexInputs.entrySet()) { inputEdges.put(entry.getKey(), new HashSet<>(entry.getValue())); } while (!noIncomingEdges.isEmpty()) { String next = noIncomingEdges.removeFirst(); topologicalOrdering.add(next); //Remove edges next -> vertexOuputsTo[...] from graph; List<String> nextEdges = verticesOutputTo.get(next); if (nextEdges != null && !nextEdges.isEmpty()) { for (String s : nextEdges) { Set<String> set = inputEdges.get(s); set.remove(next); if (set.isEmpty()) { noIncomingEdges.add(s); //No remaining edges for vertex i -> add to list for processing } } } } //If any edges remain in the graph: graph has cycles: for (Map.Entry<String, Set<String>> entry : inputEdges.entrySet()) { Set<String> set = entry.getValue(); if (set == null) continue; if (!set.isEmpty()) throw new IllegalStateException( "Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (" + "cycle includes vertex \"" + entry.getKey() + "\")"); } //Now, given the topological sort: do equivalent of forward pass Map<String, InputType> vertexOutputs = new HashMap<>(); int currLayerIdx = -1; for (String s : topologicalOrdering) { int inputIdx = networkInputs.indexOf(s); if (inputIdx != -1) { vertexOutputs.put(s, inputTypes[inputIdx]); continue; } GraphVertex gv = vertices.get(s); List<InputType> inputTypeList = new ArrayList<>(); if (gv instanceof LayerVertex) { //Add preprocessor, if necessary: String in = vertexInputs.get(s).get(0); InputType layerInput = vertexOutputs.get(in); inputTypeList.add(layerInput); LayerVertex lv = (LayerVertex) gv; Layer l = lv.getLayerConf().getLayer(); //Preprocessors - add if necessary if (lv.getPreProcessor() == null) { //But don't override preprocessors that are manually defined; if none has been defined, //add the appropriate preprocessor for this input type/layer combination InputPreProcessor preproc = l.getPreProcessorForInputType(layerInput); lv.setPreProcessor(preproc); } //Set nIn value for layer (if not already set) InputType afterPreproc = layerInput; if (lv.getPreProcessor() != null) { InputPreProcessor ip = lv.getPreProcessor(); afterPreproc = ip.getOutputType(layerInput); } l.setNIn(afterPreproc, false); currLayerIdx++; } else { List<String> inputs = vertexInputs.get(s); if (inputs != null) { for (String inputVertexName : inputs) { inputTypeList.add(vertexOutputs.get(inputVertexName)); } } } InputType outputFromVertex = gv.getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()])); vertexOutputs.put(s, outputFromVertex); } } @Data public static class GraphBuilder { protected Map<String, GraphVertex> vertices = new LinkedHashMap<>(); /** * Key: graph node. Values: input to that node */ protected Map<String, List<String>> vertexInputs = new LinkedHashMap<>(); protected List<String> networkInputs = new ArrayList<>(); protected List<InputType> networkInputTypes = new ArrayList<>(); protected List<String> networkOutputs = new ArrayList<>(); protected boolean pretrain = false; protected boolean backprop = true; protected BackpropType backpropType = BackpropType.Standard; protected int tbpttFwdLength = 20; protected int tbpttBackLength = 20; protected Map<String, InputPreProcessor> inputPreProcessors = new LinkedHashMap<>(); protected NeuralNetConfiguration.Builder globalConfiguration; public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) { this.globalConfiguration = globalConfiguration; } public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration.Builder globalConfiguration) { ComputationGraphConfiguration clonedConf = newConf.clone(); this.vertices = clonedConf.getVertices(); this.vertexInputs = clonedConf.getVertexInputs(); this.networkInputs = clonedConf.getNetworkInputs(); this.networkOutputs = clonedConf.getNetworkOutputs(); this.pretrain = clonedConf.isPretrain(); this.backprop = clonedConf.isBackprop(); this.backpropType = clonedConf.getBackpropType(); this.tbpttFwdLength = clonedConf.getTbpttFwdLength(); this.tbpttBackLength = clonedConf.getTbpttBackLength(); this.globalConfiguration = globalConfiguration; //this.getGlobalConfiguration().setSeed(clonedConf.getDefaultConfiguration().getSeed()); } /** * Specify the processors for a given layer * These are used at each layer for doing things like normalization and shaping of input.<br> * <b>Note</b>: preprocessors can also be defined using the {@link #addLayer(String, Layer, InputPreProcessor, String...)} method. * * @param layer the name of the layer that this preprocessor will be used with * @param processor the preprocessor to use for the specified layer */ public GraphBuilder inputPreProcessor(String layer, InputPreProcessor processor) { inputPreProcessors.put(layer, processor); return this; } /** * Whether to do back prop (standard supervised learning) or not * * @param backprop whether to do back prop or not */ public GraphBuilder backprop(boolean backprop) { this.backprop = backprop; return this; } /** * Whether to do layerwise pre training or not * * @param pretrain whether to do pre train or not */ public GraphBuilder pretrain(boolean pretrain) { this.pretrain = pretrain; return this; } /** * The type of backprop. Default setting is used for most networks (MLP, CNN etc), * but optionally truncated BPTT can be used for training recurrent neural networks. * If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() * * @param type Type of backprop. Default: BackpropType.Standard */ public GraphBuilder backpropType(BackpropType type) { this.backpropType = type; return this; } /** * When doing truncated BPTT: how many steps of forward pass should we do * before doing (truncated) backprop?<br> * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br> * Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, * but may be larger than it in some circumstances (but never smaller)<br> * Ideally your training data time series length should be divisible by this * This is the k1 parameter on pg23 of * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param forwardLength Forward length > 0, >= backwardLength */ public GraphBuilder tBPTTForwardLength(int forwardLength) { this.tbpttFwdLength = forwardLength; return this; } /** * When doing truncated BPTT: how many steps of backward should we do?<br> * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)<br> * This is the k2 parameter on pg23 of * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param backwardLength <= forwardLength */ public GraphBuilder tBPTTBackwardLength(int backwardLength) { this.tbpttBackLength = backwardLength; return this; } /** * Add a layer, with no {@link InputPreProcessor}, with the specified name and specified inputs. * * @param layerName Name/label of the layer to add * @param layer The layer configuration * @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects, * on a combination of the two. * @see #addLayer(String, Layer, InputPreProcessor, String...) */ public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) { return addLayer(layerName, layer, null, layerInputs); } /** * Add a layer and an {@link InputPreProcessor}, with the specified name and specified inputs. * * @param layerName Name/label of the layer to add * @param layer The layer configuration * @param preProcessor The InputPreProcessor to use with this layer. * @param layerInputs Inputs to this layer (must be 1 or more). Inputs may be other layers, GraphVertex objects, * on a combination of the two. */ public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, String... layerInputs) { NeuralNetConfiguration.Builder builder = globalConfiguration.clone(); builder.layer(layer); vertices.put(layerName, new LayerVertex(builder.build(), preProcessor)); //Automatically insert a MergeNode if layerInputs.length > 1 //Layers can only have 1 input if (layerInputs != null && layerInputs.length > 1) { String mergeName = layerName + "-merge"; addVertex(mergeName, new MergeVertex(), layerInputs); this.vertexInputs.put(layerName, Collections.singletonList(mergeName)); } else if (layerInputs != null) { this.vertexInputs.put(layerName, Arrays.asList(layerInputs)); } layer.setLayerName(layerName); return this; } /** * Intended for use with the transfer learning API. Users discouraged from employing it directly. * Removes the specified vertex from the vertices list, it's connections and associated preprocessor * If the vertex removed is an output vertex it will also be removed from the list of outputs * @param vertexName Name of the vertex to remove */ public GraphBuilder removeVertex(String vertexName) { removeVertex(vertexName, true); return this; } /** * Intended for use with the transfer learning API. Users discouraged from employing it directly. * Removes the specified vertex from the vertices list, * Removes it's connections (associated preprocessor and if an output also removes it from list of outputs) if "removeConnections" is specified as true * Specifying as false can leave the graph in an invalid state with references to vertices that donot exist unless a new vertex is added back in with the same name * @param removeConnections Specify true to remove connections * @param vertexName Name of the vertex to remove */ public GraphBuilder removeVertex(String vertexName, boolean removeConnections) { vertices.remove(vertexName); vertexInputs.remove(vertexName); if (networkInputs.contains(vertexName)) { networkInputs.remove(vertexName); } if (removeConnections) { if (networkOutputs.contains(vertexName)) { networkOutputs.remove(vertexName); } for (Map.Entry<String, List<String>> entry : this.vertexInputs.entrySet()) { List inputs = entry.getValue(); if (inputs.contains(vertexName)) { inputs.remove(vertexName); } } if (inputPreProcessors.containsKey(vertexName)) { inputPreProcessors.remove(vertexName); } } return this; } /** * Specify the inputs to the network, and their associated labels. * * @param inputNames The names of the inputs. This also defines their order */ public GraphBuilder addInputs(String... inputNames) { Collections.addAll(networkInputs, inputNames); return this; } /**Specify the types of inputs to the network, so that:<br> * (a) preprocessors can be automatically added, and<br> * (b) the nIns (input size) for each layer can be automatically calculated and set<br> * The order here is the same order as .addInputs(). Thus, if you do .addInputs("a","b") and .setInputTypes(InputType.feedForward(), * InputType.convolutional(1,28,28)) then the input labelled "a" is a feed forward input, whereas the input labelled "b" in a CNN * input, with 28x28x1 images as input.<br> * <b>Note</b>: Using setInputTypes is not always necessary, but can be especially helpful for example with CNNs such that * the calculations on input/ouput sizes (width, height, depth, etc) don't need to be done manually.<br> * <b>Note 2</b>: If a preprocessor is manually added for a given layer, it will not be overridden by the automatic * addition of preprocessors. * <b>Note 3</b>: If a layer has an nIn set manually, this will not be overridden */ public GraphBuilder setInputTypes(InputType... inputTypes) { if (inputTypes != null && inputTypes.length > 0) Collections.addAll(networkInputTypes, inputTypes); return this; } /** * Set the network output labels. These should be the names of the OutputLayer instances in the network * * @param outputNames The names of the output layers. This also defines their order. */ public GraphBuilder setOutputs(String... outputNames) { Collections.addAll(networkOutputs, outputNames); return this; } /** * Add a {@link GraphVertex} to the network configuration. A GraphVertex defines forward and backward pass methods, * and can contain a {@link LayerVertex}, a {@link org.deeplearning4j.nn.conf.graph.ElementWiseVertex} to do element-wise * addition/subtraction, a {@link MergeVertex} to combine/concatenate the activations out of multiple layers or vertices, * a {@link org.deeplearning4j.nn.conf.graph.SubsetVertex} to select a subset of the activations out of another layer/GraphVertex.<br> * Custom GraphVertex objects (that extend the abstract {@link GraphVertex} class) may also be used. * * @param vertexName The name of the GraphVertex to add * @param vertex The GraphVertex to add * @param vertexInputs The inputs/activations to this GraphVertex */ public GraphBuilder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) { vertices.put(vertexName, vertex); this.vertexInputs.put(vertexName, Arrays.asList(vertexInputs)); return this; } /** * Create the ComputationGraphConfiguration from the Builder pattern */ public ComputationGraphConfiguration build() { ComputationGraphConfiguration conf = new ComputationGraphConfiguration(); conf.backprop = backprop; conf.pretrain = pretrain; conf.backpropType = backpropType; conf.tbpttBackLength = tbpttBackLength; conf.tbpttFwdLength = tbpttFwdLength; conf.networkInputs = networkInputs; conf.networkOutputs = networkOutputs; conf.vertices = this.vertices; conf.vertexInputs = this.vertexInputs; conf.trainingWorkspaceMode = globalConfiguration.trainingWorkspaceMode; conf.inferenceWorkspaceMode = globalConfiguration.inferenceWorkspaceMode; conf.defaultConfiguration = globalConfiguration.build(); conf.getDefaultConfiguration().setPretrain(pretrain); //Add preprocessors that were defined separately to the Layers to which they belong for (Map.Entry<String, InputPreProcessor> entry : inputPreProcessors.entrySet()) { GraphVertex gv = vertices.get(entry.getKey()); if (gv instanceof LayerVertex) { LayerVertex lv = (LayerVertex) gv; lv.setPreProcessor(entry.getValue()); } else { throw new IllegalStateException( "Invalid configuration: InputPreProcessor defined for GraphVertex \"" + entry.getKey() + "\", but this vertex is not a LayerVertex"); } } for (Map.Entry<String, GraphVertex> gv : vertices.entrySet()) { if (gv.getValue() instanceof LayerVertex) { LayerVertex lv = (LayerVertex) gv.getValue(); Layer l = lv.getLayerConf().getLayer(); if (l instanceof BasePretrainNetwork) lv.getLayerConf().setPretrain(pretrain); } } conf.validate(); //throws exception for invalid configuration //Automatically add preprocessors, set nIns for CNN->dense transitions, etc if (!networkInputTypes.isEmpty()) { conf.addPreProcessors(networkInputTypes.toArray(new InputType[networkInputs.size()])); } return conf; } } }