lightning/pytorch_lightning/root_module/root_module.py

173 lines
4.8 KiB
Python
Raw Normal View History

2019-03-31 01:45:16 +00:00
import torch
2019-08-07 14:14:59 +00:00
from pytorch_lightning.root_module.memory import ModelSummary
from pytorch_lightning.root_module.grads import GradInformation
from pytorch_lightning.trainer.trainer_io import load_hparams_from_tags_csv
from pytorch_lightning.root_module.model_saving import ModelIO
2019-08-07 14:14:59 +00:00
from pytorch_lightning.root_module.hooks import ModelHooks
from pytorch_lightning.root_module.decorators import data_loader
2019-03-31 01:45:16 +00:00
2019-07-24 19:33:08 +00:00
class LightningModule(GradInformation, ModelIO, ModelHooks):
2019-03-31 01:45:16 +00:00
2019-07-25 16:08:00 +00:00
def __init__(self, *args, **kwargs):
super(LightningModule, self).__init__(*args, **kwargs)
2019-03-31 20:29:50 +00:00
2019-03-31 01:45:16 +00:00
self.dtype = torch.FloatTensor
self.exp_save_path = None
self.current_epoch = 0
self.global_step = 0
self.loaded_optimizer_states_dict = {}
2019-04-23 11:25:09 +00:00
self.trainer = None
2019-10-04 22:53:38 +00:00
self.logger = None
2019-07-24 20:19:19 +00:00
self.example_input_array = None
2019-03-31 01:45:16 +00:00
2019-03-31 20:29:50 +00:00
# track if gpu was requested for checkpointing
self.on_gpu = False
self.use_dp = False
self.use_ddp = False
2019-10-04 22:53:38 +00:00
self.use_ddp2 = False
self.use_amp = False
2019-03-31 20:29:50 +00:00
2019-03-31 01:45:16 +00:00
def forward(self, *args, **kwargs):
"""
Expand model in into whatever you need.
Also need to return the target
:param x:
:return:
"""
2019-06-25 23:35:11 +00:00
raise NotImplementedError
2019-03-31 01:45:16 +00:00
def training_step(self, *args, **kwargs):
"""
return loss, dict with metrics for tqdm
:param called with batch, batch_nb
additional: optimizer_i if multiple optimizers used
:return:
"""
raise NotImplementedError
def validation_step(self, *args, **kwargs):
2019-03-31 01:45:16 +00:00
"""
return whatever outputs will need to be aggregated in validation_end
OPTIONAL
:param called with batch, batch_nb
additional: dataset_i if multiple val datasets used
2019-03-31 01:45:16 +00:00
:return:
"""
pass
2019-03-31 01:45:16 +00:00
Expectopatronum implement #89 (#182) * rename validate -> evaluate; implement test logic; allow multiple test_loaders * add test_step and test_end to LightningModule * add in_test_mode to pretraining to implement case 2 (test pretrained model) * fix code style issues * LightningTestModel: add optional second test set, implement test_step and test_end * implemented test for multiple test_dataloaders; fixed typo * add two test cases for #89 * add documentation for test_step, test_end; fix computation of loss in validation_step example * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Added proper dp ddp routing calls for test mode * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update override_data_parallel.py * Update test_models.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update test_models.py * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * Update trainer.py * Update override_data_parallel.py * Update debug.py * Update lm_test_module.py * Update test_models.py
2019-08-30 22:56:09 +00:00
def test_step(self, *args, **kwargs):
"""
return whatever outputs will need to be aggregated in test_end
OPTIONAL
:param called with batch, batch_nb
additional: dataset_i if multiple val datasets used
:return:
"""
pass
2019-03-31 01:45:16 +00:00
def validation_end(self, outputs):
"""
Outputs has the appended output after each validation step
OPTIONAL
2019-03-31 01:45:16 +00:00
:param outputs:
:return: dic_with_metrics for tqdm
"""
pass
2019-03-31 01:45:16 +00:00
Expectopatronum implement #89 (#182) * rename validate -> evaluate; implement test logic; allow multiple test_loaders * add test_step and test_end to LightningModule * add in_test_mode to pretraining to implement case 2 (test pretrained model) * fix code style issues * LightningTestModel: add optional second test set, implement test_step and test_end * implemented test for multiple test_dataloaders; fixed typo * add two test cases for #89 * add documentation for test_step, test_end; fix computation of loss in validation_step example * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Added proper dp ddp routing calls for test mode * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update override_data_parallel.py * Update test_models.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update test_models.py * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * Update trainer.py * Update override_data_parallel.py * Update debug.py * Update lm_test_module.py * Update test_models.py
2019-08-30 22:56:09 +00:00
def test_end(self, outputs):
"""
Outputs has the appended output after each test step
OPTIONAL
:param outputs:
:return: dic_with_metrics for tqdm
"""
pass
2019-03-31 01:45:16 +00:00
def configure_optimizers(self):
"""
2019-07-24 05:12:45 +00:00
Return a list of optimizers and a list of schedulers (could be empty)
2019-03-31 01:45:16 +00:00
:return:
"""
raise NotImplementedError
2019-10-05 14:47:18 +00:00
def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i, second_order_closure=None):
"""
Do something instead of the standard optimizer behavior
:param epoch_nb:
:param batch_nb:
:param optimizer:
:param optimizer_i:
2019-10-05 14:47:18 +00:00
:param second_order_closure: closure for second order methods
:return:
"""
2019-10-05 14:47:18 +00:00
if isinstance(optimizer, torch.optim.LBFGS):
optimizer.step(second_order_closure)
else:
optimizer.step()
# clear gradients
optimizer.zero_grad()
@data_loader
def train_dataloader(self):
2019-03-31 01:45:16 +00:00
"""
Implement a PyTorch DataLoader
2019-03-31 01:45:16 +00:00
:return:
"""
raise NotImplementedError
@data_loader
2019-03-31 01:45:16 +00:00
def test_dataloader(self):
"""
Implement a PyTorch DataLoader
2019-03-31 01:45:16 +00:00
:return:
"""
return None
2019-03-31 01:45:16 +00:00
@data_loader
2019-03-31 01:45:16 +00:00
def val_dataloader(self):
"""
Implement a PyTorch DataLoader
2019-03-31 01:45:16 +00:00
:return:
"""
return None
2019-03-31 01:45:16 +00:00
@classmethod
def load_from_metrics(cls, weights_path, tags_csv):
2019-03-31 01:45:16 +00:00
"""
Primary way of loading model from csv weights path
:param weights_path:
:param tags_csv:
2019-05-13 09:32:18 +00:00
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
2019-03-31 01:45:16 +00:00
:return:
"""
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
2019-03-31 01:45:16 +00:00
# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
2019-03-31 01:45:16 +00:00
2019-07-27 02:07:02 +00:00
# load the state_dict on the model automatically
2019-03-31 01:45:16 +00:00
model = cls(hparams)
2019-07-27 02:07:02 +00:00
model.load_state_dict(checkpoint['state_dict'])
2019-03-31 01:45:16 +00:00
# give model a chance to load something
2019-07-27 02:07:02 +00:00
model.on_load_checkpoint(checkpoint)
2019-03-31 01:45:16 +00:00
return model
2019-07-25 14:39:48 +00:00
2019-07-25 16:01:52 +00:00
def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.parameters():
param.requires_grad = True