parent
12184854f9
commit
49767e424f
|
@ -46,10 +46,10 @@ class TrainerCallbackHookMixin(ABC):
|
|||
for callback in self.callbacks:
|
||||
callback.on_init_end(self)
|
||||
|
||||
def on_fit_start(self, model):
|
||||
def on_fit_start(self):
|
||||
"""Called when the trainer initialization begins, model has not yet been set."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start(self, model)
|
||||
callback.on_fit_start(self, self.get_model())
|
||||
|
||||
def on_fit_end(self):
|
||||
"""Called when the trainer initialization begins, model has not yet been set."""
|
||||
|
|
|
@ -284,9 +284,6 @@ class Trainer(
|
|||
# setup data, etc...
|
||||
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# hook
|
||||
self.call_hook('on_fit_start', model)
|
||||
|
||||
# hook
|
||||
self.data_connector.prepare_data(model)
|
||||
|
||||
|
@ -298,6 +295,10 @@ class Trainer(
|
|||
# -------------------------
|
||||
self.accelerator_backend = self.accelerator_connector.select_accelerator()
|
||||
self.accelerator_backend.setup(model)
|
||||
|
||||
# hook
|
||||
self.call_hook('on_fit_start')
|
||||
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
|
|
Loading…
Reference in New Issue