lightning/docs/LightningModule/methods.md

1.7 KiB

Lightning modules are strict superclasses of torch.nn.Module. A LightningModule offers the following in addition to that API.


freeze

Freeze all params for inference

model = MyLightningModule(...)
model.freeze()

load_from_metrics

This is the easiest/fastest way which loads hyperparameters and weights from a checkpoint, such as the one saved by the ModelCheckpoint callback

pretrained_model = MyLightningModule.load_from_checkpoint(
    checkpoint_path='/path/to/pytorch_checkpoint.ckpt'
)
    
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)

load_from_metrics

If you're using test tube, there is an alternate method which uses the meta_tags.csv file from test-tube to rebuild the model. The meta_tags.csv file can be found in the test-tube experiment save_dir.

pretrained_model = MyLightningModule.load_from_metrics(
    weights_path='/path/to/pytorch_checkpoint.ckpt',
    tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
    on_gpu=True,
    map_location=None
)
    
# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)

Params

Param description
weights_path Path to a PyTorch checkpoint
tags_csv Path to meta_tags.csv file generated by the test-tube Experiment
on_gpu if True, puts model on GPU. Make sure to use transforms option if model devices have changed
map_location A dictionary mapping saved weight GPU devices to new GPU devices

Returns

LightningModule - The pretrained LightningModule


unfreeze

Unfreeze all params for inference

model = MyLightningModule(...)
model.unfreeze()