Update default gym env version to CartPole-v1 (#7079)

Version v1 generates a better baseline with higher max_episodes and reward_threshold attained.

changed_params -->

register(
    id='CartPole-v1',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    max_episode_steps=500,
    reward_threshold=475.0,
)
This commit is contained in:
Soham Roy 2021-04-18 15:58:04 +05:30 committed by GitHub
parent 97be843226
commit 71b4611c64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -20,7 +20,7 @@ classic CartPole environment.
To run the template, just run:
`python reinforce_learn_Qnet.py`
After ~1500 steps, you will see the total_reward hitting the max score of 200.
After ~1500 steps, you will see the total_reward hitting the max score of 475+.
Open up TensorBoard to see the metrics:
`tensorboard --logdir default`
@ -149,7 +149,7 @@ class Agent:
"""
Base Agent class handling the interaction with the environment
>>> env = gym.make("CartPole-v0")
>>> env = gym.make("CartPole-v1")
>>> buffer = ReplayBuffer(10)
>>> Agent(env, buffer) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.Agent object at ...>
@ -229,7 +229,7 @@ class Agent:
class DQNLightning(pl.LightningModule):
""" Basic DQN Model
>>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
>>> DQNLightning(env="CartPole-v1") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
DQNLightning(
(net): DQN(
(net): Sequential(...)
@ -393,7 +393,7 @@ class DQNLightning(pl.LightningModule):
parser = parent_parser.add_argument_group("DQNLightning")
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag")
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")