package org.deeplearning4j.nn.conf.layers; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.util.LayerValidation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import java.util.Collection; import java.util.Map; @Data @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class RnnOutputLayer extends BaseOutputLayer { private RnnOutputLayer(Builder builder) { super(builder); } @Override public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) { LayerValidation.assertNInNOutSet("RnnOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer ret = new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(conf); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams); ret.setParamTable(paramTable); ret.setConf(conf); return ret; } @Override public ParamInitializer initializer() { return DefaultParamInitializer.getInstance(); } @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer index = " + layerIndex + ", layer name=\"" + getLayerName() + "\"): Expected RNN input, got " + inputType); } return InputType.recurrent(nOut); } @Override public void setNIn(InputType inputType, boolean override) { if (inputType == null || inputType.getType() != InputType.Type.RNN) { throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer name=\"" + getLayerName() + "\"): Expected RNN input, got " + inputType); } if (nIn <= 0 || override) { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); } public static class Builder extends BaseOutputLayer.Builder<Builder> { public Builder() { } public Builder(LossFunction lossFunction) { lossFunction(lossFunction); } public Builder(ILossFunction lossFunction) { this.lossFn = lossFunction; } @Override @SuppressWarnings("unchecked") public RnnOutputLayer build() { return new RnnOutputLayer(this); } } }