package org.deeplearning4j.examples.rl4j; import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscrete; import org.deeplearning4j.rl4j.learning.async.a3c.discrete.A3CDiscreteDense; import org.deeplearning4j.rl4j.mdp.gym.GymEnv; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactorySeparate; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactorySeparateStdDense; import org.deeplearning4j.rl4j.util.DataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/18/16. * * main example for A3C on cartpole * */ public class A3CCartpole { private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C = new A3CDiscrete.A3CConfiguration( 123, //Random seed 200, //Max step By epoch 500000, //Max step 16, //Number of threads 5, //t_max 10, //num step noop warmup 0.01, //reward scaling 0.99, //gamma 10.0 //td-error clipping ); private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration .builder().learningRate(1e-2).l2(0).numHiddenNodes(16).numLayer(3).build(); public static void main( String[] args ) { A3CcartPole(); } public static void A3CcartPole() { //record the training data in rl4j-data in a new folder DataManager manager = new DataManager(true); //define the mdp from gym (name, render) GymEnv mdp = null; try { mdp = new GymEnv("CartPole-v0", false, false); } catch (RuntimeException e){ System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api."); } //define the training A3CDiscreteDense<Box> dql = new A3CDiscreteDense<Box>(mdp, CARTPOLE_NET_A3C, CARTPOLE_A3C, manager); //start the training dql.train(); //close the mdp (http connection) mdp.close(); } }