import os from collections.abc import Mapping from functools import partial from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast import lightning as L import torch from lightning.fabric.accelerators import Accelerator from lightning.fabric.loggers import Logger from lightning.fabric.strategies import Strategy from lightning.fabric.wrappers import _unwrap_objects from lightning.pytorch.utilities.model_helpers import is_overridden from lightning_utilities import apply_to_collection from tqdm import tqdm class MyCustomTrainer: def __init__( self, accelerator: Union[str, Accelerator] = "auto", strategy: Union[str, Strategy] = "auto", devices: Union[List[int], str, int] = "auto", precision: Union[str, int] = "32-true", plugins: Optional[Union[str, Any]] = None, callbacks: Optional[Union[List[Any], Any]] = None, loggers: Optional[Union[Logger, List[Logger]]] = None, max_epochs: Optional[int] = 1000, max_steps: Optional[int] = None, grad_accum_steps: int = 1, limit_train_batches: Union[int, float] = float("inf"), limit_val_batches: Union[int, float] = float("inf"), validation_frequency: int = 1, use_distributed_sampler: bool = True, checkpoint_dir: str = "./checkpoints", checkpoint_frequency: int = 1, ) -> None: """Exemplary Trainer with Fabric. This is a very simple trainer focused on readablity but with reduced featureset. As a trainer with more included features, we recommend using the :class:`lightning.pytorch.Trainer`. Args: accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. strategy: Strategy for how to run across multiple devices. Possible choices are: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``. devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``. The value applies per node. precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``), or bfloat16 precision AMP (``"bf16-mixed"``). plugins: One or several custom plugins callbacks: A single callback or a list of callbacks. The following hooks are supported: - on_train_epoch_start - on train_epoch_end - on_train_batch_start - on_train_batch_end - on_before_backward - on_after_backward - on_before_zero_grad - on_before_optimizer_step - on_validation_model_eval - on_validation_model_train - on_validation_epoch_start - on_validation_epoch_end - on_validation_batch_start - on_validation_batch_end loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more information. max_epochs: The maximum number of epochs to train max_steps: The maximum number of (optimizer) steps to train grad_accum_steps: How many batches to process before each optimizer step limit_train_batches: Limits the number of train batches per epoch If greater than number of batches in the dataloader, this has no effect. limit_val_batches: Limits the number of validation batches per epoch. If greater than number of batches in the dataloader, this has no effect. validation_frequency: How many epochs to run before each validation epoch. use_distributed_sampler: Wraps the sampler of each dataloader with a respective distributed-aware sampler in case of distributed training. checkpoint_dir: Directory to store checkpoints to. checkpoint_frequency: How many epochs to run before each checkpoint is written. Warning: callbacks written for the lightning trainer (especially making assumptions on the trainer), won't work! """ self.fabric = L.Fabric( accelerator=accelerator, strategy=strategy, devices=devices, precision=precision, plugins=plugins, callbacks=callbacks, loggers=loggers, ) self.global_step = 0 self.grad_accum_steps: int = grad_accum_steps self.current_epoch = 0 self.max_epochs = max_epochs self.max_steps = max_steps self.should_stop = False # ensures limit_X_batches is either int or inf if not isinstance(limit_train_batches, int): assert limit_train_batches == float("inf") if not isinstance(limit_val_batches, int): assert limit_val_batches == float("inf") self.limit_train_batches = limit_train_batches self.limit_val_batches = limit_val_batches self.validation_frequency = validation_frequency self.use_distributed_sampler = use_distributed_sampler self._current_train_return: Union[torch.Tensor, Mapping[str, Any]] = {} self._current_val_return: Optional[Union[torch.Tensor, Mapping[str, Any]]] = {} self.checkpoint_dir = checkpoint_dir self.checkpoint_frequency = checkpoint_frequency def fit( self, model: L.LightningModule, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, ckpt_path: Optional[str] = None, ): """The main entrypoint of the trainer, triggering the actual training. Args: model: the LightningModule to train. Can have the same hooks as :attr:`callbacks` (see :meth:`MyCustomTrainer.__init__`). train_loader: the training dataloader. Has to be an iterable returning batches. val_loader: the validation dataloader. Has to be an iterable returning batches. If not specified, no validation will run. ckpt_path: Path to previous checkpoints to resume training from. If specified, will always look for the latest checkpoint within the given directory. """ self.fabric.launch() # setup dataloaders train_loader = self.fabric.setup_dataloaders(train_loader, use_distributed_sampler=self.use_distributed_sampler) if val_loader is not None: val_loader = self.fabric.setup_dataloaders(val_loader, use_distributed_sampler=self.use_distributed_sampler) # setup model and optimizer if isinstance(self.fabric.strategy, L.fabric.strategies.fsdp.FSDPStrategy): # currently, there is no way to support fsdp with model.configure_optimizers in fabric # as it would require fabric to hold a reference to the model, which we don't want to. raise NotImplementedError("BYOT currently does not support FSDP") optimizer, scheduler_cfg = self._parse_optimizers_schedulers(model.configure_optimizers()) assert optimizer is not None model, optimizer = self.fabric.setup(model, optimizer) # assemble state (current epoch and global step will be added in save) state = {"model": model, "optim": optimizer, "scheduler": scheduler_cfg} # load last checkpoint if available if ckpt_path is not None and os.path.isdir(ckpt_path): latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir) if latest_checkpoint_path is not None: self.load(state, latest_checkpoint_path) # check if we even need to train here if self.max_epochs is not None and self.current_epoch >= self.max_epochs: self.should_stop = True while not self.should_stop: self.train_loop( model, optimizer, train_loader, limit_batches=self.limit_train_batches, scheduler_cfg=scheduler_cfg ) if self.should_validate: self.val_loop(model, val_loader, limit_batches=self.limit_val_batches) self.step_scheduler(model, scheduler_cfg, level="epoch", current_value=self.current_epoch) self.current_epoch += 1 # stopping condition on epoch level if self.max_epochs is not None and self.current_epoch >= self.max_epochs: self.should_stop = True self.save(state) # reset for next fit call self.should_stop = False def train_loop( self, model: L.LightningModule, optimizer: torch.optim.Optimizer, train_loader: torch.utils.data.DataLoader, limit_batches: Union[int, float] = float("inf"), scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, ): """The training loop running a single training epoch. Args: model: the LightningModule to train optimizer: the optimizer, optimizing the LightningModule. train_loader: The dataloader yielding the training batches. limit_batches: Limits the batches during this training epoch. If greater than the number of batches in the ``train_loader``, this has no effect. scheduler_cfg: The learning rate scheduler configuration. Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` for supported values. """ self.fabric.call("on_train_epoch_start") iterable = self.progbar_wrapper( train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" ) for batch_idx, batch in enumerate(iterable): # end epoch if stopping training completely or max batches for this epoch reached if self.should_stop or batch_idx >= limit_batches: break self.fabric.call("on_train_batch_start", batch, batch_idx) # check if optimizer should step in gradient accumulation should_optim_step = self.global_step % self.grad_accum_steps == 0 if should_optim_step: # currently only supports a single optimizer self.fabric.call("on_before_optimizer_step", optimizer) # optimizer step runs train step internally through closure optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) self.fabric.call("on_before_zero_grad", optimizer) optimizer.zero_grad() else: # gradient accumulation -> no optimizer step self.training_step(model=model, batch=batch, batch_idx=batch_idx) self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) # this guard ensures, we only step the scheduler once per global step if should_optim_step: self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) # add output values to progress bar self._format_iterable(iterable, self._current_train_return, "train") # only increase global step if optimizer stepped self.global_step += int(should_optim_step) # stopping criterion on step level if self.max_steps is not None and self.global_step >= self.max_steps: self.should_stop = True break self.fabric.call("on_train_epoch_end") def val_loop( self, model: L.LightningModule, val_loader: Optional[torch.utils.data.DataLoader], limit_batches: Union[int, float] = float("inf"), ): """The validation loop running a single validation epoch. Args: model: the LightningModule to evaluate val_loader: The dataloader yielding the validation batches. limit_batches: Limits the batches during this validation epoch. If greater than the number of batches in the ``val_loader``, this has no effect. """ # no validation if val_loader wasn't passed if val_loader is None: return # no validation but warning if val_loader was passed, but validation_step not implemented if val_loader is not None and not is_overridden("validation_step", _unwrap_objects(model)): L.fabric.utilities.rank_zero_warn( "Your LightningModule does not have a validation_step implemented, " "but you passed a validation dataloder. Skipping Validation." ) return if not is_overridden("on_validation_model_eval", _unwrap_objects(model)): model.eval() else: self.fabric.call("on_validation_model_eval") # calls `model.eval()` torch.set_grad_enabled(False) self.fabric.call("on_validation_epoch_start") iterable = self.progbar_wrapper(val_loader, total=min(len(val_loader), limit_batches), desc="Validation") for batch_idx, batch in enumerate(iterable): # end epoch if stopping training completely or max batches for this epoch reached if self.should_stop or batch_idx >= limit_batches: break self.fabric.call("on_validation_batch_start", batch, batch_idx) out = model.validation_step(batch, batch_idx) # avoid gradients in stored/accumulated values -> prevents potential OOM out = apply_to_collection(out, torch.Tensor, lambda x: x.detach()) self.fabric.call("on_validation_batch_end", out, batch, batch_idx) self._current_val_return = out self._format_iterable(iterable, self._current_val_return, "val") self.fabric.call("on_validation_epoch_end") if not is_overridden("on_validation_model_train", _unwrap_objects(model)): model.train() else: self.fabric.call("on_validation_model_train") torch.set_grad_enabled(True) def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor: """A single training step, running forward and backward. The optimizer step is called separately, as this is given as a closure to the optimizer step. Args: model: the lightning module to train batch: the batch to run the forward on batch_idx: index of the current batch w.r.t the current epoch """ outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx) loss = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"] self.fabric.call("on_before_backward", loss) self.fabric.backward(loss) self.fabric.call("on_after_backward") # avoid gradients in stored/accumulated values -> prevents potential OOM self._current_train_return = apply_to_collection(outputs, dtype=torch.Tensor, function=lambda x: x.detach()) return loss def step_scheduler( self, model: L.LightningModule, scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]], level: Literal["step", "epoch"], current_value: int, ) -> None: """Steps the learning rate scheduler if necessary. Args: model: The LightningModule to train scheduler_cfg: The learning rate scheduler configuration. Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values. level: whether we are trying to step on epoch- or step-level current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step`` """ # no scheduler if scheduler_cfg is None: return # wrong interval (step vs. epoch) if scheduler_cfg["interval"] != level: return # right interval, but wrong step wrt frequency if current_value % cast(int, scheduler_cfg["frequency"]) != 0: return # assemble potential monitored values possible_monitor_vals = {None: None} if isinstance(self._current_train_return, torch.Tensor): possible_monitor_vals.update("train_loss", self._current_train_return) elif isinstance(self._current_train_return, Mapping): possible_monitor_vals.update({"train_" + k: v for k, v in self._current_train_return.items()}) if isinstance(self._current_val_return, torch.Tensor): possible_monitor_vals.update("val_loss", self._current_val_return) elif isinstance(self._current_val_return, Mapping): possible_monitor_vals.update({"val_" + k: v for k, v in self._current_val_return.items()}) try: monitor = possible_monitor_vals[cast(Optional[str], scheduler_cfg["monitor"])] except KeyError as ex: possible_keys = list(possible_monitor_vals.keys()) raise KeyError( f"monitor {scheduler_cfg['monitor']} is invalid. Possible values are {possible_keys}." ) from ex # rely on model hook for actual step model.lr_scheduler_step(scheduler_cfg["scheduler"], monitor) @property def should_validate(self) -> bool: """Whether to currently run validation.""" return self.current_epoch % self.validation_frequency == 0 def progbar_wrapper(self, iterable: Iterable, total: int, **kwargs: Any): """Wraps the iterable with tqdm for global rank zero. Args: iterable: the iterable to wrap with tqdm total: the total length of the iterable, necessary in case the number of batches was limited. """ if self.fabric.is_global_zero: return tqdm(iterable, total=total, **kwargs) return iterable def load(self, state: Optional[Mapping], path: str) -> None: """Loads a checkpoint from a given file into state. Args: state: a mapping contaning model, optimizer and lr scheduler path: the path to load the checkpoint from """ if state is None: state = {} remainder = self.fabric.load(path, state) self.global_step = remainder.pop("global_step") self.current_epoch = remainder.pop("current_epoch") if remainder: raise RuntimeError(f"Unused Checkpoint Values: {remainder}") def save(self, state: Optional[Mapping]) -> None: """Saves a checkpoint to the ``checkpoint_dir`` Args: state: A mapping containing model, optimizer and lr scheduler. """ if state is None: state = {} state.update(global_step=self.global_step, current_epoch=self.current_epoch) self.fabric.save(os.path.join(self.checkpoint_dir, f"epoch-{self.current_epoch:04d}.ckpt"), state) @staticmethod def get_latest_checkpoint(checkpoint_dir: str) -> Optional[str]: """Returns the latest checkpoint from the ``checkpoint_dir`` Args: checkpoint_dir: the directory to search for checkpoints """ if not os.path.isdir(checkpoint_dir): return None items = sorted(os.listdir(checkpoint_dir)) if not items: return None return os.path.join(checkpoint_dir, items[-1]) def _parse_optimizers_schedulers( self, configure_optim_output ) -> Tuple[ Optional[L.fabric.utilities.types.Optimizable], Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]], ]: """Recursively parses the output of :meth:`lightning.pytorch.LightningModule.configure_optimizers`. Args: configure_optim_output: The output of ``configure_optimizers``. For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`. """ _lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"} # single optimizer if isinstance(configure_optim_output, L.fabric.utilities.types.Optimizable): return configure_optim_output, None # single lr scheduler if isinstance(configure_optim_output, L.fabric.utilities.types.LRScheduler): return None, _lr_sched_defaults.update(scheduler=configure_optim_output) # single lr scheduler config if isinstance(configure_optim_output, Mapping): _lr_sched_defaults.update(configure_optim_output) return None, _lr_sched_defaults # list or tuple if isinstance(configure_optim_output, (list, tuple)): if all(isinstance(_opt_cand, L.fabric.utilities.types.Optimizable) for _opt_cand in configure_optim_output): # single optimizer in list if len(configure_optim_output) == 1: return configure_optim_output[0][0], None raise NotImplementedError("BYOT only supports a single optimizer") if all( isinstance(_lr_cand, (L.fabric.utilities.types.LRScheduler, Mapping)) for _lr_cand in configure_optim_output ): # single scheduler in list if len(configure_optim_output) == 1: return None, self._parse_optimizers_schedulers(configure_optim_output[0])[1] # optimizer and lr scheduler elif len(configure_optim_output) == 2: opt_cands, lr_cands = ( self._parse_optimizers_schedulers(configure_optim_output[0])[0], self._parse_optimizers_schedulers(configure_optim_output[1])[1], ) return opt_cands, lr_cands return None, None @staticmethod def _format_iterable( prog_bar, candidates: Optional[Union[torch.Tensor, Mapping[str, Union[torch.Tensor, float, int]]]], prefix: str ): """Adds values as postfix string to progressbar. Args: prog_bar: a progressbar (on global rank zero) or an iterable (every other rank). candidates: the values to add as postfix strings to the progressbar. prefix: the prefix to add to each of these values. """ if isinstance(prog_bar, tqdm) and candidates is not None: postfix_str = "" float_candidates = apply_to_collection(candidates, torch.Tensor, lambda x: x.item()) if isinstance(candidates, torch.Tensor): postfix_str += f" {prefix}_loss: {float_candidates:.3f}" elif isinstance(candidates, Mapping): for k, v in float_candidates.items(): postfix_str += f" {prefix}_{k}: {v:.3f}" if postfix_str: prog_bar.set_postfix_str(postfix_str)