Update default gym env version to CartPole-v1 (#7079)
Version v1 generates a better baseline with higher max_episodes and reward_threshold attained. changed_params --> register( id='CartPole-v1', entry_point='gym.envs.classic_control:CartPoleEnv', max_episode_steps=500, reward_threshold=475.0, )
This commit is contained in:
parent
97be843226
commit
71b4611c64
|
@ -20,7 +20,7 @@ classic CartPole environment.
|
|||
To run the template, just run:
|
||||
`python reinforce_learn_Qnet.py`
|
||||
|
||||
After ~1500 steps, you will see the total_reward hitting the max score of 200.
|
||||
After ~1500 steps, you will see the total_reward hitting the max score of 475+.
|
||||
Open up TensorBoard to see the metrics:
|
||||
|
||||
`tensorboard --logdir default`
|
||||
|
@ -149,7 +149,7 @@ class Agent:
|
|||
"""
|
||||
Base Agent class handling the interaction with the environment
|
||||
|
||||
>>> env = gym.make("CartPole-v0")
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> buffer = ReplayBuffer(10)
|
||||
>>> Agent(env, buffer) # doctest: +ELLIPSIS
|
||||
<...reinforce_learn_Qnet.Agent object at ...>
|
||||
|
@ -229,7 +229,7 @@ class Agent:
|
|||
class DQNLightning(pl.LightningModule):
|
||||
""" Basic DQN Model
|
||||
|
||||
>>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
>>> DQNLightning(env="CartPole-v1") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
DQNLightning(
|
||||
(net): DQN(
|
||||
(net): Sequential(...)
|
||||
|
@ -393,7 +393,7 @@ class DQNLightning(pl.LightningModule):
|
|||
parser = parent_parser.add_argument_group("DQNLightning")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
|
||||
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
|
||||
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
|
||||
parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag")
|
||||
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
|
||||
parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network")
|
||||
parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")
|
||||
|
|
Loading…
Reference in New Issue