pointer to trainer in model
This commit is contained in:
parent
b625b293f4
commit
676d76d839
|
@ -195,6 +195,7 @@ class Trainer(TrainerIO):
|
|||
# -----------------------------
|
||||
def fit(self, model):
|
||||
self.model = model
|
||||
model.trainer = self
|
||||
|
||||
# transfer data loaders from model
|
||||
self.__get_dataloaders(model)
|
||||
|
|
|
@ -24,6 +24,7 @@ class RootModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
|||
self.overfit = hparams.overfit
|
||||
self.gradient_clip = hparams.gradient_clip
|
||||
self.num = 2
|
||||
self.trainer = None
|
||||
|
||||
# track if gpu was requested for checkpointing
|
||||
self.on_gpu = False
|
||||
|
|
2
setup.py
2
setup.py
|
@ -7,7 +7,7 @@ from setuptools import setup, find_packages
|
|||
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
|
||||
setup(
|
||||
name="pytorch-lightning",
|
||||
version='0.1.dev13',
|
||||
version='0.1.dev14',
|
||||
description="The Keras for ML researchers using PyTorch",
|
||||
author="William Falcon",
|
||||
author_email="waf2107@columbia.edu",
|
||||
|
|
Loading…
Reference in New Issue