diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 2b1114a61b..897bef1d11 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -195,6 +195,7 @@ class Trainer(TrainerIO): # ----------------------------- def fit(self, model): self.model = model + model.trainer = self # transfer data loaders from model self.__get_dataloaders(model) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index ab49a9a1af..319af7d6ab 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -24,6 +24,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): self.overfit = hparams.overfit self.gradient_clip = hparams.gradient_clip self.num = 2 + self.trainer = None # track if gpu was requested for checkpointing self.on_gpu = False diff --git a/setup.py b/setup.py index 4ae0236c04..1d049abec2 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages # http://blog.ionelmc.ro/2014/05/25/python-packaging/ setup( name="pytorch-lightning", - version='0.1.dev13', + version='0.1.dev14', description="The Keras for ML researchers using PyTorch", author="William Falcon", author_email="waf2107@columbia.edu",