30 lines
849 B
Python
30 lines
849 B
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
|
|
def policy_loss(advantages: torch.Tensor, ratio: torch.Tensor, clip_coef: float) -> torch.Tensor:
|
|
pg_loss1 = -advantages * ratio
|
|
pg_loss2 = -advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
|
|
return torch.max(pg_loss1, pg_loss2).mean()
|
|
|
|
|
|
def value_loss(
|
|
new_values: Tensor,
|
|
old_values: Tensor,
|
|
returns: Tensor,
|
|
clip_coef: float,
|
|
clip_vloss: bool,
|
|
vf_coef: float,
|
|
) -> Tensor:
|
|
new_values = new_values.view(-1)
|
|
if not clip_vloss:
|
|
values_pred = new_values
|
|
else:
|
|
values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
|
|
return vf_coef * F.mse_loss(values_pred, returns)
|
|
|
|
|
|
def entropy_loss(entropy: Tensor, ent_coef: float) -> Tensor:
|
|
return -entropy.mean() * ent_coef
|