added lazy decorator

This commit is contained in:
William Falcon 2019-07-25 10:39:48 -04:00
parent c6da6eb46c
commit 39b15855ed
2 changed files with 24 additions and 11 deletions

View File

@ -0,0 +1,17 @@
def data_loader(fn):
"""
Decorator to make any fx with this use the lazy property
:param fn:
:return:
"""
attr_name = '_lazy_' + fn.__name__
@property
def _data_loader(self):
if not hasattr(self, attr_name):
setattr(self, attr_name, fn(self))
return getattr(self, attr_name)
return _data_loader

View File

@ -1,11 +1,9 @@
import os
import torch
import math
from pytorch_lightning.root_module.memory import ModelSummary
from pytorch_lightning.root_module.grads import GradInformation
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.root_module.hooks import ModelHooks
from pytorch_lightning.root_module.decorators import data_loader
class LightningModule(GradInformation, ModelIO, ModelHooks):
@ -26,11 +24,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
# track if gpu was requested for checkpointing
self.on_gpu = False
# computed vars for the dataloaders
self._tng_dataloader = None
self._val_dataloader = None
self._test_dataloader = None
def forward(self, *args, **kwargs):
"""
Expand model in into whatever you need.
@ -91,7 +84,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
for param in self.parameters():
param.requires_grad = True
@property
@data_loader
def tng_dataloader(self):
"""
Implement a function to load an h5py of this data
@ -99,7 +92,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
"""
raise NotImplementedError
@property
@data_loader
def test_dataloader(self):
"""
Implement a function to load an h5py of this data
@ -107,7 +100,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
"""
raise NotImplementedError
@property
@data_loader
def val_dataloader(self):
"""
Implement a function to load an h5py of this data
@ -142,3 +135,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
model.load_model_specific(checkpoint)
model.load_state_dict(checkpoint['state_dict'], strict=False)
return model