lightning/examples/fabric/build_your_own_trainer/trainer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

533 lines
23 KiB
Python
Raw Normal View History

import os
from collections.abc import Mapping
from functools import partial
from typing import Any, cast, Iterable, List, Literal, Optional, Tuple, Union
import torch
from lightning_utilities import apply_to_collection, is_overridden
from tqdm import tqdm
import lightning as L
from lightning.fabric.accelerators import Accelerator
from lightning.fabric.loggers import Logger
from lightning.fabric.strategies import Strategy
from lightning.fabric.wrappers import _unwrap_objects
class MyCustomTrainer:
def __init__(
self,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[List[int], str, int] = "auto",
precision: Union[str, int] = "32-true",
plugins: Optional[Union[str, Any]] = None,
callbacks: Optional[Union[List[Any], Any]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
max_epochs: Optional[int] = 1000,
max_steps: Optional[int] = None,
grad_accum_steps: int = 1,
limit_train_batches: Union[int, float] = float("inf"),
limit_val_batches: Union[int, float] = float("inf"),
validation_frequency: int = 1,
use_distributed_sampler: bool = True,
checkpoint_dir: str = "./checkpoints",
checkpoint_frequency: int = 1,
) -> None:
"""Exemplary Trainer with Fabric. This is a very simple trainer focused on readablity but with reduced
featureset. As a trainer with more included features, we recommend using the
:class:`lightning.pytorch.Trainer`.
Args:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
devices: Number of devices to train on (``int``),
which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
or bfloat16 precision AMP (``"bf16-mixed"``).
plugins: One or several custom plugins
callbacks: A single callback or a list of callbacks. The following hooks are supported:
- on_train_epoch_start
- on train_epoch_end
- on_train_batch_start
- on_train_batch_end
- on_before_backward
- on_after_backward
- on_before_zero_grad
- on_before_optimizer_step
- on_validation_model_eval
- on_validation_model_train
- on_validation_epoch_start
- on_validation_epoch_end
- on_validation_batch_start
- on_validation_batch_end
loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more
information.
max_epochs: The maximum number of epochs to train
max_steps: The maximum number of (optimizer) steps to train
grad_accum_steps: How many batches to process before each optimizer step
limit_train_batches: Limits the number of train batches per epoch
If greater than number of batches in the dataloader, this has no effect.
limit_val_batches: Limits the number of validation batches per epoch.
If greater than number of batches in the dataloader, this has no effect.
validation_frequency: How many epochs to run before each validation epoch.
use_distributed_sampler: Wraps the sampler of each dataloader with a respective distributed-aware sampler
in case of distributed training.
checkpoint_dir: Directory to store checkpoints to.
checkpoint_frequency: How many epochs to run before each checkpoint is written.
Warning:
callbacks written for the lightning trainer (especially making assumptions on the trainer), won't work!
"""
self.fabric = L.Fabric(
accelerator=accelerator,
strategy=strategy,
devices=devices,
precision=precision,
plugins=plugins,
callbacks=callbacks,
loggers=loggers,
)
self.global_step = 0
self.grad_accum_steps: int = grad_accum_steps
self.current_epoch = 0
self.max_epochs = max_epochs
self.max_steps = max_steps
self.should_stop = False
# ensures limit_X_batches is either int or inf
if not isinstance(limit_train_batches, int):
assert limit_train_batches == float("inf")
if not isinstance(limit_val_batches, int):
assert limit_val_batches == float("inf")
self.limit_train_batches = limit_train_batches
self.limit_val_batches = limit_val_batches
self.validation_frequency = validation_frequency
self.use_distributed_sampler = use_distributed_sampler
self._current_train_return: Union[torch.Tensor, Mapping[str, Any]] = {}
self._current_val_return: Optional[Union[torch.Tensor, Mapping[str, Any]]] = {}
self.checkpoint_dir = checkpoint_dir
self.checkpoint_frequency = checkpoint_frequency
def fit(
self,
model: L.LightningModule,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
2023-03-14 11:54:15 +00:00
ckpt_path: Optional[str] = None,
):
"""The main entrypoint of the trainer, triggering the actual training.
Args:
model: the LightningModule to train.
Can have the same hooks as :attr:`callbacks` (see :meth:`MyCustomTrainer.__init__`).
train_loader: the training dataloader. Has to be an iterable returning batches.
val_loader: the validation dataloader. Has to be an iterable returning batches.
If not specified, no validation will run.
2023-03-14 11:54:15 +00:00
ckpt_path: Path to previous checkpoints to resume training from.
If specified, will always look for the latest checkpoint within the given directory.
"""
self.fabric.launch()
# setup dataloaders
train_loader = self.fabric.setup_dataloaders(train_loader, use_distributed_sampler=self.use_distributed_sampler)
if val_loader is not None:
val_loader = self.fabric.setup_dataloaders(val_loader, use_distributed_sampler=self.use_distributed_sampler)
# setup model and optimizer
if isinstance(self.fabric.strategy, L.fabric.strategies.fsdp.FSDPStrategy):
# currently, there is no way to support fsdp with model.configure_optimizers in fabric
# as it would require fabric to hold a reference to the model, which we don't want to.
raise NotImplementedError("BYOT currently does not support FSDP")
else:
optimizer, scheduler_cfg = self._parse_optimizers_schedulers(model.configure_optimizers())
assert optimizer is not None
model, optimizer = self.fabric.setup(model, optimizer)
# assemble state (current epoch and global step will be added in save)
state = {"model": model, "optim": optimizer, "scheduler": scheduler_cfg}
# load last checkpoint if available
2023-03-14 11:54:15 +00:00
if ckpt_path is not None and os.path.isdir(ckpt_path):
latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir)
if latest_checkpoint_path is not None:
self.load(state, latest_checkpoint_path)
2023-03-14 11:54:15 +00:00
# check if we even need to train here
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
self.should_stop = True
while not self.should_stop:
self.train_loop(
model, optimizer, train_loader, limit_batches=self.limit_train_batches, scheduler_cfg=scheduler_cfg
)
if self.should_validate:
self.val_loop(model, val_loader, limit_batches=self.limit_val_batches)
self.step_scheduler(model, scheduler_cfg, level="epoch", current_value=self.current_epoch)
self.current_epoch += 1
# stopping condition on epoch level
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
self.should_stop = True
self.save(state)
2023-03-14 11:54:15 +00:00
# reset for next fit call
self.should_stop = False
def train_loop(
self,
model: L.LightningModule,
optimizer: torch.optim.Optimizer,
train_loader: torch.utils.data.DataLoader,
limit_batches: Union[int, float] = float("inf"),
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
):
"""The training loop running a single training epoch.
Args:
model: the LightningModule to train
optimizer: the optimizer, optimizing the LightningModule.
train_loader: The dataloader yielding the training batches.
limit_batches: Limits the batches during this training epoch.
If greater then the number of batches in the ``train_loader``, this has no effect.
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
"""
self.fabric.call("on_train_epoch_start")
iterable = self.progbar_wrapper(
train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
)
for batch_idx, batch in enumerate(iterable):
# end epoch if stopping training completely or max batches for this epoch reached
if self.should_stop or batch_idx >= limit_batches:
self.fabric.call("on_train_epoch_end")
return
self.fabric.call("on_train_batch_start", batch, batch_idx)
# check if optimizer should step in gradient accumulation
should_optim_step = self.global_step % self.grad_accum_steps == 0
if should_optim_step:
# currently only supports a single optimizer
self.fabric.call("on_before_optimizer_step", optimizer, 0)
# optimizer step runs train step internally through closure
optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
self.fabric.call("on_before_zero_grad", optimizer)
optimizer.zero_grad()
else:
# gradient accumulation -> no optimizer step
self.training_step(model=model, batch=batch, batch_idx=batch_idx)
self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
# this guard ensures, we only step the scheduler once per global step
if should_optim_step:
self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
# add output values to progress bar
self._format_iterable(iterable, self._current_train_return, "train")
# only increase global step if optimizer stepped
self.global_step += int(should_optim_step)
# stopping criterion on step level
if self.max_steps is not None and self.global_step >= self.max_steps:
self.should_stop = True
break
self.fabric.call("on_train_epoch_end")
def val_loop(
self,
model: L.LightningModule,
val_loader: Optional[torch.utils.data.DataLoader],
limit_batches: Union[int, float] = float("inf"),
):
"""The validation loop ruunning a single validation epoch.
Args:
model: the LightningModule to evaluate
val_loader: The dataloader yielding the validation batches.
limit_batches: Limits the batches during this validation epoch.
If greater then the number of batches in the ``val_loader``, this has no effect.
"""
# no validation if val_loader wasn't passed
if val_loader is None:
return
# no validation but warning if val_loader was passed, but validation_step not implemented
elif val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model), L.LightningModule):
L.fabric.utilities.rank_zero_warn(
"Your LightningModule does not have a validation_step implemented, "
"but you passed a validation dataloder. Skipping Validation."
)
return
self.fabric.call("on_validation_model_eval") # calls `model.eval()`
torch.set_grad_enabled(False)
self.fabric.call("on_validation_epoch_start")
iterable = self.progbar_wrapper(val_loader, total=min(len(val_loader), limit_batches), desc="Validation")
for batch_idx, batch in enumerate(iterable):
# end epoch if stopping training completely or max batches for this epoch reached
if self.should_stop or batch_idx >= limit_batches:
self.fabric.call("on_validation_epoch_end")
return
self.fabric.call("on_validation_batch_start", batch, batch_idx)
out = model.validation_step(batch, batch_idx)
# avoid gradients in stored/accumulated values -> prevents potential OOM
out = apply_to_collection(out, torch.Tensor, lambda x: x.detach())
self.fabric.call("on_validation_batch_end", out, batch, batch_idx)
self._current_val_return = out
self._format_iterable(iterable, self._current_val_return, "val")
self.fabric.call("on_validation_epoch_end")
self.fabric.call("on_validation_model_train")
torch.set_grad_enabled(True)
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
"""A single training step, running forward and backward. The optimizer step is called separately, as this
is given as a closure to the optimizer step.
Args:
model: the lightning module to train
batch: the batch to run the forward on
batch_idx: index of the current batch w.r.t the current epoch
"""
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
loss = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]
self.fabric.call("on_before_backward", loss)
self.fabric.backward(loss)
self.fabric.call("on_after_backward")
# avoid gradients in stored/accumulated values -> prevents potential OOM
self._current_train_return = apply_to_collection(outputs, dtype=torch.Tensor, function=lambda x: x.detach())
return loss
def step_scheduler(
self,
model: L.LightningModule,
scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
level: Literal["step", "epoch"],
current_value: int,
) -> None:
"""Steps the learning rate scheduler if necessary.
Args:
model: The LightningModule to train
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
level: whether we are trying to step on epoch- or step-level
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
"""
# no scheduler
if scheduler_cfg is None:
return
# wrong interval (step vs. epoch)
if scheduler_cfg["interval"] != level:
return
# right interval, but wrong step wrt frequency
if current_value % cast(int, scheduler_cfg["frequency"]) != 0:
return
# assemble potential monitored values
possible_monitor_vals = {None: None}
if isinstance(self._current_train_return, torch.Tensor):
possible_monitor_vals.update("train_loss", self._current_train_return)
elif isinstance(self._current_train_return, Mapping):
possible_monitor_vals.update({"train_" + k: v for k, v in self._current_train_return.items()})
if isinstance(self._current_val_return, torch.Tensor):
possible_monitor_vals.update("val_loss", self._current_val_return)
elif isinstance(self._current_val_return, Mapping):
possible_monitor_vals.update({"val_" + k: v for k, v in self._current_val_return.items()})
try:
monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])]
except KeyError as e:
possible_keys = list(possible_monitor_vals.keys())
raise KeyError(
f"monitor {scheduler_cfg['monitor']} is invalid. Possible values are {possible_keys}."
) from e
# rely on model hook for actual step
model.lr_scheduler_step(scheduler_cfg["scheduler"], monitor)
@property
def should_validate(self) -> bool:
"""Whether to currently run validation."""
return self.current_epoch % self.validation_frequency == 0
def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any):
"""Wraps the iterable with tqdm for global rank zero.
Args:
iterable: the iterable to wrap with tqdm
total: the total length of the iterable, necessary in case the number of batches was limited.
"""
if self.fabric.is_global_zero:
return tqdm(iterable, total=total, **kwargs)
return iterable
def load(self, state: Optional[Mapping], path: str) -> None:
"""Loads a checkpoint from a given file into state.
Args:
state: a mapping contaning model, optimizer and lr scheduler
path: the path to load the checkpoint from
"""
if state is None:
state = {}
remainder = self.fabric.load(path, state)
self.global_step = remainder.pop("global_step")
self.current_epoch = remainder.pop("current_epoch")
if remainder:
raise RuntimeError(f"Unused Checkpoint Values: {remainder}")
def save(self, state: Optional[Mapping]) -> None:
"""Saves a checkpoint to the ``checkpoint_dir``
Args:
state: A mapping containing model, optimizer and lr scheduler.
"""
if state is None:
state = {}
state.update(global_step=self.global_step, current_epoch=self.current_epoch)
self.fabric.save(os.path.join(self.checkpoint_dir, f"epoch-{self.current_epoch:04d}.ckpt"), state)
@staticmethod
def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
"""Returns the latest checkpoint from the ``checkpoint_dir``
Args:
checkpoint_dir: the directory to search for checkpoints
"""
if not os.path.isdir(checkpoint_dir):
return None
items = sorted(os.listdir(checkpoint_dir))
if not items:
return None
return os.path.join(checkpoint_dir, items[-1])
def _parse_optimizers_schedulers(
self, configure_optim_output
) -> Tuple[
Optional[L.fabric.utilities.types.Optimizable],
Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]],
]:
"""Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
Args:
configure_optim_output: The output of ``configure_optimizers``.
For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
"""
_lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"}
# single optimizer
if isinstance(configure_optim_output, L.fabric.utilities.types.Optimizable):
return configure_optim_output, None
# single lr scheduler
elif isinstance(configure_optim_output, L.fabric.utilities.types.LRScheduler):
return None, _lr_sched_defaults.update(scheduler=configure_optim_output)
# single lr scheduler config
elif isinstance(configure_optim_output, Mapping):
_lr_sched_defaults.update(configure_optim_output)
return None, _lr_sched_defaults
# list or tuple
elif isinstance(configure_optim_output, (list, tuple)):
if all(
[isinstance(_opt_cand, L.fabric.utilities.types.Optimizable) for _opt_cand in configure_optim_output]
):
# single optimizer in list
if len(configure_optim_output) == 1:
return configure_optim_output[0][0], None
raise NotImplementedError("BYOT only supports a single optimizer")
elif all(
[
isinstance(_lr_cand, (L.fabric.utilities.types.LRScheduler, Mapping))
for _lr_cand in configure_optim_output
]
):
# single scheduler in list
if len(configure_optim_output) == 1:
return None, self._parse_optimizers_schedulers(configure_optim_output[0])[1]
# optimizer and lr scheduler
elif len(configure_optim_output) == 2:
opt_cands, lr_cands = (
self._parse_optimizers_schedulers(configure_optim_output[0])[0],
self._parse_optimizers_schedulers(configure_optim_output[1])[1],
)
return opt_cands, lr_cands
return None, None
@staticmethod
def _format_iterable(
prog_bar, candidates: Optional[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, float, int]]]], prefix: str
):
"""Adds values as postfix string to progressbar.
Args:
prog_bar: a progressbar (on global rank zero) or an iterable (every other rank).
candidates: the values to add as postfix strings to the progressbar.
prefix: the prefix to add to each of these values.
"""
if isinstance(prog_bar, tqdm) and candidates is not None:
postfix_str = ""
float_candidates = apply_to_collection(candidates, torch.Tensor, lambda x: x.item())
if isinstance(candidates, torch.Tensor):
postfix_str += f" {prefix}_loss: {float_candidates:.3f}"
elif isinstance(candidates, Mapping):
for k, v in float_candidates.items():
postfix_str += f" {prefix}_{k}: {v:.3f}"
if postfix_str:
prog_bar.set_postfix_str(postfix_str)