Minor BYOT Follow-up (#17076)

This commit is contained in:
Justus Schock 2023-03-14 12:54:15 +01:00 committed by GitHub
parent f1f8050e0e
commit 0154c6cd6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 6 deletions

View File

@ -124,6 +124,7 @@ class MyCustomTrainer:
model: L.LightningModule,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
ckpt_path: Optional[str] = None,
):
"""The main entrypoint of the trainer, triggering the actual training.
@ -133,6 +134,8 @@ class MyCustomTrainer:
train_loader: the training dataloader. Has to be an iterable returning batches.
val_loader: the validation dataloader. Has to be an iterable returning batches.
If not specified, no validation will run.
ckpt_path: Path to previous checkpoints to resume training from.
If specified, will always look for the latest checkpoint within the given directory.
"""
self.fabric.launch()
@ -155,13 +158,14 @@ class MyCustomTrainer:
state = {"model": model, "optim": optimizer, "scheduler": scheduler_cfg}
# load last checkpoint if available
latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir)
if latest_checkpoint_path is not None:
self.load(state, latest_checkpoint_path)
if ckpt_path is not None and os.path.isdir(ckpt_path):
latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir)
if latest_checkpoint_path is not None:
self.load(state, latest_checkpoint_path)
# check if we even need to train here
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
self.should_stop = True
# check if we even need to train here
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
self.should_stop = True
while not self.should_stop:
self.train_loop(
@ -181,6 +185,9 @@ class MyCustomTrainer:
self.save(state)
# reset for next fit call
self.should_stop = False
def train_loop(
self,
model: L.LightningModule,