544 lines
23 KiB
Python
544 lines
23 KiB
Python
import os
|
|
from collections.abc import Iterable, Mapping
|
|
from functools import partial
|
|
from typing import Any, Literal, Optional, Union, cast
|
|
|
|
import lightning as L
|
|
import torch
|
|
from lightning.fabric.accelerators import Accelerator
|
|
from lightning.fabric.loggers import Logger
|
|
from lightning.fabric.strategies import Strategy
|
|
from lightning.fabric.wrappers import _unwrap_objects
|
|
from lightning.pytorch.utilities.model_helpers import is_overridden
|
|
from lightning_utilities import apply_to_collection
|
|
from tqdm import tqdm
|
|
|
|
|
|
class MyCustomTrainer:
|
|
def __init__(
|
|
self,
|
|
accelerator: Union[str, Accelerator] = "auto",
|
|
strategy: Union[str, Strategy] = "auto",
|
|
devices: Union[list[int], str, int] = "auto",
|
|
precision: Union[str, int] = "32-true",
|
|
plugins: Optional[Union[str, Any]] = None,
|
|
callbacks: Optional[Union[list[Any], Any]] = None,
|
|
loggers: Optional[Union[Logger, list[Logger]]] = None,
|
|
max_epochs: Optional[int] = 1000,
|
|
max_steps: Optional[int] = None,
|
|
grad_accum_steps: int = 1,
|
|
limit_train_batches: Union[int, float] = float("inf"),
|
|
limit_val_batches: Union[int, float] = float("inf"),
|
|
validation_frequency: int = 1,
|
|
use_distributed_sampler: bool = True,
|
|
checkpoint_dir: str = "./checkpoints",
|
|
checkpoint_frequency: int = 1,
|
|
) -> None:
|
|
"""Exemplary Trainer with Fabric. This is a very simple trainer focused on readablity but with reduced
|
|
featureset. As a trainer with more included features, we recommend using the
|
|
:class:`lightning.pytorch.Trainer`.
|
|
|
|
Args:
|
|
accelerator: The hardware to run on. Possible choices are:
|
|
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
|
strategy: Strategy for how to run across multiple devices. Possible choices are:
|
|
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
|
|
devices: Number of devices to train on (``int``),
|
|
which GPUs to train on (``list`` or ``str``), or ``"auto"``.
|
|
The value applies per node.
|
|
precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
|
|
or bfloat16 precision AMP (``"bf16-mixed"``).
|
|
plugins: One or several custom plugins
|
|
callbacks: A single callback or a list of callbacks. The following hooks are supported:
|
|
- on_train_epoch_start
|
|
- on train_epoch_end
|
|
- on_train_batch_start
|
|
- on_train_batch_end
|
|
- on_before_backward
|
|
- on_after_backward
|
|
- on_before_zero_grad
|
|
- on_before_optimizer_step
|
|
- on_validation_model_eval
|
|
- on_validation_model_train
|
|
- on_validation_epoch_start
|
|
- on_validation_epoch_end
|
|
- on_validation_batch_start
|
|
- on_validation_batch_end
|
|
|
|
loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
|
|
information.
|
|
|
|
max_epochs: The maximum number of epochs to train
|
|
max_steps: The maximum number of (optimizer) steps to train
|
|
grad_accum_steps: How many batches to process before each optimizer step
|
|
limit_train_batches: Limits the number of train batches per epoch
|
|
If greater than number of batches in the dataloader, this has no effect.
|
|
limit_val_batches: Limits the number of validation batches per epoch.
|
|
If greater than number of batches in the dataloader, this has no effect.
|
|
validation_frequency: How many epochs to run before each validation epoch.
|
|
use_distributed_sampler: Wraps the sampler of each dataloader with a respective distributed-aware sampler
|
|
in case of distributed training.
|
|
checkpoint_dir: Directory to store checkpoints to.
|
|
checkpoint_frequency: How many epochs to run before each checkpoint is written.
|
|
|
|
Warning:
|
|
callbacks written for the lightning trainer (especially making assumptions on the trainer), won't work!
|
|
|
|
"""
|
|
|
|
self.fabric = L.Fabric(
|
|
accelerator=accelerator,
|
|
strategy=strategy,
|
|
devices=devices,
|
|
precision=precision,
|
|
plugins=plugins,
|
|
callbacks=callbacks,
|
|
loggers=loggers,
|
|
)
|
|
self.global_step = 0
|
|
self.grad_accum_steps: int = grad_accum_steps
|
|
self.current_epoch = 0
|
|
|
|
self.max_epochs = max_epochs
|
|
self.max_steps = max_steps
|
|
self.should_stop = False
|
|
|
|
# ensures limit_X_batches is either int or inf
|
|
if not isinstance(limit_train_batches, int):
|
|
assert limit_train_batches == float("inf")
|
|
|
|
if not isinstance(limit_val_batches, int):
|
|
assert limit_val_batches == float("inf")
|
|
|
|
self.limit_train_batches = limit_train_batches
|
|
self.limit_val_batches = limit_val_batches
|
|
self.validation_frequency = validation_frequency
|
|
self.use_distributed_sampler = use_distributed_sampler
|
|
self._current_train_return: Union[torch.Tensor, Mapping[str, Any]] = {}
|
|
self._current_val_return: Optional[Union[torch.Tensor, Mapping[str, Any]]] = {}
|
|
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.checkpoint_frequency = checkpoint_frequency
|
|
|
|
def fit(
|
|
self,
|
|
model: L.LightningModule,
|
|
train_loader: torch.utils.data.DataLoader,
|
|
val_loader: torch.utils.data.DataLoader,
|
|
ckpt_path: Optional[str] = None,
|
|
):
|
|
"""The main entrypoint of the trainer, triggering the actual training.
|
|
|
|
Args:
|
|
model: the LightningModule to train.
|
|
Can have the same hooks as :attr:`callbacks` (see :meth:`MyCustomTrainer.__init__`).
|
|
train_loader: the training dataloader. Has to be an iterable returning batches.
|
|
val_loader: the validation dataloader. Has to be an iterable returning batches.
|
|
If not specified, no validation will run.
|
|
ckpt_path: Path to previous checkpoints to resume training from.
|
|
If specified, will always look for the latest checkpoint within the given directory.
|
|
|
|
"""
|
|
self.fabric.launch()
|
|
|
|
# setup dataloaders
|
|
train_loader = self.fabric.setup_dataloaders(train_loader, use_distributed_sampler=self.use_distributed_sampler)
|
|
if val_loader is not None:
|
|
val_loader = self.fabric.setup_dataloaders(val_loader, use_distributed_sampler=self.use_distributed_sampler)
|
|
|
|
# setup model and optimizer
|
|
if isinstance(self.fabric.strategy, L.fabric.strategies.fsdp.FSDPStrategy):
|
|
# currently, there is no way to support fsdp with model.configure_optimizers in fabric
|
|
# as it would require fabric to hold a reference to the model, which we don't want to.
|
|
raise NotImplementedError("BYOT currently does not support FSDP")
|
|
|
|
optimizer, scheduler_cfg = self._parse_optimizers_schedulers(model.configure_optimizers())
|
|
assert optimizer is not None
|
|
model, optimizer = self.fabric.setup(model, optimizer)
|
|
|
|
# assemble state (current epoch and global step will be added in save)
|
|
state = {"model": model, "optim": optimizer, "scheduler": scheduler_cfg}
|
|
|
|
# load last checkpoint if available
|
|
if ckpt_path is not None and os.path.isdir(ckpt_path):
|
|
latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir)
|
|
if latest_checkpoint_path is not None:
|
|
self.load(state, latest_checkpoint_path)
|
|
|
|
# check if we even need to train here
|
|
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
|
|
self.should_stop = True
|
|
|
|
while not self.should_stop:
|
|
self.train_loop(
|
|
model, optimizer, train_loader, limit_batches=self.limit_train_batches, scheduler_cfg=scheduler_cfg
|
|
)
|
|
|
|
if self.should_validate:
|
|
self.val_loop(model, val_loader, limit_batches=self.limit_val_batches)
|
|
|
|
self.step_scheduler(model, scheduler_cfg, level="epoch", current_value=self.current_epoch)
|
|
|
|
self.current_epoch += 1
|
|
|
|
# stopping condition on epoch level
|
|
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
|
|
self.should_stop = True
|
|
|
|
self.save(state)
|
|
|
|
# reset for next fit call
|
|
self.should_stop = False
|
|
|
|
def train_loop(
|
|
self,
|
|
model: L.LightningModule,
|
|
optimizer: torch.optim.Optimizer,
|
|
train_loader: torch.utils.data.DataLoader,
|
|
limit_batches: Union[int, float] = float("inf"),
|
|
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
|
|
):
|
|
"""The training loop running a single training epoch.
|
|
|
|
Args:
|
|
model: the LightningModule to train
|
|
optimizer: the optimizer, optimizing the LightningModule.
|
|
train_loader: The dataloader yielding the training batches.
|
|
limit_batches: Limits the batches during this training epoch.
|
|
If greater than the number of batches in the ``train_loader``, this has no effect.
|
|
scheduler_cfg: The learning rate scheduler configuration.
|
|
Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
|
|
for supported values.
|
|
|
|
"""
|
|
self.fabric.call("on_train_epoch_start")
|
|
iterable = self.progbar_wrapper(
|
|
train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
|
|
)
|
|
|
|
for batch_idx, batch in enumerate(iterable):
|
|
# end epoch if stopping training completely or max batches for this epoch reached
|
|
if self.should_stop or batch_idx >= limit_batches:
|
|
break
|
|
|
|
self.fabric.call("on_train_batch_start", batch, batch_idx)
|
|
|
|
# check if optimizer should step in gradient accumulation
|
|
should_optim_step = self.global_step % self.grad_accum_steps == 0
|
|
if should_optim_step:
|
|
# currently only supports a single optimizer
|
|
self.fabric.call("on_before_optimizer_step", optimizer)
|
|
|
|
# optimizer step runs train step internally through closure
|
|
optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
|
|
self.fabric.call("on_before_zero_grad", optimizer)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
else:
|
|
# gradient accumulation -> no optimizer step
|
|
self.training_step(model=model, batch=batch, batch_idx=batch_idx)
|
|
|
|
self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
|
|
|
|
# this guard ensures, we only step the scheduler once per global step
|
|
if should_optim_step:
|
|
self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
|
|
|
|
# add output values to progress bar
|
|
self._format_iterable(iterable, self._current_train_return, "train")
|
|
|
|
# only increase global step if optimizer stepped
|
|
self.global_step += int(should_optim_step)
|
|
|
|
# stopping criterion on step level
|
|
if self.max_steps is not None and self.global_step >= self.max_steps:
|
|
self.should_stop = True
|
|
break
|
|
|
|
self.fabric.call("on_train_epoch_end")
|
|
|
|
def val_loop(
|
|
self,
|
|
model: L.LightningModule,
|
|
val_loader: Optional[torch.utils.data.DataLoader],
|
|
limit_batches: Union[int, float] = float("inf"),
|
|
):
|
|
"""The validation loop running a single validation epoch.
|
|
|
|
Args:
|
|
model: the LightningModule to evaluate
|
|
val_loader: The dataloader yielding the validation batches.
|
|
limit_batches: Limits the batches during this validation epoch.
|
|
If greater than the number of batches in the ``val_loader``, this has no effect.
|
|
|
|
"""
|
|
# no validation if val_loader wasn't passed
|
|
if val_loader is None:
|
|
return
|
|
|
|
# no validation but warning if val_loader was passed, but validation_step not implemented
|
|
if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)):
|
|
L.fabric.utilities.rank_zero_warn(
|
|
"Your LightningModule does not have a validation_step implemented, "
|
|
"but you passed a validation dataloder. Skipping Validation."
|
|
)
|
|
return
|
|
|
|
if not is_overridden("on_validation_model_eval", _unwrap_objects(model)):
|
|
model.eval()
|
|
else:
|
|
self.fabric.call("on_validation_model_eval") # calls `model.eval()`
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
self.fabric.call("on_validation_epoch_start")
|
|
|
|
iterable = self.progbar_wrapper(val_loader, total=min(len(val_loader), limit_batches), desc="Validation")
|
|
|
|
for batch_idx, batch in enumerate(iterable):
|
|
# end epoch if stopping training completely or max batches for this epoch reached
|
|
if self.should_stop or batch_idx >= limit_batches:
|
|
break
|
|
|
|
self.fabric.call("on_validation_batch_start", batch, batch_idx)
|
|
|
|
out = model.validation_step(batch, batch_idx)
|
|
# avoid gradients in stored/accumulated values -> prevents potential OOM
|
|
out = apply_to_collection(out, torch.Tensor, lambda x: x.detach())
|
|
|
|
self.fabric.call("on_validation_batch_end", out, batch, batch_idx)
|
|
self._current_val_return = out
|
|
|
|
self._format_iterable(iterable, self._current_val_return, "val")
|
|
|
|
self.fabric.call("on_validation_epoch_end")
|
|
|
|
if not is_overridden("on_validation_model_train", _unwrap_objects(model)):
|
|
model.train()
|
|
else:
|
|
self.fabric.call("on_validation_model_train")
|
|
torch.set_grad_enabled(True)
|
|
|
|
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
|
|
"""A single training step, running forward and backward. The optimizer step is called separately, as this is
|
|
given as a closure to the optimizer step.
|
|
|
|
Args:
|
|
model: the lightning module to train
|
|
batch: the batch to run the forward on
|
|
batch_idx: index of the current batch w.r.t the current epoch
|
|
|
|
"""
|
|
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
|
|
|
|
loss = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]
|
|
|
|
self.fabric.call("on_before_backward", loss)
|
|
self.fabric.backward(loss)
|
|
self.fabric.call("on_after_backward")
|
|
|
|
# avoid gradients in stored/accumulated values -> prevents potential OOM
|
|
self._current_train_return = apply_to_collection(outputs, dtype=torch.Tensor, function=lambda x: x.detach())
|
|
|
|
return loss
|
|
|
|
def step_scheduler(
|
|
self,
|
|
model: L.LightningModule,
|
|
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
|
|
level: Literal["step", "epoch"],
|
|
current_value: int,
|
|
) -> None:
|
|
"""Steps the learning rate scheduler if necessary.
|
|
|
|
Args:
|
|
model: The LightningModule to train
|
|
scheduler_cfg: The learning rate scheduler configuration.
|
|
Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values.
|
|
level: whether we are trying to step on epoch- or step-level
|
|
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
|
|
|
|
"""
|
|
|
|
# no scheduler
|
|
if scheduler_cfg is None:
|
|
return
|
|
|
|
# wrong interval (step vs. epoch)
|
|
if scheduler_cfg["interval"] != level:
|
|
return
|
|
|
|
# right interval, but wrong step wrt frequency
|
|
if current_value % cast(int, scheduler_cfg["frequency"]) != 0:
|
|
return
|
|
|
|
# assemble potential monitored values
|
|
possible_monitor_vals = {None: None}
|
|
if isinstance(self._current_train_return, torch.Tensor):
|
|
possible_monitor_vals.update("train_loss", self._current_train_return)
|
|
elif isinstance(self._current_train_return, Mapping):
|
|
possible_monitor_vals.update({"train_" + k: v for k, v in self._current_train_return.items()})
|
|
|
|
if isinstance(self._current_val_return, torch.Tensor):
|
|
possible_monitor_vals.update("val_loss", self._current_val_return)
|
|
elif isinstance(self._current_val_return, Mapping):
|
|
possible_monitor_vals.update({"val_" + k: v for k, v in self._current_val_return.items()})
|
|
|
|
try:
|
|
monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])]
|
|
except KeyError as ex:
|
|
possible_keys = list(possible_monitor_vals.keys())
|
|
raise KeyError(
|
|
f"monitor {scheduler_cfg['monitor']} is invalid. Possible values are {possible_keys}."
|
|
) from ex
|
|
|
|
# rely on model hook for actual step
|
|
model.lr_scheduler_step(scheduler_cfg["scheduler"], monitor)
|
|
|
|
@property
|
|
def should_validate(self) -> bool:
|
|
"""Whether to currently run validation."""
|
|
return self.current_epoch % self.validation_frequency == 0
|
|
|
|
def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any):
|
|
"""Wraps the iterable with tqdm for global rank zero.
|
|
|
|
Args:
|
|
iterable: the iterable to wrap with tqdm
|
|
total: the total length of the iterable, necessary in case the number of batches was limited.
|
|
|
|
"""
|
|
if self.fabric.is_global_zero:
|
|
return tqdm(iterable, total=total, **kwargs)
|
|
return iterable
|
|
|
|
def load(self, state: Optional[Mapping], path: str) -> None:
|
|
"""Loads a checkpoint from a given file into state.
|
|
|
|
Args:
|
|
state: a mapping contaning model, optimizer and lr scheduler
|
|
path: the path to load the checkpoint from
|
|
|
|
"""
|
|
if state is None:
|
|
state = {}
|
|
|
|
remainder = self.fabric.load(path, state)
|
|
self.global_step = remainder.pop("global_step")
|
|
self.current_epoch = remainder.pop("current_epoch")
|
|
|
|
if remainder:
|
|
raise RuntimeError(f"Unused Checkpoint Values: {remainder}")
|
|
|
|
def save(self, state: Optional[Mapping]) -> None:
|
|
"""Saves a checkpoint to the ``checkpoint_dir``
|
|
|
|
Args:
|
|
state: A mapping containing model, optimizer and lr scheduler.
|
|
|
|
"""
|
|
if state is None:
|
|
state = {}
|
|
|
|
state.update(global_step=self.global_step, current_epoch=self.current_epoch)
|
|
|
|
self.fabric.save(os.path.join(self.checkpoint_dir, f"epoch-{self.current_epoch:04d}.ckpt"), state)
|
|
|
|
@staticmethod
|
|
def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
|
|
"""Returns the latest checkpoint from the ``checkpoint_dir``
|
|
|
|
Args:
|
|
checkpoint_dir: the directory to search for checkpoints
|
|
|
|
"""
|
|
if not os.path.isdir(checkpoint_dir):
|
|
return None
|
|
|
|
items = sorted(os.listdir(checkpoint_dir))
|
|
|
|
if not items:
|
|
return None
|
|
|
|
return os.path.join(checkpoint_dir, items[-1])
|
|
|
|
def _parse_optimizers_schedulers(
|
|
self, configure_optim_output
|
|
) -> tuple[
|
|
Optional[L.fabric.utilities.types.Optimizable],
|
|
Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
|
|
]:
|
|
"""Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
|
|
|
|
Args:
|
|
configure_optim_output: The output of ``configure_optimizers``.
|
|
For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
|
|
|
|
"""
|
|
_lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"}
|
|
|
|
# single optimizer
|
|
if isinstance(configure_optim_output, L.fabric.utilities.types.Optimizable):
|
|
return configure_optim_output, None
|
|
|
|
# single lr scheduler
|
|
if isinstance(configure_optim_output, L.fabric.utilities.types.LRScheduler):
|
|
return None, _lr_sched_defaults.update(scheduler=configure_optim_output)
|
|
|
|
# single lr scheduler config
|
|
if isinstance(configure_optim_output, Mapping):
|
|
_lr_sched_defaults.update(configure_optim_output)
|
|
return None, _lr_sched_defaults
|
|
|
|
# list or tuple
|
|
if isinstance(configure_optim_output, (list, tuple)):
|
|
if all(isinstance(_opt_cand, L.fabric.utilities.types.Optimizable) for _opt_cand in configure_optim_output):
|
|
# single optimizer in list
|
|
if len(configure_optim_output) == 1:
|
|
return configure_optim_output[0][0], None
|
|
|
|
raise NotImplementedError("BYOT only supports a single optimizer")
|
|
|
|
if all(
|
|
isinstance(_lr_cand, (L.fabric.utilities.types.LRScheduler, Mapping))
|
|
for _lr_cand in configure_optim_output
|
|
):
|
|
# single scheduler in list
|
|
if len(configure_optim_output) == 1:
|
|
return None, self._parse_optimizers_schedulers(configure_optim_output[0])[1]
|
|
|
|
# optimizer and lr scheduler
|
|
elif len(configure_optim_output) == 2:
|
|
opt_cands, lr_cands = (
|
|
self._parse_optimizers_schedulers(configure_optim_output[0])[0],
|
|
self._parse_optimizers_schedulers(configure_optim_output[1])[1],
|
|
)
|
|
return opt_cands, lr_cands
|
|
|
|
return None, None
|
|
|
|
@staticmethod
|
|
def _format_iterable(
|
|
prog_bar, candidates: Optional[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, float, int]]]], prefix: str
|
|
):
|
|
"""Adds values as postfix string to progressbar.
|
|
|
|
Args:
|
|
prog_bar: a progressbar (on global rank zero) or an iterable (every other rank).
|
|
candidates: the values to add as postfix strings to the progressbar.
|
|
prefix: the prefix to add to each of these values.
|
|
|
|
"""
|
|
if isinstance(prog_bar, tqdm) and candidates is not None:
|
|
postfix_str = ""
|
|
float_candidates = apply_to_collection(candidates, torch.Tensor, lambda x: x.item())
|
|
if isinstance(candidates, torch.Tensor):
|
|
postfix_str += f" {prefix}_loss: {float_candidates:.3f}"
|
|
elif isinstance(candidates, Mapping):
|
|
for k, v in float_candidates.items():
|
|
postfix_str += f" {prefix}_{k}: {v:.3f}"
|
|
|
|
if postfix_str:
|
|
prog_bar.set_postfix_str(postfix_str)
|