From 39b15855ed8935f576968677cafe089638579290 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 25 Jul 2019 10:39:48 -0400 Subject: [PATCH] added lazy decorator --- pytorch_lightning/root_module/decorators.py | 17 +++++++++++++++++ pytorch_lightning/root_module/root_module.py | 18 +++++++----------- 2 files changed, 24 insertions(+), 11 deletions(-) create mode 100644 pytorch_lightning/root_module/decorators.py diff --git a/pytorch_lightning/root_module/decorators.py b/pytorch_lightning/root_module/decorators.py new file mode 100644 index 0000000000..ef7bd50283 --- /dev/null +++ b/pytorch_lightning/root_module/decorators.py @@ -0,0 +1,17 @@ + +def data_loader(fn): + """ + Decorator to make any fx with this use the lazy property + :param fn: + :return: + """ + + attr_name = '_lazy_' + fn.__name__ + + @property + def _data_loader(self): + if not hasattr(self, attr_name): + setattr(self, attr_name, fn(self)) + return getattr(self, attr_name) + + return _data_loader diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index c0f401848d..d6e0ec9666 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -1,11 +1,9 @@ -import os import torch -import math - from pytorch_lightning.root_module.memory import ModelSummary from pytorch_lightning.root_module.grads import GradInformation from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.root_module.hooks import ModelHooks +from pytorch_lightning.root_module.decorators import data_loader class LightningModule(GradInformation, ModelIO, ModelHooks): @@ -26,11 +24,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): # track if gpu was requested for checkpointing self.on_gpu = False - # computed vars for the dataloaders - self._tng_dataloader = None - self._val_dataloader = None - self._test_dataloader = None - def forward(self, *args, **kwargs): """ Expand model in into whatever you need. @@ -91,7 +84,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): for param in self.parameters(): param.requires_grad = True - @property + @data_loader def tng_dataloader(self): """ Implement a function to load an h5py of this data @@ -99,7 +92,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): """ raise NotImplementedError - @property + @data_loader def test_dataloader(self): """ Implement a function to load an h5py of this data @@ -107,7 +100,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): """ raise NotImplementedError - @property + @data_loader def val_dataloader(self): """ Implement a function to load an h5py of this data @@ -142,3 +135,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): model.load_model_specific(checkpoint) model.load_state_dict(checkpoint['state_dict'], strict=False) return model + + +