package org.deeplearning4j.examples.rl4j;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscreteDense;
import org.deeplearning4j.rl4j.mdp.gym.GymEnv;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.util.DataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/18/16.
*
* main example for Async NStep QLearning on cartpole
*/
public class AsyncNStepCartpole {
public static AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration CARTPOLE_NSTEP =
new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(
123, //Random seed
200, //Max step By epoch
300000, //Max step
16, //Number of threads
5, //t_max
100, //target update (hard)
10, //num step noop warmup
0.01, //reward scaling
0.99, //gamma
100.0, //td-error clipping
0.1f, //min epsilon
9000 //num step for eps greedy anneal
);
public static DQNFactoryStdDense.Configuration CARTPOLE_NET_NSTEP =
DQNFactoryStdDense.Configuration.builder()
.l2(0.001).learningRate(0.0005).numHiddenNodes(16).numLayer(3).build();
public static void main( String[] args )
{
cartPole();
}
public static void cartPole() {
//record the training data in rl4j-data in a new folder
DataManager manager = new DataManager(true);
//define the mdp from gym (name, render)
GymEnv mdp = null;
try {
mdp = new GymEnv("CartPole-v0", false, false);
} catch (RuntimeException e){
System.out.print("To run this example, download and start the gym-http-api repo found at https://github.com/openai/gym-http-api.");
}
//define the training
AsyncNStepQLearningDiscreteDense<Box> dql = new AsyncNStepQLearningDiscreteDense<Box>(mdp, CARTPOLE_NET_NSTEP, CARTPOLE_NSTEP, manager);
//train
dql.train();
//close the mdp (close connection)
mdp.close();
}
}