package org.deeplearning4j.examples.rl4j; import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete; import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscreteDense; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteDense; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.toy.HardDeteministicToy; import org.deeplearning4j.rl4j.mdp.toy.SimpleToy; import org.deeplearning4j.rl4j.mdp.toy.SimpleToyState; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/11/16. * * main example for toy DQN * */ public class Toy { public static QLearning.QLConfiguration TOY_QL = new QLearning.QLConfiguration( 123, //Random seed 100000,//Max step By epoch 80000, //Max step 10000, //Max size of experience replay 32, //size of batches 100, //target update (hard) 0, //num step noop warmup 0.05, //reward scaling 0.99, //gamma 10.0, //td-error clipping 0.1f, //min epsilon 2000, //num step for eps greedy anneal true //double DQN ); public static AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration TOY_ASYNC_QL = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration( 123, //Random seed 100000, //Max step By epoch 80000, //Max step 8, //Number of threads 5, //t_max 100, //target update (hard) 0, //num step noop warmup 0.1, //reward scaling 0.99, //gamma 10.0, //td-error clipping 0.1f, //min epsilon 2000 //num step for eps greedy anneal ); public static DQNFactoryStdDense.Configuration TOY_NET = DQNFactoryStdDense.Configuration.builder() .l2(0.01).learningRate(1e-2).numLayer(3).numHiddenNodes(16).build(); public static void main(String[] args ) { simpleToy(); //toyAsyncNstep(); } public static void simpleToy() { //record the training data in rl4j-data in a new folder DataManager manager = new DataManager(); //define the mdp from toy (toy length) SimpleToy mdp = new SimpleToy(20); //define the training method Learning<SimpleToyState, Integer, DiscreteSpace, IDQN> dql = new QLearningDiscreteDense<SimpleToyState>(mdp, TOY_NET, TOY_QL, manager); //enable some logging for debug purposes on toy mdp mdp.setFetchable(dql); //start the training dql.train(); //useless on toy but good practice! mdp.close(); } public static void hardToy() { //record the training data in rl4j-data in a new folder DataManager manager = new DataManager(); //define the mdp from toy (toy length) MDP mdp = new HardDeteministicToy(); //define the training ILearning<SimpleToyState, Integer, DiscreteSpace> dql = new QLearningDiscreteDense(mdp, TOY_NET, TOY_QL, manager); //start the training dql.train(); //useless on toy but good practice! mdp.close(); } public static void toyAsyncNstep() { //record the training data in rl4j-data in a new folder DataManager manager = new DataManager(); //define the mdp SimpleToy mdp = new SimpleToy(20); //define the training AsyncNStepQLearningDiscreteDense dql = new AsyncNStepQLearningDiscreteDense<SimpleToyState>(mdp, TOY_NET, TOY_ASYNC_QL, manager); //enable some logging for debug purposes on toy mdp mdp.setFetchable(dql); //start the training dql.train(); //useless on toy but good practice! mdp.close(); } }