lightning/docs/LightningModule/methods.md

51 lines
1.3 KiB
Markdown
Raw Normal View History

2019-06-28 21:42:32 +00:00
Lightning modules are strict superclasses of torch.nn.Module. A LightningModule offers the following in addition to that API.
---
### freeze
Freeze all params for inference
```{.python}
model = MyLightningModule(...)
model.freeze()
```
---
### load_from_metrics
This is the easiest/fastest way which uses the meta_tags.csv file from test-tube to rebuild the model.
The meta_tags.csv file can be found in the test-tube experiment save_dir.
```{.python}
pretrained_model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
on_gpu=True,
map_location=None
)
2019-07-27 03:16:03 +00:00
# predict
pretrained_model.eval()
2019-06-28 21:42:32 +00:00
pretrained_model.freeze()
y_hat = pretrained_model(x)
```
**Params**
| Param | description |
|---|---|
2019-08-01 14:11:26 +00:00
| weights_path | Path to a PyTorch checkpoint |
2019-06-28 21:42:32 +00:00
| tags_csv | Path to meta_tags.csv file generated by the test-tube Experiment |
| on_gpu | if True, puts model on GPU. Make sure to use transforms option if model devices have changed |
| map_location | A dictionary mapping saved weight GPU devices to new GPU devices |
**Returns**
LightningModule - The pretrained LightningModule
---
### unfreeze
Unfreeze all params for inference
```{.python}
model = MyLightningModule(...)
model.unfreeze()
```