package org.deeplearning4j.nn.updater; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; /** * * * @author Adam Gibson */ public class UpdaterCreator { private UpdaterCreator() {} public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) { if (layer instanceof MultiLayerNetwork) { return new MultiLayerUpdater((MultiLayerNetwork) layer); } else if (layer instanceof ComputationGraph) { return new ComputationGraphUpdater((ComputationGraph) layer); } else { return new LayerUpdater((Layer) layer); } } }