fix on_fit_start (#3616)

* init

* fix call_hook args
This commit is contained in:
s-rog 2020-09-23 16:38:33 +08:00 committed by GitHub
parent 12184854f9
commit 49767e424f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -46,10 +46,10 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_init_end(self)
def on_fit_start(self, model):
def on_fit_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_start(self, model)
callback.on_fit_start(self, self.get_model())
def on_fit_end(self):
"""Called when the trainer initialization begins, model has not yet been set."""

View File

@ -284,9 +284,6 @@ class Trainer(
# setup data, etc...
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
# hook
self.call_hook('on_fit_start', model)
# hook
self.data_connector.prepare_data(model)
@ -298,6 +295,10 @@ class Trainer(
# -------------------------
self.accelerator_backend = self.accelerator_connector.select_accelerator()
self.accelerator_backend.setup(model)
# hook
self.call_hook('on_fit_start')
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()