50 lines
1.3 KiB
Markdown
50 lines
1.3 KiB
Markdown
|
Lightning modules are strict superclasses of torch.nn.Module. A LightningModule offers the following in addition to that API.
|
||
|
|
||
|
---
|
||
|
### freeze
|
||
|
Freeze all params for inference
|
||
|
```{.python}
|
||
|
model = MyLightningModule(...)
|
||
|
model.freeze()
|
||
|
```
|
||
|
|
||
|
---
|
||
|
### load_from_metrics
|
||
|
This is the easiest/fastest way which uses the meta_tags.csv file from test-tube to rebuild the model.
|
||
|
The meta_tags.csv file can be found in the test-tube experiment save_dir.
|
||
|
|
||
|
```{.python}
|
||
|
pretrained_model = MyLightningModule.load_from_metrics(
|
||
|
weights_path='/path/to/pytorch_checkpoint.ckpt',
|
||
|
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
|
||
|
on_gpu=True,
|
||
|
map_location=None
|
||
|
)
|
||
|
|
||
|
# predict
|
||
|
pretrained_model.freeze()
|
||
|
y_hat = pretrained_model(x)
|
||
|
```
|
||
|
|
||
|
**Params**
|
||
|
|
||
|
| Param | description |
|
||
|
|---|---|
|
||
|
| weights_path | Path to a pytorch checkpoint |
|
||
|
| tags_csv | Path to meta_tags.csv file generated by the test-tube Experiment |
|
||
|
| on_gpu | if True, puts model on GPU. Make sure to use transforms option if model devices have changed |
|
||
|
| map_location | A dictionary mapping saved weight GPU devices to new GPU devices |
|
||
|
|
||
|
**Returns**
|
||
|
|
||
|
LightningModule - The pretrained LightningModule
|
||
|
|
||
|
---
|
||
|
### unfreeze
|
||
|
Unfreeze all params for inference
|
||
|
```{.python}
|
||
|
model = MyLightningModule(...)
|
||
|
model.unfreeze()
|
||
|
```
|
||
|
|