Fix docs for 'nn.Module from checkpoint' (#19053)

This commit is contained in:
Adrian Wälchli 2023-11-23 04:23:24 +01:00 committed by GitHub
parent 9a26da8081
commit 85adf17328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 3 deletions

View File

@ -58,8 +58,10 @@ To change the checkpoint path use the `default_root_dir` argument:
# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")
----
*******************************
LightningModule from checkpoint
*******************************
@ -136,8 +138,10 @@ In some cases, we may also pass entire PyTorch modules to the ``__init__`` metho
model = LitAutoEncoder.load_from_checkpoint(PATH, encoder=encoder, decoder=decoder)
----
*************************
nn.Module from checkpoint
*************************
@ -162,7 +166,9 @@ For example, let's pretend we created a LightningModule like so:
class Autoencoder(L.LightningModule):
def __init__(self, encoder, decoder, *args, **kwargs):
...
super().__init__()
self.encoder = encoder
self.decoder = decoder
autoencoder = Autoencoder(Encoder(), Decoder())
@ -172,11 +178,13 @@ Once the autoencoder has trained, pull out the relevant weights for your torch n
.. code-block:: python
checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
----
*********************
Disable checkpointing
*********************