From 71b4611c64059d7589e4d80115209fd2c89e8bdb Mon Sep 17 00:00:00 2001 From: Soham Roy Date: Sun, 18 Apr 2021 15:58:04 +0530 Subject: [PATCH] Update default gym env version to CartPole-v1 (#7079) Version v1 generates a better baseline with higher max_episodes and reward_threshold attained. changed_params --> register( id='CartPole-v1', entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500, reward_threshold=475.0, ) --- pl_examples/domain_templates/reinforce_learn_Qnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 4d90faeb45..70726a7488 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -20,7 +20,7 @@ classic CartPole environment. To run the template, just run: `python reinforce_learn_Qnet.py` -After ~1500 steps, you will see the total_reward hitting the max score of 200. +After ~1500 steps, you will see the total_reward hitting the max score of 475+. Open up TensorBoard to see the metrics: `tensorboard --logdir default` @@ -149,7 +149,7 @@ class Agent: """ Base Agent class handling the interaction with the environment - >>> env = gym.make("CartPole-v0") + >>> env = gym.make("CartPole-v1") >>> buffer = ReplayBuffer(10) >>> Agent(env, buffer) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.Agent object at ...> @@ -229,7 +229,7 @@ class Agent: class DQNLightning(pl.LightningModule): """ Basic DQN Model - >>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> DQNLightning(env="CartPole-v1") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE DQNLightning( (net): DQN( (net): Sequential(...) @@ -393,7 +393,7 @@ class DQNLightning(pl.LightningModule): parser = parent_parser.add_argument_group("DQNLightning") parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") - parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") + parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag") parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network") parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")