From 49767e424f2a2ba2893d27ed0c757815e355c86e Mon Sep 17 00:00:00 2001 From: s-rog <55400948+s-rog@users.noreply.github.com> Date: Wed, 23 Sep 2020 16:38:33 +0800 Subject: [PATCH] fix on_fit_start (#3616) * init * fix call_hook args --- pytorch_lightning/trainer/callback_hook.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 27539dbd1f..9c06d87708 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -46,10 +46,10 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_init_end(self) - def on_fit_start(self, model): + def on_fit_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_start(self, model) + callback.on_fit_start(self, self.get_model()) def on_fit_end(self): """Called when the trainer initialization begins, model has not yet been set.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 245b240525..9f89c5383c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -284,9 +284,6 @@ class Trainer( # setup data, etc... self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook - self.call_hook('on_fit_start', model) - # hook self.data_connector.prepare_data(model) @@ -298,6 +295,10 @@ class Trainer( # ------------------------- self.accelerator_backend = self.accelerator_connector.select_accelerator() self.accelerator_backend.setup(model) + + # hook + self.call_hook('on_fit_start') + results = self.accelerator_backend.train() self.accelerator_backend.teardown()