package org.deeplearning4j.nn.updater; import lombok.Getter; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.accum.Norm2; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.conditions.Conditions; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * BaseMultiLayerUpdater - core functionality for applying updaters to MultiLayerNetwork and ComputationGraph. * <p> * This implements updater combining: that is, for any layers (and variables) that:<br> * (a) have contiguous parameters/gradients in the view arrays, and<br> * (b) have identical updater configuration (including updater, LR, LR/momentum schedules, etc - different L1/L2 are OK, * however)<br> * are combined into a single {@link org.nd4j.linalg.learning.GradientUpdater} operation, instead of having a set of * smaller operations. A smaller number of larger operations improves performance, especially for GPUs. * * @author Alex Black */ @Getter public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater { protected final T network; protected Map<String, Layer> layersByName; protected final List<UpdaterBlock> updaterBlocks; protected INDArray updaterStateViewArray; public BaseMultiLayerUpdater(T network) { this(network, null); } /** * * @param network Network to create the updater for * @param updaterState The updater state to use. Note: This array is used *directly* and isn't copied/cloned */ public BaseMultiLayerUpdater(T network, INDArray updaterState) { this.network = network; Layer[] layers = getOrderedLayers(); int updaterStateSize = 0; //Iterate through layers, and variables for each layer. //While the updater configuration is the same: combine into one op, rather than doing a lot of smaller // (yet identical) ops. Layer lastLayer = null; String lastVariable = null; UpdaterBlock currentBlock = null; updaterBlocks = new ArrayList<>(); INDArray paramsView = network.params(); INDArray gradientView = getFlattenedGradientsView(); int paramsViewSoFar = 0; int currentUpdaterOffset = 0; for (int i = 0; i < layers.length; i++) { Map<String, INDArray> layerParamTable = layers[i].paramTable(); List<String> variables = new ArrayList<>(layerParamTable.keySet()); //Is from a set, but iteration order should be fixed per layer as it's a from a LinkedHashSet for (int j = 0; j < variables.size(); j++) { String var = variables.get(j); int paramSizeThisVariable = layerParamTable.get(var).length(); int updaterStateSizeThisVariable = (int)layers[i].conf().getLayer().getIUpdaterByParam(var).stateSize(paramSizeThisVariable); INDArray gradientViewSubset = null; INDArray paramsViewSubset = null; if (paramSizeThisVariable > 0) { paramsViewSubset = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable)); gradientViewSubset = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable)); } //First: decide whether to add to the existing updater block, or create a new one if (currentBlock == null || !UpdaterUtils.updaterConfigurationsEquals(lastLayer, lastVariable, layers[i], var)) { //Create a new block List<UpdaterBlock.ParamState> list = new ArrayList<>(); list.add(new UpdaterBlock.ParamState(layers[i], var, paramsViewSubset, gradientViewSubset)); currentBlock = new UpdaterBlock(paramsViewSoFar, paramsViewSoFar + paramSizeThisVariable, currentUpdaterOffset, currentUpdaterOffset + updaterStateSizeThisVariable, list); updaterBlocks.add(currentBlock); } else { //Add to existing updater block currentBlock.setParamOffsetEnd(currentBlock.getParamOffsetEnd() + paramSizeThisVariable); currentBlock.setUpdaterViewOffsetEnd( currentBlock.getUpdaterViewOffsetEnd() + updaterStateSizeThisVariable); currentBlock.getLayersAndVariablesInBlock().add( new UpdaterBlock.ParamState(layers[i], var, paramsViewSubset, gradientViewSubset)); } lastLayer = layers[i]; lastVariable = variables.get(j); updaterStateSize += updaterStateSizeThisVariable; paramsViewSoFar += paramSizeThisVariable; currentUpdaterOffset += updaterStateSizeThisVariable; } } //Initialize the updater state, if required boolean updaterRequiresInit = false; if (updaterState != null) { updaterStateViewArray = updaterState; updaterRequiresInit = false; } else if (updaterStateSize > 0) { //May be 0 if all SGD or NONE updaters, for example updaterStateViewArray = Nd4j.createUninitialized(new int[] {1, updaterStateSize}, Nd4j.order()); updaterRequiresInit = true; } //Create and set up the updaters, for the updater blocks: int updaterViewSoFar = 0; paramsViewSoFar = 0; for (int i = 0; i < updaterBlocks.size(); i++) { UpdaterBlock ub = updaterBlocks.get(i); int viewStateSize = ub.getUpdaterViewOffsetEnd() - ub.getUpdaterViewOffsetStart(); int gradSize = ub.getParamOffsetEnd() - ub.getParamOffsetStart(); if (viewStateSize > 0) { INDArray updaterViewSubset = updaterStateViewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(updaterViewSoFar, updaterViewSoFar + viewStateSize)); ub.setUpdaterView(updaterViewSubset); ub.setUpdaterViewRequiresInitialization(updaterRequiresInit); } if (gradSize > 0) { INDArray gradientViewSubset = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(paramsViewSoFar, paramsViewSoFar + gradSize)); ub.setGradientView(gradientViewSubset); } ub.init(); updaterViewSoFar += viewStateSize; paramsViewSoFar += gradSize; } } /** * * @return Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening * order - input to output for MultiLayerNetwork, or topological order for ComputationGraph) */ protected abstract Layer[] getOrderedLayers(); /** * @return The flattened gradient view array for the model */ protected abstract INDArray getFlattenedGradientsView(); /** * @return True if the configuration for the model is set to minibatch (divide by minibatch size), false otherwise */ protected abstract boolean isMiniBatch(); /** * Set the view array. Note that this does an assign operation - the provided array is not stored internally. * * @param viewArray The new updater state */ public void setStateViewArray(INDArray viewArray) { if (this.updaterStateViewArray.length() != viewArray.length()) throw new IllegalStateException("Invalid input: view arrays differ in length. " + "Expected length " + this.updaterStateViewArray.length() + ", got length " + viewArray.length()); this.updaterStateViewArray.assign(viewArray); } @Override public void setStateViewArray(Layer layer, INDArray viewArray, boolean initialize) { this.setStateViewArray(viewArray); } @Override public INDArray getStateViewArray() { return updaterStateViewArray; } @Override public void update(Layer layer, Gradient gradient, int iteration, int batchSize) { update(gradient, iteration, batchSize); } /** * Update the gradient for the model. * This operates in 3 steps: * 1. Pre-apply: gradient clipping, etc on a per-layer basis * 2. Execute the updater (Adam, Nesterov momentum, etc) - in blocks of layers at a time * 3. Divide by minibatch size * * @param gradient Gradient to updater * @param iteration The current iteration (i.e., number of parameter updates so far) * @param batchSize The current minibatch size (number of examples) */ public void update(Gradient gradient, int iteration, int batchSize) { //Split up the gradients on a per-layer basis, for pre-apply Map<String, Gradient> layerGradients = new HashMap<>(); Layer[] layers = getOrderedLayers(); if (layers.length == 1 && isSingleLayerUpdater()) { layerGradients.put(layers[0].conf().getLayer().getLayerName(), gradient); // layerGradients.put("0", gradient); } else { for (Map.Entry<String, INDArray> gradientPair : gradient.gradientForVariable().entrySet()) { String key = gradientPair.getKey(); int idx = key.lastIndexOf('_'); if (idx == -1) throw new IllegalStateException( "Invalid key: Gradient key does not have layer separator: \"" + key + "\""); String layerName = key.substring(0, idx); Gradient g = layerGradients.get(layerName); if (g == null) { g = new DefaultGradient(); layerGradients.put(layerName, g); } String newKey = key.substring(idx + 1); g.setGradientFor(newKey, gradientPair.getValue()); } } //PRE apply (gradient clipping, etc): done on a per-layer basis for (Map.Entry<String, Gradient> entry : layerGradients.entrySet()) { String layerName = entry.getKey(); Layer layer = layersByName.get(layerName); preApply(layer, layerGradients.get(layerName), iteration); } //Apply the updaters in blocks. This also applies LR and momentum schedules, L1 and L2 // for (UpdaterBlock ub : updaterBlocks) { if (ub.skipDueToPretrainConfig()) { //Should skip some updater blocks sometimes //For example, VAE decoder params while doing supervised backprop continue; } if (Nd4j.getWorkspaceManager().checkIfWorkspaceExists(ComputationGraph.workspaceFeedForward)) { try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager() .getAndActivateWorkspace(ComputationGraph.workspaceFeedForward)) { ub.update(iteration); } } else { ub.update(iteration); } } //Divide by minibatch size if necessary if (isMiniBatch()) { //OK even with pretrain layers: their gradients will get modified during next backprop iteration getFlattenedGradientsView().divi(batchSize); } } protected boolean isSingleLayerUpdater() { return false; } /** * Pre-apply: Apply gradient normalization/clipping * * @param layer Layer to apply gradient normalization/clipping for * @param gradient Gradient to update * @param iteration The current iteration (i.e., number of parameter updates so far) */ public void preApply(Layer layer, Gradient gradient, int iteration) { GradientNormalization normalization = layer.conf().getLayer().getGradientNormalization(); if (normalization == null || normalization == GradientNormalization.None || layer.conf().isPretrain()) return; //no op final double threshold = layer.conf().getLayer().getGradientNormalizationThreshold(); INDArray layerGradientView = layer.getGradientsViewArray(); switch (normalization) { case RenormalizeL2PerLayer: if(layerGradientView != null){ double l2 = layerGradientView.norm2Number().doubleValue(); layerGradientView.divi(l2); } break; case RenormalizeL2PerParamType: for (INDArray g : gradient.gradientForVariable().values()) { double l2 = Nd4j.getExecutioner().execAndReturn(new Norm2(g)).getFinalResult().doubleValue(); g.divi(l2); } break; case ClipElementWiseAbsoluteValue: if(layerGradientView != null){ BooleanIndexing.replaceWhere(layerGradientView, threshold, Conditions.greaterThan(threshold)); BooleanIndexing.replaceWhere(layerGradientView, -threshold, Conditions.lessThan(-threshold)); } break; case ClipL2PerLayer: if(layerGradientView != null){ double layerL2 = layerGradientView.norm2Number().doubleValue(); if (layerL2 > threshold) { double scalingFactor = threshold / layerL2; // g = g / l2 * threshold -> layerGradientView.muli(scalingFactor); } } break; case ClipL2PerParamType: for (INDArray g : gradient.gradientForVariable().values()) { double l2 = g.norm2Number().doubleValue(); if (l2 > threshold) { double scalingFactor = l2 / threshold; g.divi(scalingFactor); } } break; default: throw new RuntimeException( "Unknown (or not implemented) gradient normalization strategy: " + normalization); } } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; BaseMultiLayerUpdater<?> that = (BaseMultiLayerUpdater<?>) o; return updaterStateViewArray != null ? updaterStateViewArray.equals(that.updaterStateViewArray) : that.updaterStateViewArray == null; } @Override public int hashCode() { int result = layersByName != null ? layersByName.hashCode() : 0; result = 31 * result + (updaterBlocks != null ? updaterBlocks.hashCode() : 0); result = 31 * result + (updaterStateViewArray != null ? updaterStateViewArray.hashCode() : 0); return result; } }