pointer to trainer in model

This commit is contained in:
William Falcon 2019-04-23 07:25:09 -04:00
parent b625b293f4
commit 676d76d839
3 changed files with 3 additions and 1 deletions

View File

@ -195,6 +195,7 @@ class Trainer(TrainerIO):
# -----------------------------
def fit(self, model):
self.model = model
model.trainer = self
# transfer data loaders from model
self.__get_dataloaders(model)

View File

@ -24,6 +24,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
self.overfit = hparams.overfit
self.gradient_clip = hparams.gradient_clip
self.num = 2
self.trainer = None
# track if gpu was requested for checkpointing
self.on_gpu = False

View File

@ -7,7 +7,7 @@ from setuptools import setup, find_packages
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
setup(
name="pytorch-lightning",
version='0.1.dev13',
version='0.1.dev14',
description="The Keras for ML researchers using PyTorch",
author="William Falcon",
author_email="waf2107@columbia.edu",