51 lines
1.3 KiB
Markdown
51 lines
1.3 KiB
Markdown
Lightning modules are strict superclasses of torch.nn.Module. A LightningModule offers the following in addition to that API.
|
|
|
|
---
|
|
### freeze
|
|
Freeze all params for inference
|
|
```{.python}
|
|
model = MyLightningModule(...)
|
|
model.freeze()
|
|
```
|
|
|
|
---
|
|
### load_from_metrics
|
|
This is the easiest/fastest way which uses the meta_tags.csv file from test-tube to rebuild the model.
|
|
The meta_tags.csv file can be found in the test-tube experiment save_dir.
|
|
|
|
```{.python}
|
|
pretrained_model = MyLightningModule.load_from_metrics(
|
|
weights_path='/path/to/pytorch_checkpoint.ckpt',
|
|
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
|
|
on_gpu=True,
|
|
map_location=None
|
|
)
|
|
|
|
# predict
|
|
pretrained_model.eval()
|
|
pretrained_model.freeze()
|
|
y_hat = pretrained_model(x)
|
|
```
|
|
|
|
**Params**
|
|
|
|
| Param | description |
|
|
|---|---|
|
|
| weights_path | Path to a PyTorch checkpoint |
|
|
| tags_csv | Path to meta_tags.csv file generated by the test-tube Experiment |
|
|
| on_gpu | if True, puts model on GPU. Make sure to use transforms option if model devices have changed |
|
|
| map_location | A dictionary mapping saved weight GPU devices to new GPU devices |
|
|
|
|
**Returns**
|
|
|
|
LightningModule - The pretrained LightningModule
|
|
|
|
---
|
|
### unfreeze
|
|
Unfreeze all params for inference
|
|
```{.python}
|
|
model = MyLightningModule(...)
|
|
model.unfreeze()
|
|
```
|
|
|