diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 27539dbd1f..9c06d87708 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -46,10 +46,10 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_init_end(self) - def on_fit_start(self, model): + def on_fit_start(self): """Called when the trainer initialization begins, model has not yet been set.""" for callback in self.callbacks: - callback.on_fit_start(self, model) + callback.on_fit_start(self, self.get_model()) def on_fit_end(self): """Called when the trainer initialization begins, model has not yet been set.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 245b240525..9f89c5383c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -284,9 +284,6 @@ class Trainer( # setup data, etc... self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule) - # hook - self.call_hook('on_fit_start', model) - # hook self.data_connector.prepare_data(model) @@ -298,6 +295,10 @@ class Trainer( # ------------------------- self.accelerator_backend = self.accelerator_connector.select_accelerator() self.accelerator_backend.setup(model) + + # hook + self.call_hook('on_fit_start') + results = self.accelerator_backend.train() self.accelerator_backend.teardown()