package org.deeplearning4j.nn.params;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* Parameter initializer for the Variational Autoencoder model.
*
* See: Kingma & Welling, 2013: Auto-Encoding Variational Bayes - https://arxiv.org/abs/1312.6114
*
* @author Alex Black
*/
public class VariationalAutoencoderParamInitializer extends DefaultParamInitializer {
private static final VariationalAutoencoderParamInitializer INSTANCE = new VariationalAutoencoderParamInitializer();
public static VariationalAutoencoderParamInitializer getInstance() {
return INSTANCE;
}
public static final String WEIGHT_KEY_SUFFIX = "W";
public static final String BIAS_KEY_SUFFIX = "b";
public static final String PZX_PREFIX = "pZX";
public static final String PZX_MEAN_PREFIX = PZX_PREFIX + "Mean";
public static final String PZX_LOGSTD2_PREFIX = PZX_PREFIX + "LogStd2";
public static final String ENCODER_PREFIX = "e";
public static final String DECODER_PREFIX = "d";
/** Key for weight parameters connecting the last encoder layer and the mean values for p(z|data) */
public static final String PZX_MEAN_W = "pZXMean" + WEIGHT_KEY_SUFFIX;
/** Key for bias parameters for the mean values for p(z|data) */
public static final String PZX_MEAN_B = "pZXMean" + BIAS_KEY_SUFFIX;
/** Key for weight parameters connecting the last encoder layer and the log(sigma^2) values for p(z|data) */
public static final String PZX_LOGSTD2_W = PZX_LOGSTD2_PREFIX + WEIGHT_KEY_SUFFIX;
/** Key for bias parameters for log(sigma^2) in p(z|data) */
public static final String PZX_LOGSTD2_B = PZX_LOGSTD2_PREFIX + BIAS_KEY_SUFFIX;
public static final String PXZ_PREFIX = "pXZ";
/** Key for weight parameters connecting the last decoder layer and p(data|z) (according to whatever
* {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution} is set for the VAE) */
public static final String PXZ_W = PXZ_PREFIX + WEIGHT_KEY_SUFFIX;
/** Key for bias parameters connecting the last decoder layer and p(data|z) (according to whatever
* {@link org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution} is set for the VAE) */
public static final String PXZ_B = PXZ_PREFIX + BIAS_KEY_SUFFIX;
@Override
public int numParams(NeuralNetConfiguration conf) {
VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
int nIn = layer.getNIn();
int nOut = layer.getNOut();
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
int paramCount = 0;
for (int i = 0; i < encoderLayerSizes.length; i++) {
int encoderLayerIn;
if (i == 0) {
encoderLayerIn = nIn;
} else {
encoderLayerIn = encoderLayerSizes[i - 1];
}
paramCount += (encoderLayerIn + 1) * encoderLayerSizes[i]; //weights + bias
}
//Between the last encoder layer and the parameters for p(z|x):
int lastEncLayerSize = encoderLayerSizes[encoderLayerSizes.length - 1];
paramCount += (lastEncLayerSize + 1) * 2 * nOut; //Mean and variance parameters used in unsupervised training
//Decoder:
for (int i = 0; i < decoderLayerSizes.length; i++) {
int decoderLayerNIn;
if (i == 0) {
decoderLayerNIn = nOut;
} else {
decoderLayerNIn = decoderLayerSizes[i - 1];
}
paramCount += (decoderLayerNIn + 1) * decoderLayerSizes[i];
}
//Between last decoder layer and parameters for p(x|z):
int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
int lastDecLayerSize = decoderLayerSizes[decoderLayerSizes.length - 1];
paramCount += (lastDecLayerSize + 1) * nDistributionParams;
return paramCount;
}
@Override
public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
if (paramsView.length() != numParams(conf)) {
throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(conf)
+ ", got length " + paramsView.length());
}
Map<String, INDArray> ret = new LinkedHashMap<>();
VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
int nIn = layer.getNIn();
int nOut = layer.getNOut();
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
WeightInit weightInit = layer.getWeightInit();
Distribution dist = Distributions.createDistribution(layer.getDist());
int soFar = 0;
for (int i = 0; i < encoderLayerSizes.length; i++) {
int encoderLayerNIn;
if (i == 0) {
encoderLayerNIn = nIn;
} else {
encoderLayerNIn = encoderLayerSizes[i - 1];
}
int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
INDArray weightView = paramsView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + weightParamCount));
soFar += weightParamCount;
INDArray biasView = paramsView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
soFar += encoderLayerSizes[i];
INDArray layerWeights = createWeightMatrix(encoderLayerNIn, encoderLayerSizes[i], weightInit, dist,
weightView, initializeParams);
INDArray layerBiases = createBias(encoderLayerSizes[i], 0.0, biasView, initializeParams); //TODO don't hardcode 0
String sW = "e" + i + WEIGHT_KEY_SUFFIX;
String sB = "e" + i + BIAS_KEY_SUFFIX;
ret.put(sW, layerWeights);
ret.put(sB, layerBiases);
conf.addVariable(sW);
conf.addVariable(sB);
}
//Last encoder layer -> p(z|x)
int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
INDArray pzxWeightsMean =
paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
soFar += nWeightsPzx;
INDArray pzxBiasMean = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
soFar += nOut;
INDArray pzxWeightsMeanReshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut,
weightInit, dist, pzxWeightsMean, initializeParams);
INDArray pzxBiasMeanReshaped = createBias(nOut, 0.0, pzxBiasMean, initializeParams); //TODO don't hardcode 0
ret.put(PZX_MEAN_W, pzxWeightsMeanReshaped);
ret.put(PZX_MEAN_B, pzxBiasMeanReshaped);
conf.addVariable(PZX_MEAN_W);
conf.addVariable(PZX_MEAN_B);
//Pretrain params
INDArray pzxWeightsLogStdev2 =
paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
soFar += nWeightsPzx;
INDArray pzxBiasLogStdev2 = paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
soFar += nOut;
INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut,
weightInit, dist, pzxWeightsLogStdev2, initializeParams);
INDArray pzxBiasLogStdev2Reshaped = createBias(nOut, 0.0, pzxBiasLogStdev2, initializeParams); //TODO don't hardcode 0
ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2Reshaped);
conf.addVariable(PZX_LOGSTD2_W);
conf.addVariable(PZX_LOGSTD2_B);
for (int i = 0; i < decoderLayerSizes.length; i++) {
int decoderLayerNIn;
if (i == 0) {
decoderLayerNIn = nOut;
} else {
decoderLayerNIn = decoderLayerSizes[i - 1];
}
int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
INDArray weightView = paramsView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + weightParamCount));
soFar += weightParamCount;
INDArray biasView = paramsView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
soFar += decoderLayerSizes[i];
INDArray layerWeights = createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], weightInit, dist,
weightView, initializeParams);
INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, initializeParams); //TODO don't hardcode 0
String sW = "d" + i + WEIGHT_KEY_SUFFIX;
String sB = "d" + i + BIAS_KEY_SUFFIX;
ret.put(sW, layerWeights);
ret.put(sB, layerBiases);
conf.addVariable(sW);
conf.addVariable(sB);
}
//Finally, p(x|z):
int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
INDArray pxzWeightView =
paramsView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
soFar += pxzWeightCount;
INDArray pxzBiasView = paramsView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + nDistributionParams));
INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1],
nDistributionParams, weightInit, dist, pxzWeightView, initializeParams);
INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, initializeParams); //TODO don't hardcode 0
ret.put(PXZ_W, pxzWeightsReshaped);
ret.put(PXZ_B, pxzBiasReshaped);
conf.addVariable(PXZ_W);
conf.addVariable(PXZ_B);
return ret;
}
@Override
public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
Map<String, INDArray> ret = new LinkedHashMap<>();
VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer();
int nIn = layer.getNIn();
int nOut = layer.getNOut();
int[] encoderLayerSizes = layer.getEncoderLayerSizes();
int[] decoderLayerSizes = layer.getDecoderLayerSizes();
int soFar = 0;
for (int i = 0; i < encoderLayerSizes.length; i++) {
int encoderLayerNIn;
if (i == 0) {
encoderLayerNIn = nIn;
} else {
encoderLayerNIn = encoderLayerSizes[i - 1];
}
int weightParamCount = encoderLayerNIn * encoderLayerSizes[i];
INDArray weightGradView = gradientView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + weightParamCount));
soFar += weightParamCount;
INDArray biasGradView = gradientView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + encoderLayerSizes[i]));
soFar += encoderLayerSizes[i];
INDArray layerWeights = weightGradView.reshape('f', encoderLayerNIn, encoderLayerSizes[i]);
INDArray layerBiases = biasGradView; //Aready correct shape (row vector)
ret.put("e" + i + WEIGHT_KEY_SUFFIX, layerWeights);
ret.put("e" + i + BIAS_KEY_SUFFIX, layerBiases);
}
//Last encoder layer -> p(z|x)
int nWeightsPzx = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
INDArray pzxWeightsMean =
gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
soFar += nWeightsPzx;
INDArray pzxBiasMean = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
soFar += nOut;
INDArray pzxWeightGradMeanReshaped =
pzxWeightsMean.reshape('f', encoderLayerSizes[encoderLayerSizes.length - 1], nOut);
ret.put(PZX_MEAN_W, pzxWeightGradMeanReshaped);
ret.put(PZX_MEAN_B, pzxBiasMean);
////////////////////////////////////////////////////////
INDArray pzxWeightsLogStdev2 =
gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nWeightsPzx));
soFar += nWeightsPzx;
INDArray pzxBiasLogStdev2 = gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + nOut));
soFar += nOut;
INDArray pzxWeightsLogStdev2Reshaped = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut,
null, null, pzxWeightsLogStdev2, false); //TODO
ret.put(PZX_LOGSTD2_W, pzxWeightsLogStdev2Reshaped);
ret.put(PZX_LOGSTD2_B, pzxBiasLogStdev2);
for (int i = 0; i < decoderLayerSizes.length; i++) {
int decoderLayerNIn;
if (i == 0) {
decoderLayerNIn = nOut;
} else {
decoderLayerNIn = decoderLayerSizes[i - 1];
}
int weightParamCount = decoderLayerNIn * decoderLayerSizes[i];
INDArray weightView = gradientView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + weightParamCount));
soFar += weightParamCount;
INDArray biasView = gradientView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + decoderLayerSizes[i]));
soFar += decoderLayerSizes[i];
INDArray layerWeights =
createWeightMatrix(decoderLayerNIn, decoderLayerSizes[i], null, null, weightView, false);
INDArray layerBiases = createBias(decoderLayerSizes[i], 0.0, biasView, false); //TODO don't hardcode 0
String sW = "d" + i + WEIGHT_KEY_SUFFIX;
String sB = "d" + i + BIAS_KEY_SUFFIX;
ret.put(sW, layerWeights);
ret.put(sB, layerBiases);
}
//Finally, p(x|z):
int nDistributionParams = layer.getOutputDistribution().distributionInputSize(nIn);
int pxzWeightCount = decoderLayerSizes[decoderLayerSizes.length - 1] * nDistributionParams;
INDArray pxzWeightView =
gradientView.get(NDArrayIndex.point(0), NDArrayIndex.interval(soFar, soFar + pxzWeightCount));
soFar += pxzWeightCount;
INDArray pxzBiasView = gradientView.get(NDArrayIndex.point(0),
NDArrayIndex.interval(soFar, soFar + nDistributionParams));
INDArray pxzWeightsReshaped = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1],
nDistributionParams, null, null, pxzWeightView, false);
INDArray pxzBiasReshaped = createBias(nDistributionParams, 0.0, pxzBiasView, false);
ret.put(PXZ_W, pxzWeightsReshaped);
ret.put(PXZ_B, pxzBiasReshaped);
return ret;
}
}