lightning/examples/pytorch/domain_templates/reinforce_learn_Qnet.py

408 lines
13 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deep Reinforcement Learning: Deep Q-network (DQN)
The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the
classic CartPole environment.
To run the template, just run:
`python reinforce_learn_Qnet.py`
After ~1500 steps, you will see the total_reward hitting the max score of 475+.
Open up TensorBoard to see the metrics:
`tensorboard --logdir default`
References
----------
[1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
"""
import argparse
import random
from collections import OrderedDict, deque, namedtuple
from collections.abc import Iterator
import gym
import torch
import torch.nn as nn
import torch.optim as optim
from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
class DQN(nn.Module):
"""Simple MLP network.
>>> DQN(10, 5) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
DQN(
(net): Sequential(...)
)
"""
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
"""
Args:
obs_size: observation/state size of the environment
n_actions: number of discrete actions available in the environment
hidden_size: size of hidden layers
"""
super().__init__()
self.net = nn.Sequential(nn.Linear(obs_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, n_actions))
def forward(self, x):
return self.net(x.float())
# Named tuple for storing experience steps gathered in training
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"])
class ReplayBuffer:
"""Replay Buffer for storing past experiences allowing the agent to learn from them.
>>> ReplayBuffer(5) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.ReplayBuffer object at ...>
"""
def __init__(self, capacity: int) -> None:
"""
Args:
capacity: size of the buffer
"""
self.buffer = deque(maxlen=capacity)
def __len__(self) -> int:
return len(self.buffer)
def append(self, experience: Experience) -> None:
"""Add experience to the buffer.
Args:
experience: tuple (state, action, reward, done, new_state)
"""
self.buffer.append(experience)
def sample(self, batch_size: int) -> tuple:
indices = random.sample(range(len(self.buffer)), batch_size)
states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))
return (
torch.tensor(states),
torch.tensor(actions),
torch.tensor(rewards, dtype=torch.float32),
torch.tensor(dones, dtype=torch.bool),
torch.tensor(next_states),
)
class RLDataset(IterableDataset):
"""Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.
>>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.RLDataset object at ...>
"""
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
"""
Args:
buffer: replay buffer
sample_size: number of experiences to sample at a time
"""
self.buffer = buffer
self.sample_size = sample_size
def __iter__(self) -> Iterator:
states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
for i in range(len(dones)):
yield states[i], actions[i], rewards[i], dones[i], new_states[i]
class Agent:
"""Base Agent class handling the interaction with the environment.
>>> env = gym.make("CartPole-v1")
>>> buffer = ReplayBuffer(10)
>>> Agent(env, buffer) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.Agent object at ...>
"""
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
"""
Args:
env: training environment
replay_buffer: replay buffer storing experiences
"""
self.env = env
self.replay_buffer = replay_buffer
self.reset()
self.state = self.env.reset()
def reset(self) -> None:
"""Resets the environment and updates the state."""
self.state = self.env.reset()
def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
"""Using the given network, decide what action to carry out using an epsilon-greedy policy.
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
action
"""
if random.random() < epsilon:
action = self.env.action_space.sample()
else:
state = torch.tensor([self.state])
if device not in ["cpu"]:
state = state.cuda(device)
q_values = net(state)
_, action = torch.max(q_values, dim=1)
action = int(action.item())
return action
@torch.no_grad()
def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> tuple[float, bool]:
"""Carries out a single interaction step between the agent and the environment.
Args:
net: DQN network
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
reward, done
"""
action = self.get_action(net, epsilon, device)
# do step in the environment
new_state, reward, done, _ = self.env.step(action)
exp = Experience(self.state, action, reward, done, new_state)
self.replay_buffer.append(exp)
self.state = new_state
if done:
self.reset()
return reward, done
class DQNLightning(LightningModule):
"""Basic DQN Model.
>>> DQNLightning(env="CartPole-v1") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
DQNLightning(
(net): DQN(
(net): Sequential(...)
)
(target_net): DQN(
(net): Sequential(...)
)
)
"""
def __init__(
self,
env: str,
replay_size: int = 200,
warm_start_steps: int = 200,
gamma: float = 0.99,
eps_start: float = 1.0,
eps_end: float = 0.01,
eps_last_frame: int = 200,
sync_rate: int = 10,
lr: float = 1e-2,
episode_length: int = 50,
batch_size: int = 4,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.replay_size = replay_size
self.warm_start_steps = warm_start_steps
self.gamma = gamma
self.eps_start = eps_start
self.eps_end = eps_end
self.eps_last_frame = eps_last_frame
self.sync_rate = sync_rate
self.lr = lr
self.episode_length = episode_length
self.batch_size = batch_size
self.env = gym.make(env)
obs_size = self.env.observation_space.shape[0]
n_actions = self.env.action_space.n
self.net = DQN(obs_size, n_actions)
self.target_net = DQN(obs_size, n_actions)
self.buffer = ReplayBuffer(self.replay_size)
self.agent = Agent(self.env, self.buffer)
self.total_reward = 0
self.episode_reward = 0
self.populate(self.warm_start_steps)
def populate(self, steps: int = 1000) -> None:
"""Carries out several random steps through the environment to initially fill up the replay buffer with
experiences.
Args:
steps: number of random steps to populate the buffer with
"""
for i in range(steps):
self.agent.play_step(self.net, epsilon=1.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Passes in a state `x` through the network and gets the `q_values` of each action as an output.
Args:
x: environment state
Returns:
q values
"""
return self.net(x)
def dqn_mse_loss(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Calculates the mse loss using a mini batch from the replay buffer.
Args:
batch: current mini batch of replay data
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch
state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()
expected_state_action_values = next_state_values * self.gamma + rewards
return nn.MSELoss()(state_action_values, expected_state_action_values)
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
the minibatch received.
Args:
batch: current mini batch of replay data
nb_batch: batch number
Returns:
Training loss and log metrics
"""
device = self.get_device(batch)
epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame)
# step through environment with agent
reward, done = self.agent.play_step(self.net, epsilon, device)
self.episode_reward += reward
# calculates training loss
loss = self.dqn_mse_loss(batch)
if done:
self.total_reward = self.episode_reward
self.episode_reward = 0
# Soft update of target network
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())
log = {
"total_reward": torch.tensor(self.total_reward).to(device),
"reward": torch.tensor(reward).to(device),
"steps": torch.tensor(self.global_step).to(device),
}
return OrderedDict({"loss": loss, "log": log, "progress_bar": log})
def configure_optimizers(self) -> list[Optimizer]:
"""Initialize Adam optimizer."""
optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
return [optimizer]
def __dataloader(self) -> DataLoader:
"""Initialize the Replay Buffer dataset used for retrieving experiences."""
dataset = RLDataset(self.buffer, self.episode_length)
return DataLoader(dataset=dataset, batch_size=self.batch_size, sampler=None)
def train_dataloader(self) -> DataLoader:
"""Get train loader."""
return self.__dataloader()
def get_device(self, batch) -> str:
"""Retrieve device currently being used by minibatch."""
return batch[0].device.index if self.on_gpu else "cpu"
def main(args) -> None:
model = DQNLightning(**vars(args))
trainer = Trainer(accelerator="cpu", devices=1, val_check_interval=100)
trainer.fit(model)
if __name__ == "__main__":
cli_lightning_logo()
seed_everything(0)
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag")
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")
parser.add_argument(
"--warm_start_steps",
type=int,
default=1000,
help="how many samples do we use to fill our buffer at the start of training",
)
parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying")
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
args = parser.parse_args()
main(args)