lightning/pytorch_lightning/root_module/root_module.py

164 lines
4.5 KiB
Python
Raw Normal View History

2019-03-31 01:45:16 +00:00
import os
import torch
import math
from pytorch_lightning.root_module.memory import ModelSummary
from pytorch_lightning.root_module.grads import GradInformation
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.root_module.hooks import ModelHooks
2019-07-24 19:33:08 +00:00
class LightningModule(GradInformation, ModelIO, ModelHooks):
2019-03-31 01:45:16 +00:00
def __init__(self, hparams):
2019-06-27 14:05:47 +00:00
super(LightningModule, self).__init__()
2019-03-31 01:45:16 +00:00
self.hparams = hparams
2019-03-31 20:29:50 +00:00
2019-03-31 01:45:16 +00:00
self.dtype = torch.FloatTensor
self.exp_save_path = None
self.current_epoch = 0
self.global_step = 0
self.loaded_optimizer_states_dict = {}
2019-04-23 11:25:09 +00:00
self.trainer = None
2019-06-29 19:58:47 +00:00
self.experiment = None
2019-03-31 01:45:16 +00:00
2019-03-31 20:29:50 +00:00
# track if gpu was requested for checkpointing
self.on_gpu = False
2019-03-31 01:45:16 +00:00
# computed vars for the dataloaders
self._tng_dataloader = None
self._val_dataloader = None
self._test_dataloader = None
def forward(self, *args, **kwargs):
"""
Expand model in into whatever you need.
Also need to return the target
:param x:
:return:
"""
2019-06-25 23:35:11 +00:00
raise NotImplementedError
2019-03-31 01:45:16 +00:00
2019-05-14 10:37:56 +00:00
def validation_step(self, data_batch, batch_nb):
2019-03-31 01:45:16 +00:00
"""
return whatever outputs will need to be aggregated in validation_end
:param data_batch:
:return:
"""
raise NotImplementedError
def validation_end(self, outputs):
"""
Outputs has the appended output after each validation step
:param outputs:
:return: dic_with_metrics for tqdm
"""
raise NotImplementedError
2019-05-14 10:37:56 +00:00
def training_step(self, data_batch, batch_nb):
2019-03-31 01:45:16 +00:00
"""
return loss, dict with metrics for tqdm
:param data_batch:
:return:
"""
raise NotImplementedError
def configure_optimizers(self):
"""
Return array of optimizers
:return:
"""
raise NotImplementedError
def update_tng_log_metrics(self, logs):
"""
Chance to update metrics to be logged for training step.
For example, add music, images, etc... to log
:param logs:
:return:
"""
return logs
2019-03-31 01:45:16 +00:00
def loss(self, *args, **kwargs):
"""
Expand model_out into your components
:param model_out:
:return:
"""
raise NotImplementedError
def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.parameters():
param.requires_grad = True
@property
def tng_dataloader(self):
"""
Implement a function to load an h5py of this data
:return:
"""
raise NotImplementedError
@property
def test_dataloader(self):
"""
Implement a function to load an h5py of this data
:return:
"""
raise NotImplementedError
@property
def val_dataloader(self):
"""
Implement a function to load an h5py of this data
:return:
"""
raise NotImplementedError
@staticmethod
def get_process_position(gpus):
try:
current_gpu = os.environ["CUDA_VISIBLE_DEVICES"]
2019-07-12 16:38:39 +00:00
gpu_ids = gpus.split(',')
2019-03-31 01:45:16 +00:00
process_position = gpu_ids.index(current_gpu)
return process_position, current_gpu
except Exception as e:
return 0, 0
@classmethod
2019-05-13 09:32:18 +00:00
def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
2019-03-31 01:45:16 +00:00
"""
Primary way of loading model from csv weights path
:param weights_path:
:param tags_csv:
:param on_gpu:
2019-05-13 09:32:18 +00:00
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
2019-03-31 01:45:16 +00:00
:return:
"""
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', on_gpu)
if on_gpu:
2019-05-13 09:32:18 +00:00
if map_location is not None:
checkpoint = torch.load(weights_path, map_location=map_location)
else:
checkpoint = torch.load(weights_path)
2019-03-31 01:45:16 +00:00
else:
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
model = cls(hparams)
# allow model to load
model.load_model_specific(checkpoint)
model.load_state_dict(checkpoint['state_dict'], strict=False)
return model