package org.deeplearning4j.nn.conf.distribution.serde; import org.deeplearning4j.nn.conf.distribution.*; import org.nd4j.shade.jackson.core.JsonParseException; import org.nd4j.shade.jackson.core.JsonParser; import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.databind.DeserializationContext; import org.nd4j.shade.jackson.databind.JsonDeserializer; import org.nd4j.shade.jackson.databind.JsonNode; import java.io.IOException; /** * Jackson Json deserializer to handle legacy format for distributions.<br> * Now, we use 'type' field which contains class information.<br> * Previously, we used wrapper objects for type information instead (see TestDistributionDeserializer for examples) * * @author Alex Black */ public class LegacyDistributionDeserializer extends JsonDeserializer<Distribution> { @Override public Distribution deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { //Manually parse old format JsonNode node = jp.getCodec().readTree(jp); if(node.has("normal")){ JsonNode n = node.get("normal"); if(!n.has("mean") || !n.has("std")){ throw new JsonParseException("Cannot deserialize Distribution: legacy format 'normal' wrapper object " + " is missing 'mean' or 'std' field", jp.getCurrentLocation()); } double m = n.get("mean").asDouble(); double s = n.get("std").asDouble(); return new NormalDistribution(m,s); } else if(node.has("gaussian")){ JsonNode n = node.get("gaussian"); if(!n.has("mean") || !n.has("std")){ throw new JsonParseException("Cannot deserialize Distribution: legacy format 'gaussian' wrapper object " + " is missing 'mean' or 'std' field", jp.getCurrentLocation()); } double m = n.get("mean").asDouble(); double s = n.get("std").asDouble(); return new GaussianDistribution(m,s); } else if(node.has("uniform")){ JsonNode n = node.get("uniform"); if(!n.has("lower") || !n.has("upper")){ throw new JsonParseException("Cannot deserialize Distribution: legacy format 'uniform' wrapper object " + " is missing 'lower' or 'upper' field", jp.getCurrentLocation()); } double l = n.get("lower").asDouble(); double u = n.get("upper").asDouble(); return new UniformDistribution(l,u); } else if(node.has("binomial")){ JsonNode n = node.get("binomial"); if(!n.has("numberOfTrials") || !n.has("probabilityOfSuccess")){ throw new JsonParseException("Cannot deserialize Distribution: legacy format 'binomial' wrapper object " + " is missing 'lower' or 'upper' field", jp.getCurrentLocation()); } int num = n.get("numberOfTrials").asInt(); double p = n.get("probabilityOfSuccess").asDouble(); return new BinomialDistribution(num, p); } else { throw new JsonParseException("Cannot deserialize Distribution: expected type field or legacy format wrapper" + " object with name being one of {normal, gaussian, uniform, binomial}" , jp.getCurrentLocation()); } } }