1.7 KiB
1.7 KiB
Lightning modules are strict superclasses of torch.nn.Module. A LightningModule offers the following in addition to that API.
freeze
Freeze all params for inference
model = MyLightningModule(...)
model.freeze()
load_from_metrics
This is the easiest/fastest way which loads hyperparameters and weights from a checkpoint,
such as the one saved by the ModelCheckpoint
callback
pretrained_model = MyLightningModule.load_from_checkpoint(
checkpoint_path='/path/to/pytorch_checkpoint.ckpt'
)
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
load_from_metrics
If you're using test tube, there is an alternate method which uses the meta_tags.csv file from test-tube to rebuild the model. The meta_tags.csv file can be found in the test-tube experiment save_dir.
pretrained_model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
on_gpu=True,
map_location=None
)
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
Params
Param | description |
---|---|
weights_path | Path to a PyTorch checkpoint |
tags_csv | Path to meta_tags.csv file generated by the test-tube Experiment |
on_gpu | if True, puts model on GPU. Make sure to use transforms option if model devices have changed |
map_location | A dictionary mapping saved weight GPU devices to new GPU devices |
Returns
LightningModule - The pretrained LightningModule
unfreeze
Unfreeze all params for inference
model = MyLightningModule(...)
model.unfreeze()