diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 4b93ea171d..d30eacfaf5 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -124,6 +124,7 @@ class MyCustomTrainer: model: L.LightningModule, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, + ckpt_path: Optional[str] = None, ): """The main entrypoint of the trainer, triggering the actual training. @@ -133,6 +134,8 @@ class MyCustomTrainer: train_loader: the training dataloader. Has to be an iterable returning batches. val_loader: the validation dataloader. Has to be an iterable returning batches. If not specified, no validation will run. + ckpt_path: Path to previous checkpoints to resume training from. + If specified, will always look for the latest checkpoint within the given directory. """ self.fabric.launch() @@ -155,13 +158,14 @@ class MyCustomTrainer: state = {"model": model, "optim": optimizer, "scheduler": scheduler_cfg} # load last checkpoint if available - latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir) - if latest_checkpoint_path is not None: - self.load(state, latest_checkpoint_path) + if ckpt_path is not None and os.path.isdir(ckpt_path): + latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir) + if latest_checkpoint_path is not None: + self.load(state, latest_checkpoint_path) - # check if we even need to train here - if self.max_epochs is not None and self.current_epoch >= self.max_epochs: - self.should_stop = True + # check if we even need to train here + if self.max_epochs is not None and self.current_epoch >= self.max_epochs: + self.should_stop = True while not self.should_stop: self.train_loop( @@ -181,6 +185,9 @@ class MyCustomTrainer: self.save(state) + # reset for next fit call + self.should_stop = False + def train_loop( self, model: L.LightningModule,