added lazy decorator
This commit is contained in:
parent
c6da6eb46c
commit
39b15855ed
|
@ -0,0 +1,17 @@
|
|||
|
||||
def data_loader(fn):
|
||||
"""
|
||||
Decorator to make any fx with this use the lazy property
|
||||
:param fn:
|
||||
:return:
|
||||
"""
|
||||
|
||||
attr_name = '_lazy_' + fn.__name__
|
||||
|
||||
@property
|
||||
def _data_loader(self):
|
||||
if not hasattr(self, attr_name):
|
||||
setattr(self, attr_name, fn(self))
|
||||
return getattr(self, attr_name)
|
||||
|
||||
return _data_loader
|
|
@ -1,11 +1,9 @@
|
|||
import os
|
||||
import torch
|
||||
import math
|
||||
|
||||
from pytorch_lightning.root_module.memory import ModelSummary
|
||||
from pytorch_lightning.root_module.grads import GradInformation
|
||||
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
|
||||
from pytorch_lightning.root_module.hooks import ModelHooks
|
||||
from pytorch_lightning.root_module.decorators import data_loader
|
||||
|
||||
|
||||
class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||
|
@ -26,11 +24,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
# track if gpu was requested for checkpointing
|
||||
self.on_gpu = False
|
||||
|
||||
# computed vars for the dataloaders
|
||||
self._tng_dataloader = None
|
||||
self._val_dataloader = None
|
||||
self._test_dataloader = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Expand model in into whatever you need.
|
||||
|
@ -91,7 +84,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
@property
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
"""
|
||||
Implement a function to load an h5py of this data
|
||||
|
@ -99,7 +92,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@data_loader
|
||||
def test_dataloader(self):
|
||||
"""
|
||||
Implement a function to load an h5py of this data
|
||||
|
@ -107,7 +100,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@data_loader
|
||||
def val_dataloader(self):
|
||||
"""
|
||||
Implement a function to load an h5py of this data
|
||||
|
@ -142,3 +135,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
model.load_model_specific(checkpoint)
|
||||
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue