2019-10-22 08:32:40 +00:00
|
|
|
import warnings
|
2019-10-31 10:45:28 +00:00
|
|
|
import collections
|
2019-10-23 08:48:24 +00:00
|
|
|
from argparse import Namespace
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2019-08-05 08:52:09 +00:00
|
|
|
|
2019-10-22 08:32:40 +00:00
|
|
|
from pytorch_lightning.root_module.decorators import data_loader
|
2019-08-07 14:14:59 +00:00
|
|
|
from pytorch_lightning.root_module.grads import GradInformation
|
|
|
|
from pytorch_lightning.root_module.hooks import ModelHooks
|
2019-10-22 08:32:40 +00:00
|
|
|
from pytorch_lightning.root_module.memory import ModelSummary
|
|
|
|
from pytorch_lightning.root_module.model_saving import ModelIO
|
|
|
|
from pytorch_lightning.trainer.trainer_io import load_hparams_from_tags_csv
|
2019-11-05 13:43:21 +00:00
|
|
|
import logging
|
2019-10-21 06:16:55 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-24 19:33:08 +00:00
|
|
|
class LightningModule(GradInformation, ModelIO, ModelHooks):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-25 16:08:00 +00:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(LightningModule, self).__init__(*args, **kwargs)
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.dtype = torch.FloatTensor
|
|
|
|
self.exp_save_path = None
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.global_step = 0
|
|
|
|
self.loaded_optimizer_states_dict = {}
|
2019-04-23 11:25:09 +00:00
|
|
|
self.trainer = None
|
2019-10-04 22:53:38 +00:00
|
|
|
self.logger = None
|
2019-07-24 20:19:19 +00:00
|
|
|
self.example_input_array = None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-03-31 20:29:50 +00:00
|
|
|
# track if gpu was requested for checkpointing
|
|
|
|
self.on_gpu = False
|
2019-08-24 01:23:27 +00:00
|
|
|
self.use_dp = False
|
|
|
|
self.use_ddp = False
|
2019-10-04 22:53:38 +00:00
|
|
|
self.use_ddp2 = False
|
2019-08-24 01:23:27 +00:00
|
|
|
self.use_amp = False
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Expand model in into whatever you need.
|
|
|
|
Also need to return the target
|
|
|
|
:param x:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-06-25 23:35:11 +00:00
|
|
|
raise NotImplementedError
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-13 15:37:37 +00:00
|
|
|
def training_step(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
return loss, dict with metrics for tqdm
|
|
|
|
:param called with batch, batch_nb
|
|
|
|
additional: optimizer_i if multiple optimizers used
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def validation_step(self, *args, **kwargs):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
return whatever outputs will need to be aggregated in validation_end
|
2019-08-11 14:01:57 +00:00
|
|
|
OPTIONAL
|
2019-08-13 15:37:37 +00:00
|
|
|
:param called with batch, batch_nb
|
|
|
|
additional: dataset_i if multiple val datasets used
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
pass
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
def test_step(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
return whatever outputs will need to be aggregated in test_end
|
|
|
|
OPTIONAL
|
|
|
|
:param called with batch, batch_nb
|
|
|
|
additional: dataset_i if multiple val datasets used
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def validation_end(self, outputs):
|
|
|
|
"""
|
|
|
|
Outputs has the appended output after each validation step
|
2019-08-11 14:01:57 +00:00
|
|
|
OPTIONAL
|
2019-03-31 01:45:16 +00:00
|
|
|
:param outputs:
|
|
|
|
:return: dic_with_metrics for tqdm
|
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
pass
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
def test_end(self, outputs):
|
|
|
|
"""
|
|
|
|
Outputs has the appended output after each test step
|
|
|
|
OPTIONAL
|
|
|
|
:param outputs:
|
|
|
|
:return: dic_with_metrics for tqdm
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def configure_optimizers(self):
|
|
|
|
"""
|
2019-07-24 05:12:45 +00:00
|
|
|
Return a list of optimizers and a list of schedulers (could be empty)
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-10-05 14:47:18 +00:00
|
|
|
def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i, second_order_closure=None):
|
2019-08-13 13:32:45 +00:00
|
|
|
"""
|
|
|
|
Do something instead of the standard optimizer behavior
|
|
|
|
:param epoch_nb:
|
|
|
|
:param batch_nb:
|
|
|
|
:param optimizer:
|
|
|
|
:param optimizer_i:
|
2019-10-05 14:47:18 +00:00
|
|
|
:param second_order_closure: closure for second order methods
|
2019-08-13 13:32:45 +00:00
|
|
|
:return:
|
|
|
|
"""
|
2019-10-05 14:47:18 +00:00
|
|
|
if isinstance(optimizer, torch.optim.LBFGS):
|
|
|
|
optimizer.step(second_order_closure)
|
|
|
|
else:
|
|
|
|
optimizer.step()
|
2019-08-13 13:32:45 +00:00
|
|
|
|
|
|
|
# clear gradients
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
def tbptt_split_batch(self, batch, split_size):
|
|
|
|
"""
|
|
|
|
Return list of batch splits. Each split will be passed to forward_step to enable truncated
|
|
|
|
back propagation through time. The default implementation splits root level Tensors and
|
|
|
|
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
time_dims = [len(x[0]) for x in batch if isinstance(
|
|
|
|
x, torch.Tensor) or isinstance(x, collections.Sequence)]
|
|
|
|
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
|
|
|
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
|
|
|
|
|
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
split_x = x[:, t:t + split_size]
|
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
|
|
|
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
|
|
|
|
|
|
|
batch_split.append(split_x)
|
|
|
|
|
|
|
|
splits.append(batch_split)
|
|
|
|
|
|
|
|
return splits
|
|
|
|
|
2019-07-25 15:01:08 +00:00
|
|
|
@data_loader
|
2019-10-21 06:16:55 +00:00
|
|
|
def tng_dataloader(self):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
Implement a PyTorch DataLoader
|
2019-10-21 06:16:55 +00:00
|
|
|
* Deprecated in v0.5.0. use train_dataloader instead. *
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-10-21 06:16:55 +00:00
|
|
|
@data_loader
|
|
|
|
def train_dataloader(self):
|
|
|
|
"""
|
|
|
|
Implement a PyTorch DataLoader
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
#
|
|
|
|
try:
|
|
|
|
output = self.tng_dataloader()
|
|
|
|
warnings.warn("tng_dataloader has been renamed to train_dataloader since v0.5.0",
|
|
|
|
DeprecationWarning)
|
|
|
|
return output
|
|
|
|
except NotImplementedError:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-07-25 15:01:08 +00:00
|
|
|
@data_loader
|
2019-03-31 01:45:16 +00:00
|
|
|
def test_dataloader(self):
|
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
Implement a PyTorch DataLoader
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
return None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-25 15:01:08 +00:00
|
|
|
@data_loader
|
2019-03-31 01:45:16 +00:00
|
|
|
def val_dataloader(self):
|
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
Implement a PyTorch DataLoader
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
2019-08-11 14:01:57 +00:00
|
|
|
return None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
@classmethod
|
2019-10-04 20:18:43 +00:00
|
|
|
def load_from_metrics(cls, weights_path, tags_csv):
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
Primary way of loading model from csv weights path
|
|
|
|
:param weights_path:
|
|
|
|
:param tags_csv:
|
2019-05-13 09:32:18 +00:00
|
|
|
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
|
2019-03-31 01:45:16 +00:00
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
hparams = load_hparams_from_tags_csv(tags_csv)
|
2019-10-04 20:18:43 +00:00
|
|
|
hparams.__setattr__('on_gpu', False)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-09-16 14:47:19 +00:00
|
|
|
# load on CPU only to avoid OOM issues
|
|
|
|
# then its up to user to put back on GPUs
|
|
|
|
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-27 02:07:02 +00:00
|
|
|
# load the state_dict on the model automatically
|
2019-03-31 01:45:16 +00:00
|
|
|
model = cls(hparams)
|
2019-07-27 02:07:02 +00:00
|
|
|
model.load_state_dict(checkpoint['state_dict'])
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-27 01:37:06 +00:00
|
|
|
# give model a chance to load something
|
2019-07-27 02:07:02 +00:00
|
|
|
model.on_load_checkpoint(checkpoint)
|
2019-07-27 01:37:06 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
return model
|
2019-07-25 14:39:48 +00:00
|
|
|
|
2019-10-23 08:48:24 +00:00
|
|
|
@classmethod
|
|
|
|
def load_from_checkpoint(cls, checkpoint_path):
|
|
|
|
"""
|
|
|
|
Primary way of loading model from a checkpoint
|
|
|
|
:param checkpoint_path:
|
|
|
|
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
|
|
|
|
# load on CPU only to avoid OOM issues
|
|
|
|
# then its up to user to put back on GPUs
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
|
|
|
try:
|
|
|
|
ckpt_hparams = checkpoint['hparams']
|
|
|
|
except KeyError:
|
|
|
|
raise IOError(
|
|
|
|
"Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
|
|
|
|
"in self.hparams?"
|
|
|
|
)
|
|
|
|
hparams = Namespace(**ckpt_hparams)
|
|
|
|
|
|
|
|
# load the state_dict on the model automatically
|
|
|
|
model = cls(hparams)
|
|
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
|
|
|
|
# give model a chance to load something
|
|
|
|
model.on_load_checkpoint(checkpoint)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
def summarize(self, mode):
|
|
|
|
model_summary = ModelSummary(self, mode=mode)
|
2019-11-05 13:43:21 +00:00
|
|
|
logging.info(model_summary)
|
2019-07-25 16:01:52 +00:00
|
|
|
|
|
|
|
def freeze(self):
|
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = False
|
|
|
|
|
2019-11-05 14:14:33 +00:00
|
|
|
self.eval()
|
|
|
|
|
2019-07-25 16:01:52 +00:00
|
|
|
def unfreeze(self):
|
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = True
|
2019-11-05 14:14:33 +00:00
|
|
|
|
|
|
|
self.train()
|