Fix docs for 'nn.Module from checkpoint' (#19053)
This commit is contained in:
parent
9a26da8081
commit
85adf17328
|
@ -58,8 +58,10 @@ To change the checkpoint path use the `default_root_dir` argument:
|
|||
# saves checkpoints to 'some/path/' at every epoch end
|
||||
trainer = Trainer(default_root_dir="some/path/")
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*******************************
|
||||
LightningModule from checkpoint
|
||||
*******************************
|
||||
|
@ -136,8 +138,10 @@ In some cases, we may also pass entire PyTorch modules to the ``__init__`` metho
|
|||
|
||||
model = LitAutoEncoder.load_from_checkpoint(PATH, encoder=encoder, decoder=decoder)
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*************************
|
||||
nn.Module from checkpoint
|
||||
*************************
|
||||
|
@ -162,7 +166,9 @@ For example, let's pretend we created a LightningModule like so:
|
|||
|
||||
class Autoencoder(L.LightningModule):
|
||||
def __init__(self, encoder, decoder, *args, **kwargs):
|
||||
...
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
|
||||
autoencoder = Autoencoder(Encoder(), Decoder())
|
||||
|
@ -172,11 +178,13 @@ Once the autoencoder has trained, pull out the relevant weights for your torch n
|
|||
.. code-block:: python
|
||||
|
||||
checkpoint = torch.load(CKPT_PATH)
|
||||
encoder_weights = checkpoint["encoder"]
|
||||
decoder_weights = checkpoint["decoder"]
|
||||
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
|
||||
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
||||
*********************
|
||||
Disable checkpointing
|
||||
*********************
|
||||
|
|
Loading…
Reference in New Issue