package org.deeplearning4j.nn.conf.serde;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.nd4j.linalg.learning.config.*;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonMappingException;
import org.nd4j.shade.jackson.databind.deser.ResolvableDeserializer;
import org.nd4j.shade.jackson.databind.deser.std.StdDeserializer;
import java.io.IOException;
import java.util.Map;
/**
* A custom (abstract) deserializer that handles backward compatibility (currently only for updater refactoring that
* happened after 0.8.0). This is used for both MultiLayerConfiguration and ComputationGraphConfiguration.<br>
* We deserialize the config using the default deserializer, then handle the new IUpdater (which will be null for
* 0.8.0 and earlier configs) if necessary
*
* Overall design: http://stackoverflow.com/questions/18313323/how-do-i-call-the-default-deserializer-from-a-custom-deserializer-in-jackson
*
* @author Alex Black
*/
public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> implements ResolvableDeserializer {
protected final JsonDeserializer<?> defaultDeserializer;
public BaseNetConfigDeserializer(JsonDeserializer<?> defaultDeserializer, Class<T> deserializedType){
super(deserializedType);
this.defaultDeserializer = defaultDeserializer;
}
@Override
public abstract T deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException, JsonProcessingException;
protected void handleUpdaterBackwardCompatibility(Layer[] layers){
//Updater configuration changed after 0.8.0 release
//Previously: enumerations and a bunch of fields. Now: classes
//Here, we manually create the appropriate Updater instances, if the iupdater field is empty
for( int i=0; i<layers.length; i++ ){
Layer l = layers[i];
if(l == null || l.getIUpdater() != null){
//OK - no need to manually handle IUpdater instances for this layer
continue;
}
Updater u = l.getUpdater();
double lr = l.getLearningRate();
double eps = l.getEpsilon();
double rho = l.getRho();
switch (u){
case SGD:
l.setIUpdater(new Sgd(lr));
break;
case ADAM:
double meanDecay = l.getAdamMeanDecay();
double varDecay = l.getAdamVarDecay();
l.setIUpdater(Adam.builder().learningRate(lr)
.beta1(meanDecay).beta2(varDecay).epsilon(eps).build());
break;
case ADADELTA:
l.setIUpdater(new AdaDelta(rho, eps));
break;
case NESTEROVS:
Map<Integer,Double> momentumSchedule = l.getMomentumSchedule();
double momentum = l.getMomentum();
l.setIUpdater(new Nesterovs(lr, momentum, momentumSchedule));
break;
case ADAGRAD:
l.setIUpdater(new AdaGrad(lr, eps));
break;
case RMSPROP:
double rmsDecay = l.getRmsDecay();
l.setIUpdater(new RmsProp(lr, rmsDecay, eps));
break;
case NONE:
l.setIUpdater(new NoOp());
break;
case CUSTOM:
//No op - shouldn't happen
break;
}
}
}
@Override
public void resolve(DeserializationContext ctxt) throws JsonMappingException {
((ResolvableDeserializer) defaultDeserializer).resolve(ctxt);
}
}