Replace NumPy with Torch in examples/fabric/ (#17279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ef7da5c445
commit
fb775e0855
|
@ -16,7 +16,6 @@ Run it with:
|
|||
"""
|
||||
import cherry
|
||||
import learn2learn as l2l
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lightning.fabric import Fabric, seed_everything
|
||||
|
@ -31,10 +30,9 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways):
|
|||
data, labels = batch
|
||||
|
||||
# Separate data into adaptation/evalutation sets
|
||||
adaptation_indices = np.zeros(data.size(0), dtype=bool)
|
||||
adaptation_indices[np.arange(shots * ways) * 2] = True
|
||||
evaluation_indices = torch.from_numpy(~adaptation_indices)
|
||||
adaptation_indices = torch.from_numpy(adaptation_indices)
|
||||
adaptation_indices = torch.zeros(data.size(0), dtype=bool)
|
||||
adaptation_indices[torch.arange(shots * ways) * 2] = True
|
||||
evaluation_indices = ~adaptation_indices
|
||||
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
|
||||
evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ import random
|
|||
|
||||
import cherry
|
||||
import learn2learn as l2l
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
@ -35,10 +34,9 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
|
|||
data, labels = data.to(device), labels.to(device)
|
||||
|
||||
# Separate data into adaptation/evalutation sets
|
||||
adaptation_indices = np.zeros(data.size(0), dtype=bool)
|
||||
adaptation_indices[np.arange(shots * ways) * 2] = True
|
||||
evaluation_indices = torch.from_numpy(~adaptation_indices)
|
||||
adaptation_indices = torch.from_numpy(adaptation_indices)
|
||||
adaptation_indices = torch.zeros(data.size(0), dtype=bool)
|
||||
adaptation_indices[torch.arange(shots * ways) * 2] = True
|
||||
evaluation_indices = ~adaptation_indices
|
||||
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
|
||||
evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
|
||||
|
||||
|
@ -76,7 +74,6 @@ def main(
|
|||
seed = seed + rank
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
device = torch.device("cpu")
|
||||
if cuda and torch.cuda.device_count():
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from rl.loss import entropy_loss, policy_loss, value_loss
|
||||
|
@ -24,7 +24,8 @@ class PPOAgent(torch.nn.Module):
|
|||
raise ValueError("Unrecognized activation function: `act_fun` must be either `relu` or `tanh`")
|
||||
self.critic = torch.nn.Sequential(
|
||||
layer_init(
|
||||
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
|
||||
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
|
||||
ortho_init=ortho_init,
|
||||
),
|
||||
act_fun,
|
||||
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
|
||||
|
@ -33,7 +34,8 @@ class PPOAgent(torch.nn.Module):
|
|||
)
|
||||
self.actor = torch.nn.Sequential(
|
||||
layer_init(
|
||||
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
|
||||
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
|
||||
ortho_init=ortho_init,
|
||||
),
|
||||
act_fun,
|
||||
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
|
||||
|
@ -81,10 +83,10 @@ class PPOAgent(torch.nn.Module):
|
|||
lastgaelam = 0
|
||||
for t in reversed(range(num_steps)):
|
||||
if t == num_steps - 1:
|
||||
nextnonterminal = 1.0 - next_done
|
||||
nextnonterminal = torch.logical_not(next_done)
|
||||
nextvalues = next_value
|
||||
else:
|
||||
nextnonterminal = 1.0 - dones[t + 1]
|
||||
nextnonterminal = torch.logical_not(dones[t + 1])
|
||||
nextvalues = values[t + 1]
|
||||
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
|
||||
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
|
||||
|
@ -119,7 +121,8 @@ class PPOLightningAgent(LightningModule):
|
|||
self.normalize_advantages = normalize_advantages
|
||||
self.critic = torch.nn.Sequential(
|
||||
layer_init(
|
||||
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
|
||||
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
|
||||
ortho_init=ortho_init,
|
||||
),
|
||||
act_fun,
|
||||
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
|
||||
|
@ -128,7 +131,8 @@ class PPOLightningAgent(LightningModule):
|
|||
)
|
||||
self.actor = torch.nn.Sequential(
|
||||
layer_init(
|
||||
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
|
||||
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
|
||||
ortho_init=ortho_init,
|
||||
),
|
||||
act_fun,
|
||||
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
|
||||
|
@ -179,10 +183,10 @@ class PPOLightningAgent(LightningModule):
|
|||
lastgaelam = 0
|
||||
for t in reversed(range(num_steps)):
|
||||
if t == num_steps - 1:
|
||||
nextnonterminal = 1.0 - next_done
|
||||
nextnonterminal = torch.logical_not(next_done)
|
||||
nextvalues = next_value
|
||||
else:
|
||||
nextnonterminal = 1.0 - dones[t + 1]
|
||||
nextnonterminal = torch.logical_not(dones[t + 1])
|
||||
nextvalues = values[t + 1]
|
||||
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
|
||||
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import argparse
|
||||
import math
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -119,7 +118,12 @@ def parse_args():
|
|||
return args
|
||||
|
||||
|
||||
def layer_init(layer: torch.nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0, ortho_init: bool = True):
|
||||
def layer_init(
|
||||
layer: torch.nn.Module,
|
||||
std: float = math.sqrt(2),
|
||||
bias_const: float = 0.0,
|
||||
ortho_init: bool = True,
|
||||
):
|
||||
if ortho_init:
|
||||
torch.nn.init.orthogonal_(layer.weight, std)
|
||||
torch.nn.init.constant_(layer.bias, bias_const)
|
||||
|
@ -157,16 +161,16 @@ def test(
|
|||
step = 0
|
||||
done = False
|
||||
cumulative_rew = 0
|
||||
next_obs = Tensor(env.reset(seed=args.seed)[0]).to(device)
|
||||
next_obs = torch.tensor(env.reset(seed=args.seed)[0], device=device)
|
||||
while not done:
|
||||
# Act greedly through the environment
|
||||
action = agent.get_greedy_action(next_obs)
|
||||
|
||||
# Single environment step
|
||||
next_obs, reward, done, truncated, info = env.step(action.cpu().numpy())
|
||||
done = np.logical_or(done, truncated)
|
||||
done = done or truncated
|
||||
cumulative_rew += reward
|
||||
next_obs = Tensor(next_obs).to(device)
|
||||
next_obs = torch.tensor(next_obs, device=device)
|
||||
step += 1
|
||||
logger.add_scalar("Test/cumulative_reward", cumulative_rew, 0)
|
||||
env.close()
|
||||
|
|
|
@ -24,7 +24,6 @@ from datetime import datetime
|
|||
from typing import Dict
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from rl.agent import PPOLightningAgent
|
||||
|
@ -128,7 +127,7 @@ def main(args: argparse.Namespace):
|
|||
num_updates = args.total_timesteps // single_global_rollout
|
||||
|
||||
# Get the first environment observation and start the optimization
|
||||
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
|
||||
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
|
||||
next_done = torch.zeros(args.num_envs, device=device)
|
||||
for update in range(1, num_updates + 1):
|
||||
# Learning rate annealing
|
||||
|
@ -150,9 +149,9 @@ def main(args: argparse.Namespace):
|
|||
|
||||
# Single environment step
|
||||
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
|
||||
done = np.logical_or(done, truncated)
|
||||
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
|
||||
rewards[step] = torch.tensor(reward, device=device).view(-1)
|
||||
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
|
||||
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
|
||||
|
||||
if "final_info" in info:
|
||||
for agent_final_info in info["final_info"]:
|
||||
|
|
|
@ -23,11 +23,9 @@ import time
|
|||
from datetime import datetime
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from rl.agent import PPOLightningAgent
|
||||
from rl.utils import linear_annealing, make_env, parse_args, test
|
||||
from torch import Tensor
|
||||
from torch.utils.data import BatchSampler, DistributedSampler
|
||||
from torchmetrics import MeanMetric
|
||||
|
||||
|
@ -108,7 +106,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
|
|||
world_collective.broadcast(update_t, src=0)
|
||||
|
||||
# Get the first environment observation and start the optimization
|
||||
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
|
||||
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
|
||||
next_done = torch.zeros(args.num_envs).to(device)
|
||||
for update in range(1, num_updates + 1):
|
||||
for step in range(0, args.num_steps):
|
||||
|
@ -124,9 +122,9 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
|
|||
|
||||
# Single environment step
|
||||
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
|
||||
done = np.logical_or(done, truncated)
|
||||
rewards[step] = torch.tensor(reward).to(device).view(-1)
|
||||
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
|
||||
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
|
||||
rewards[step] = torch.tensor(reward, device=device).view(-1)
|
||||
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
|
||||
|
||||
if "final_info" in info:
|
||||
for agent_final_info in info["final_info"]:
|
||||
|
|
|
@ -25,7 +25,6 @@ from datetime import datetime
|
|||
from typing import Dict
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as distributed
|
||||
import torch.nn as nn
|
||||
|
@ -118,7 +117,6 @@ def main(args: argparse.Namespace):
|
|||
|
||||
# Seed everything
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
torch.backends.cudnn.deterministic = args.torch_deterministic
|
||||
|
@ -181,7 +179,7 @@ def main(args: argparse.Namespace):
|
|||
num_updates = args.total_timesteps // single_global_step
|
||||
|
||||
# Get the first environment observation and start the optimization
|
||||
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
|
||||
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
|
||||
next_done = torch.zeros(args.num_envs, device=device)
|
||||
for update in range(1, num_updates + 1):
|
||||
# Learning rate annealing
|
||||
|
@ -204,9 +202,9 @@ def main(args: argparse.Namespace):
|
|||
|
||||
# Single environment step
|
||||
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
|
||||
done = np.logical_or(done, truncated)
|
||||
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
|
||||
rewards[step] = torch.tensor(reward, device=device).view(-1)
|
||||
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
|
||||
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
|
||||
|
||||
if "final_info" in info:
|
||||
for agent_final_info in info["final_info"]:
|
||||
|
|
Loading…
Reference in New Issue