/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * */ package org.deeplearning4j.optimize.solvers; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.optimize.api.IterationListener; import org.deeplearning4j.optimize.api.StepFunction; import org.deeplearning4j.optimize.api.TerminationCondition; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Collection; /** * Stochastic Gradient Descent * Standard fix step size * No line search * @author Adam Gibson */ public class StochasticGradientDescent extends BaseOptimizer { public StochasticGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) { super(conf, stepFunction, iterationListeners, model); } public StochasticGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) { super(conf, stepFunction, iterationListeners, terminationConditions, model); } @Override public boolean optimize() { for (int i = 0; i < conf.getNumIterations(); i++) { //long time1 = System.currentTimeMillis(); Pair<Gradient, Double> pair = gradientAndScore(); //Nd4j.getExecutioner().commit(); //long time2 = System.currentTimeMillis(); Gradient gradient = pair.getFirst(); INDArray params = model.params(); stepFunction.step(params, gradient.gradient()); //Nd4j.getExecutioner().commit(); //long time3 = System.currentTimeMillis(); //Note: model.params() is always in-place for MultiLayerNetwork and ComputationGraph, hence no setParams is necessary there //However: for pretrain layers, params are NOT a view. Thus a setParams call is necessary //But setParams should be a no-op for MLN and CG model.setParams(params); int iterationCount = BaseOptimizer.getIterationCount(model); try(MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { for (IterationListener listener : iterationListeners) listener.iterationDone(model, iterationCount); } //Nd4j.getExecutioner().commit(); //long time4 = System.currentTimeMillis(); checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i); //Nd4j.getExecutioner().commit(); //long time5 = System.currentTimeMillis(); BaseOptimizer.incrementIterationCount(model, 1); //Nd4j.getExecutioner().commit(); //long time6 = System.currentTimeMillis(); //log.info("GradientAndScore time: {} ms; Step time: {} ms; Listeners time: {} ms; Stuff time: {} ms; Increment time: {} ms;", time2 - time1, time3 - time2, time4 - time3, time5 - time4, time6 - time5); } return true; } @Override public void preProcessLine() {} @Override public void postStep(INDArray gradient) {} }