package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
/**Embedding layer: feed-forward layer that expects single integers per example as input (class numbers, in range 0 to numClass-1)
* as input. This input has shape [numExamples,1] instead of [numExamples,numClasses] for the equivalent one-hot representation.
* Mathematically, EmbeddingLayer is equivalent to using a DenseLayer with a one-hot representation for the input; however,
* it can be much more efficient with a large number of classes (as a dense layer + one-hot input does a matrix multiply
* with all but one value being zero).<br>
* <b>Note</b>: can only be used as the first layer for a network<br>
* <b>Note 2</b>: For a given example index i, the output is activationFunction(weights.getRow(i) + bias), hence the
* weight rows can be considered a vector/embedding for each example.
* @author Alex Black
*/
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class EmbeddingLayer extends FeedForwardLayer {
private EmbeddingLayer(Builder builder) {
super(builder);
}
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<IterationListener> iterationListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams) {
org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret =
new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(conf);
ret.setListeners(iterationListeners);
ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView);
Map<String, INDArray> paramTable = initializer().init(conf, layerParamsView, initializeParams);
ret.setParamTable(paramTable);
ret.setConf(conf);
return ret;
}
@Override
public ParamInitializer initializer() {
return DefaultParamInitializer.getInstance();
}
@AllArgsConstructor
public static class Builder extends FeedForwardLayer.Builder<Builder> {
@Override
@SuppressWarnings("unchecked")
public EmbeddingLayer build() {
return new EmbeddingLayer(this);
}
}
}