package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.IterationListener; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; /**Embedding layer: feed-forward layer that expects single integers per example as input (class numbers, in range 0 to numClass-1) * as input. This input has shape [numExamples,1] instead of [numExamples,numClasses] for the equivalent one-hot representation. * Mathematically, EmbeddingLayer is equivalent to using a DenseLayer with a one-hot representation for the input; however, * it can be much more efficient with a large number of classes (as a dense layer + one-hot input does a matrix multiply * with all but one value being zero).<br> * <b>Note</b>: can only be used as the first layer for a network<br> * <b>Note 2</b>: For a given example index i, the output is activationFunction(weights.getRow(i) + bias), hence the * weight rows can be considered a vector/embedding for each example. * @author Alex Black */ @Data @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class EmbeddingLayer extends FeedForwardLayer { private EmbeddingLayer(Builder builder) { super(builder); } @Override public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) { org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret = new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(conf); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams); ret.setParamTable(paramTable); ret.setConf(conf); return ret; } @Override public ParamInitializer initializer() { return DefaultParamInitializer.getInstance(); } @AllArgsConstructor public static class Builder extends FeedForwardLayer.Builder<Builder> { @Override @SuppressWarnings("unchecked") public EmbeddingLayer build() { return new EmbeddingLayer(this); } } }