Minor BYOT Follow-up (#17076)
This commit is contained in:
parent
f1f8050e0e
commit
0154c6cd6f
|
@ -124,6 +124,7 @@ class MyCustomTrainer:
|
|||
model: L.LightningModule,
|
||||
train_loader: torch.utils.data.DataLoader,
|
||||
val_loader: torch.utils.data.DataLoader,
|
||||
ckpt_path: Optional[str] = None,
|
||||
):
|
||||
"""The main entrypoint of the trainer, triggering the actual training.
|
||||
|
||||
|
@ -133,6 +134,8 @@ class MyCustomTrainer:
|
|||
train_loader: the training dataloader. Has to be an iterable returning batches.
|
||||
val_loader: the validation dataloader. Has to be an iterable returning batches.
|
||||
If not specified, no validation will run.
|
||||
ckpt_path: Path to previous checkpoints to resume training from.
|
||||
If specified, will always look for the latest checkpoint within the given directory.
|
||||
"""
|
||||
self.fabric.launch()
|
||||
|
||||
|
@ -155,13 +158,14 @@ class MyCustomTrainer:
|
|||
state = {"model": model, "optim": optimizer, "scheduler": scheduler_cfg}
|
||||
|
||||
# load last checkpoint if available
|
||||
latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir)
|
||||
if latest_checkpoint_path is not None:
|
||||
self.load(state, latest_checkpoint_path)
|
||||
if ckpt_path is not None and os.path.isdir(ckpt_path):
|
||||
latest_checkpoint_path = self.get_latest_checkpoint(self.checkpoint_dir)
|
||||
if latest_checkpoint_path is not None:
|
||||
self.load(state, latest_checkpoint_path)
|
||||
|
||||
# check if we even need to train here
|
||||
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
|
||||
self.should_stop = True
|
||||
# check if we even need to train here
|
||||
if self.max_epochs is not None and self.current_epoch >= self.max_epochs:
|
||||
self.should_stop = True
|
||||
|
||||
while not self.should_stop:
|
||||
self.train_loop(
|
||||
|
@ -181,6 +185,9 @@ class MyCustomTrainer:
|
|||
|
||||
self.save(state)
|
||||
|
||||
# reset for next fit call
|
||||
self.should_stop = False
|
||||
|
||||
def train_loop(
|
||||
self,
|
||||
model: L.LightningModule,
|
||||
|
|
Loading…
Reference in New Issue