""" Proximal Policy Optimization (PPO) - Accelerated with Lightning Fabric Author: Federico Belotti @belerico Adapted from https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py Based on the paper: https://arxiv.org/abs/1707.06347 Requirements: - gymnasium[box2d]>=0.27.1 - moviepy - lightning - torchmetrics - tensorboard Run it with: lightning run model --accelerator=cpu --strategy=ddp --devices=2 train_fabric.py """ import argparse import os import time from datetime import datetime from typing import Dict import gymnasium as gym import torch import torchmetrics from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch import Tensor from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from lightning.fabric import Fabric from lightning.fabric.loggers import TensorBoardLogger def train( fabric: Fabric, agent: PPOLightningAgent, optimizer: torch.optim.Optimizer, data: Dict[str, Tensor], global_step: int, args: argparse.Namespace, ): indexes = list(range(data["obs"].shape[0])) if args.share_data: sampler = DistributedSampler( indexes, num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, seed=args.seed ) else: sampler = RandomSampler(indexes) sampler = BatchSampler(sampler, batch_size=args.per_rank_batch_size, drop_last=False) for epoch in range(args.update_epochs): if args.share_data: sampler.sampler.set_epoch(epoch) for batch_idxes in sampler: loss = agent.training_step({k: v[batch_idxes] for k, v in data.items()}) optimizer.zero_grad(set_to_none=True) fabric.backward(loss) fabric.clip_gradients(agent, optimizer, max_norm=args.max_grad_norm) optimizer.step() agent.on_train_epoch_end(global_step) def main(args: argparse.Namespace): run_name = f"{args.env_id}_{args.exp_name}_{args.seed}_{int(time.time())}" logger = TensorBoardLogger( root_dir=os.path.join("logs", "fabric_logs", datetime.today().strftime("%Y-%m-%d_%H-%M-%S")), name=run_name ) # Initialize Fabric fabric = Fabric(loggers=logger) rank = fabric.global_rank world_size = fabric.world_size device = fabric.device fabric.seed_everything(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic # Log hyperparameters fabric.logger.experiment.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # Environment setup envs = gym.vector.SyncVectorEnv( [ make_env( args.env_id, args.seed + rank * args.num_envs + i, rank, args.capture_video, logger.log_dir, "train" ) for i in range(args.num_envs) ] ) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" # Define the agent and the optimizer and setup them with Fabric agent: PPOLightningAgent = PPOLightningAgent( envs, act_fun=args.activation_function, vf_coef=args.vf_coef, ent_coef=args.ent_coef, clip_coef=args.clip_coef, clip_vloss=args.clip_vloss, ortho_init=args.ortho_init, normalize_advantages=args.normalize_advantages, ) optimizer = agent.configure_optimizers(args.learning_rate) agent, optimizer = fabric.setup(agent, optimizer) # Player metrics rew_avg = torchmetrics.MeanMetric().to(device) ep_len_avg = torchmetrics.MeanMetric().to(device) # Local data obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape, device=device) actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape, device=device) logprobs = torch.zeros((args.num_steps, args.num_envs), device=device) rewards = torch.zeros((args.num_steps, args.num_envs), device=device) dones = torch.zeros((args.num_steps, args.num_envs), device=device) values = torch.zeros((args.num_steps, args.num_envs), device=device) # Global variables global_step = 0 start_time = time.time() single_global_rollout = int(args.num_envs * args.num_steps * world_size) num_updates = args.total_timesteps // single_global_rollout # Get the first environment observation and start the optimization next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device) next_done = torch.zeros(args.num_envs, device=device) for update in range(1, num_updates + 1): # Learning rate annealing if args.anneal_lr: linear_annealing(optimizer, update, num_updates, args.learning_rate) fabric.log("Info/learning_rate", optimizer.param_groups[0]["lr"], global_step) for step in range(0, args.num_steps): global_step += args.num_envs * world_size obs[step] = next_obs dones[step] = next_done # Sample an action given the observation received by the environment with torch.no_grad(): action, logprob, _, value = agent.get_action_and_value(next_obs) values[step] = value.flatten() actions[step] = action logprobs[step] = logprob # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) rewards[step] = torch.tensor(reward, device=device).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: for i, agent_final_info in enumerate(info["final_info"]): if agent_final_info is not None and "episode" in agent_final_info: fabric.print( f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" ) rew_avg(agent_final_info["episode"]["r"][0]) ep_len_avg(agent_final_info["episode"]["l"][0]) # Sync the metrics rew_avg_reduced = rew_avg.compute() if not rew_avg_reduced.isnan(): fabric.log("Rewards/rew_avg", rew_avg_reduced, global_step) ep_len_avg_reduced = ep_len_avg.compute() if not ep_len_avg_reduced.isnan(): fabric.log("Game/ep_len_avg", ep_len_avg_reduced, global_step) rew_avg.reset() ep_len_avg.reset() # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) returns, advantages = agent.estimate_returns_and_advantages( rewards, values, dones, next_obs, next_done, args.num_steps, args.gamma, args.gae_lambda ) # Flatten the batch local_data = { "obs": obs.reshape((-1,) + envs.single_observation_space.shape), "logprobs": logprobs.reshape(-1), "actions": actions.reshape((-1,) + envs.single_action_space.shape), "advantages": advantages.reshape(-1), "returns": returns.reshape(-1), "values": values.reshape(-1), } if args.share_data: # Gather all the tensors from all the world and reshape them gathered_data = fabric.all_gather(local_data) for k, v in gathered_data.items(): if k == "obs": gathered_data[k] = v.reshape((-1,) + envs.single_observation_space.shape) elif k == "actions": gathered_data[k] = v.reshape((-1,) + envs.single_action_space.shape) else: gathered_data[k] = v.reshape(-1) else: gathered_data = local_data # Train the agent train(fabric, agent, optimizer, gathered_data, global_step, args) fabric.log("Time/step_per_second", int(global_step / (time.time() - start_time)), global_step) envs.close() if fabric.is_global_zero: test(agent.module, device, fabric.logger.experiment, args) if __name__ == "__main__": args = parse_args() main(args)