353 lines
14 KiB
Python
353 lines
14 KiB
Python
"""
|
|
Proximal Policy Optimization (PPO) - Accelerated with Lightning Fabric
|
|
|
|
Author: Federico Belotti @belerico
|
|
Adapted from https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
|
|
Based on the paper: https://arxiv.org/abs/1707.06347
|
|
|
|
Requirements:
|
|
- gymnasium[box2d]>=0.27.1
|
|
- moviepy
|
|
- lightning
|
|
- torchmetrics
|
|
- tensorboard
|
|
|
|
|
|
Run it with:
|
|
lightning run model --devices=2 train_fabric_decoupled.py
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
from contextlib import nullcontext
|
|
from datetime import datetime
|
|
|
|
import gymnasium as gym
|
|
import torch
|
|
from rl.agent import PPOLightningAgent
|
|
from rl.utils import linear_annealing, make_env, parse_args, test
|
|
from torch.distributed.algorithms.join import Join
|
|
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler
|
|
from torchmetrics import MeanMetric
|
|
|
|
from lightning.fabric import Fabric
|
|
from lightning.fabric.loggers import TensorBoardLogger
|
|
from lightning.fabric.plugins.collectives import TorchCollective
|
|
from lightning.fabric.plugins.collectives.collective import CollectibleGroup
|
|
from lightning.fabric.strategies import DDPStrategy
|
|
|
|
|
|
@torch.no_grad()
|
|
def player(args, world_collective: TorchCollective, player_trainer_collective: TorchCollective):
|
|
run_name = f"{args.env_id}_{args.exp_name}_{args.seed}"
|
|
logger = TensorBoardLogger(
|
|
root_dir=os.path.join("logs", "fabric_decoupled_logs", datetime.today().strftime("%Y-%m-%d_%H-%M-%S")),
|
|
name=run_name,
|
|
)
|
|
log_dir = logger.log_dir
|
|
|
|
# Initialize Fabric object
|
|
fabric = Fabric(loggers=logger, accelerator="cuda" if args.player_on_gpu else "cpu")
|
|
device = fabric.device
|
|
fabric.seed_everything(args.seed)
|
|
torch.backends.cudnn.deterministic = args.torch_deterministic
|
|
|
|
# Log hyperparameters
|
|
logger.experiment.add_text(
|
|
"hyperparameters",
|
|
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
|
)
|
|
|
|
# Environment setup
|
|
envs = gym.vector.SyncVectorEnv(
|
|
[make_env(args.env_id, args.seed + i, 0, args.capture_video, log_dir, "train") for i in range(args.num_envs)]
|
|
)
|
|
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
|
|
|
|
# Define the agent
|
|
agent: PPOLightningAgent = PPOLightningAgent(
|
|
envs,
|
|
act_fun=args.activation_function,
|
|
vf_coef=args.vf_coef,
|
|
ent_coef=args.ent_coef,
|
|
clip_coef=args.clip_coef,
|
|
clip_vloss=args.clip_vloss,
|
|
ortho_init=args.ortho_init,
|
|
normalize_advantages=args.normalize_advantages,
|
|
).to(device)
|
|
flattened_parameters = torch.empty_like(
|
|
torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), device=device
|
|
)
|
|
|
|
# Receive the first weights from the rank-1, a.k.a. the first of the trainers
|
|
# In this way we are sure that before the first iteration everyone starts with the same parameters
|
|
player_trainer_collective.broadcast(flattened_parameters, src=1)
|
|
torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, agent.parameters())
|
|
|
|
# Player metrics
|
|
rew_avg = MeanMetric(sync_on_compute=False).to(device)
|
|
ep_len_avg = MeanMetric(sync_on_compute=False).to(device)
|
|
|
|
# Local data
|
|
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
|
|
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
|
|
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
|
|
|
|
# Global variables
|
|
global_step = 0
|
|
start_time = time.time()
|
|
single_global_step = int(args.num_envs * args.num_steps)
|
|
num_updates = args.total_timesteps // single_global_step
|
|
if not args.share_data:
|
|
if single_global_step < world_collective.world_size - 1:
|
|
raise RuntimeError(
|
|
"The number of trainers ({}) is greater than the available collected data ({}). ".format(
|
|
world_collective.world_size - 1, single_global_step
|
|
)
|
|
+ "Consider to lower the number of trainers at least to the size of available collected data"
|
|
)
|
|
chunks_sizes = [
|
|
len(chunk)
|
|
for chunk in torch.tensor_split(torch.arange(single_global_step), world_collective.world_size - 1)
|
|
]
|
|
|
|
# Broadcast num_updates to all the world
|
|
update_t = torch.tensor([num_updates], device=device, dtype=torch.float32)
|
|
world_collective.broadcast(update_t, src=0)
|
|
|
|
# Get the first environment observation and start the optimization
|
|
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
|
|
next_done = torch.zeros(args.num_envs).to(device)
|
|
for _ in range(1, num_updates + 1):
|
|
for step in range(0, args.num_steps):
|
|
global_step += args.num_envs
|
|
obs[step] = next_obs
|
|
dones[step] = next_done
|
|
|
|
# Sample an action given the observation received by the environment
|
|
action, logprob, _, value = agent.get_action_and_value(next_obs)
|
|
values[step] = value.flatten()
|
|
actions[step] = action
|
|
logprobs[step] = logprob
|
|
|
|
# Single environment step
|
|
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
|
|
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
|
|
rewards[step] = torch.tensor(reward, device=device).view(-1)
|
|
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
|
|
|
|
if "final_info" in info:
|
|
for i, agent_final_info in enumerate(info["final_info"]):
|
|
if agent_final_info is not None and "episode" in agent_final_info:
|
|
fabric.print(
|
|
f"Rank-0: global_step={global_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}"
|
|
)
|
|
rew_avg(agent_final_info["episode"]["r"][0])
|
|
ep_len_avg(agent_final_info["episode"]["l"][0])
|
|
|
|
# Sync the metrics
|
|
rew_avg_reduced = rew_avg.compute()
|
|
if not rew_avg_reduced.isnan():
|
|
fabric.log("Rewards/rew_avg", rew_avg_reduced, global_step)
|
|
ep_len_avg_reduced = ep_len_avg.compute()
|
|
if not ep_len_avg_reduced.isnan():
|
|
fabric.log("Game/ep_len_avg", ep_len_avg_reduced, global_step)
|
|
rew_avg.reset()
|
|
ep_len_avg.reset()
|
|
|
|
# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
|
|
returns, advantages = agent.estimate_returns_and_advantages(
|
|
rewards, values, dones, next_obs, next_done, args.num_steps, args.gamma, args.gae_lambda
|
|
)
|
|
|
|
# Flatten the batch
|
|
local_data = {
|
|
"obs": obs.reshape((-1,) + envs.single_observation_space.shape),
|
|
"logprobs": logprobs.reshape(-1),
|
|
"actions": actions.reshape((-1,) + envs.single_action_space.shape),
|
|
"advantages": advantages.reshape(-1),
|
|
"returns": returns.reshape(-1),
|
|
"values": values.reshape(-1),
|
|
}
|
|
if not args.player_on_gpu and args.cuda:
|
|
for v in local_data.values():
|
|
v = v.pin_memory()
|
|
|
|
# Send data to the training agents
|
|
if args.share_data:
|
|
world_collective.broadcast_object_list([local_data], src=0)
|
|
else:
|
|
# Split data in an even way, when possible
|
|
perm = torch.randperm(single_global_step, device=device)
|
|
chunks = [{} for _ in range(world_collective.world_size - 1)]
|
|
for k, v in local_data.items():
|
|
chunked_local_data = v[perm].split(chunks_sizes)
|
|
for i in range(len(chunks)):
|
|
chunks[i][k] = chunked_local_data[i]
|
|
|
|
world_collective.scatter_object_list([None], [None] + chunks, src=0)
|
|
|
|
# Gather metrics from the trainers to be plotted
|
|
metrics = [None]
|
|
player_trainer_collective.broadcast_object_list(metrics, src=1)
|
|
|
|
# Wait the trainers to finish
|
|
player_trainer_collective.broadcast(flattened_parameters, src=1)
|
|
|
|
# Convert back the parameters
|
|
torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, agent.parameters())
|
|
|
|
fabric.log_dict(metrics[0], global_step)
|
|
fabric.log_dict({"Time/step_per_second": int(global_step / (time.time() - start_time))}, global_step)
|
|
|
|
if args.share_data:
|
|
world_collective.broadcast_object_list([-1], src=0)
|
|
else:
|
|
world_collective.scatter_object_list([None], [None] + [-1] * (world_collective.world_size - 1), src=0)
|
|
envs.close()
|
|
test(agent, device, fabric.logger.experiment, args)
|
|
|
|
|
|
def trainer(
|
|
args,
|
|
world_collective: TorchCollective,
|
|
player_trainer_collective: TorchCollective,
|
|
optimization_pg: CollectibleGroup,
|
|
):
|
|
global_rank = world_collective.rank
|
|
group_rank = global_rank - 1
|
|
group_world_size = world_collective.world_size - 1
|
|
|
|
# Initialize Fabric
|
|
fabric = Fabric(strategy=DDPStrategy(process_group=optimization_pg), accelerator="cuda" if args.cuda else "cpu")
|
|
device = fabric.device
|
|
fabric.seed_everything(args.seed)
|
|
torch.backends.cudnn.deterministic = args.torch_deterministic
|
|
|
|
# Environment setup
|
|
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, 0, 0, False, None)])
|
|
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
|
|
|
|
# Define the agent and the optimizer and setup them with Fabric
|
|
agent: PPOLightningAgent = PPOLightningAgent(
|
|
envs,
|
|
act_fun=args.activation_function,
|
|
vf_coef=args.vf_coef,
|
|
ent_coef=args.ent_coef,
|
|
clip_coef=args.clip_coef,
|
|
clip_vloss=args.clip_vloss,
|
|
ortho_init=args.ortho_init,
|
|
normalize_advantages=args.normalize_advantages,
|
|
process_group=optimization_pg,
|
|
)
|
|
optimizer = agent.configure_optimizers(args.learning_rate)
|
|
agent, optimizer = fabric.setup(agent, optimizer)
|
|
|
|
# Send weights to rank-0, a.k.a. the player
|
|
if global_rank == 1:
|
|
player_trainer_collective.broadcast(
|
|
torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), src=1
|
|
)
|
|
|
|
# Receive maximum number of updates from the player
|
|
update = 0
|
|
num_updates = torch.zeros(1, device=device)
|
|
world_collective.broadcast(num_updates, src=0)
|
|
num_updates = num_updates.item()
|
|
|
|
# Start training
|
|
while True:
|
|
# Wait for data
|
|
data = [None]
|
|
if args.share_data:
|
|
world_collective.broadcast_object_list(data, src=0)
|
|
else:
|
|
world_collective.scatter_object_list(data, [None for _ in range(world_collective.world_size)], src=0)
|
|
data = data[0]
|
|
if data == -1:
|
|
return
|
|
|
|
# Metrics dict to be sent to the player
|
|
if group_rank == 0:
|
|
metrics = {}
|
|
|
|
# Lerning rate annealing
|
|
if args.anneal_lr:
|
|
linear_annealing(optimizer, update, num_updates, args.learning_rate)
|
|
if group_rank == 0:
|
|
metrics["Info/learning_rate"] = optimizer.param_groups[0]["lr"]
|
|
update += 1
|
|
|
|
indexes = list(range(data["obs"].shape[0]))
|
|
if args.share_data:
|
|
sampler = DistributedSampler(
|
|
indexes, num_replicas=group_world_size, rank=group_rank, shuffle=True, seed=args.seed, drop_last=False
|
|
)
|
|
else:
|
|
sampler = RandomSampler(indexes)
|
|
sampler = BatchSampler(sampler, batch_size=args.per_rank_batch_size, drop_last=False)
|
|
|
|
# The Join context is needed because there can be the possibility
|
|
# that some ranks receive less data
|
|
with Join([agent._forward_module]) if not args.share_data else nullcontext():
|
|
for epoch in range(args.update_epochs):
|
|
if args.share_data:
|
|
sampler.sampler.set_epoch(epoch)
|
|
for batch_idxes in sampler:
|
|
loss = agent.training_step({k: v[batch_idxes].to(device) for k, v in data.items()})
|
|
optimizer.zero_grad(set_to_none=True)
|
|
fabric.backward(loss)
|
|
fabric.clip_gradients(agent, optimizer, max_norm=args.max_grad_norm)
|
|
optimizer.step()
|
|
|
|
# Sync metrics
|
|
avg_pg_loss = agent.avg_pg_loss.compute()
|
|
avg_value_loss = agent.avg_value_loss.compute()
|
|
avg_ent_loss = agent.avg_ent_loss.compute()
|
|
agent.reset_metrics()
|
|
|
|
# Send updated weights to the player
|
|
if global_rank == 1:
|
|
metrics["Loss/policy_loss"] = avg_pg_loss
|
|
metrics["Loss/value_loss"] = avg_value_loss
|
|
metrics["Loss/entropy_loss"] = avg_ent_loss
|
|
player_trainer_collective.broadcast_object_list(
|
|
[metrics], src=1
|
|
) # Broadcast metrics: fake send with object list between rank-0 and rank-1
|
|
player_trainer_collective.broadcast(
|
|
torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), src=1
|
|
)
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
world_collective = TorchCollective()
|
|
player_trainer_collective = TorchCollective()
|
|
world_collective.setup(backend="nccl" if args.player_on_gpu and args.cuda else "gloo")
|
|
|
|
# Create a global group, assigning it to the collective: used by the player to exchange
|
|
# collected experiences with the trainers
|
|
world_collective.create_group()
|
|
global_rank = world_collective.rank
|
|
|
|
# Create a group between rank-0 (player) and rank-1 (trainer), assigning it to the collective:
|
|
# used by rank-1 to send metrics to be tracked by the rank-0 at the end of a training episode
|
|
player_trainer_collective.create_group(ranks=[0, 1])
|
|
|
|
# Create a new group, without assigning it to the collective: in this way the trainers can
|
|
# still communicate with the player through the global group, but they can optimize the agent
|
|
# between themselves
|
|
optimization_pg = world_collective.new_group(ranks=list(range(1, world_collective.world_size)))
|
|
if global_rank == 0:
|
|
player(args, world_collective, player_trainer_collective)
|
|
else:
|
|
trainer(args, world_collective, player_trainer_collective, optimization_pg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|