Always use `trainer.call_hook` (#8498)
This commit is contained in:
parent
ad3f183bc3
commit
e1442d247e
|
@ -89,6 +89,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
|
||||
|
||||
|
||||
- Improve coverage of `self.log`-ing in any `LightningModule` or `Callback` hook ([#8498](https://github.com/PyTorchLightning/pytorch-lightning/pull/8498))
|
||||
|
||||
|
||||
- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
|
||||
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))
|
||||
|
||||
|
|
|
@ -402,9 +402,22 @@ class LightningModule(
|
|||
on_step = self.__auto_choose_log_on_step(on_step)
|
||||
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
||||
|
||||
if self.trainer is None:
|
||||
raise MisconfigurationException(
|
||||
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
|
||||
" This is most likely because the model hasn't been passed to the `Trainer`"
|
||||
)
|
||||
results = self.trainer._results
|
||||
assert results is not None
|
||||
assert self._current_fx_name is not None
|
||||
if results is None:
|
||||
raise MisconfigurationException(
|
||||
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
|
||||
" yet. This is most likely because you are trying to log in a `predict` hook,"
|
||||
" but it doesn't support logging"
|
||||
)
|
||||
if self._current_fx_name is None:
|
||||
raise MisconfigurationException(
|
||||
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
|
||||
)
|
||||
FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
|
||||
|
||||
# make sure user doesn't introduce logic for multi-dataloaders
|
||||
|
|
|
@ -179,11 +179,10 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
def on_evaluation_model_eval(self) -> None:
|
||||
"""Sets model to eval mode"""
|
||||
model_ref = self.trainer.lightning_module
|
||||
if self.trainer.testing:
|
||||
model_ref.on_test_model_eval()
|
||||
self.trainer.call_hook("on_test_model_eval")
|
||||
else:
|
||||
model_ref.on_validation_model_eval()
|
||||
self.trainer.call_hook("on_validation_model_eval")
|
||||
|
||||
def on_evaluation_model_train(self) -> None:
|
||||
"""Sets model to train mode"""
|
||||
|
|
|
@ -139,7 +139,7 @@ class CallbackConnector:
|
|||
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
|
||||
will be pushed to the end of the list, ensuring they run last.
|
||||
"""
|
||||
model_callbacks = self.trainer.lightning_module.configure_callbacks()
|
||||
model_callbacks = self.trainer.call_hook("configure_callbacks")
|
||||
if not model_callbacks:
|
||||
return
|
||||
model_callback_types = {type(c) for c in model_callbacks}
|
||||
|
|
|
@ -73,7 +73,7 @@ class DataConnector:
|
|||
if self.can_prepare_data():
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.prepare_data()
|
||||
self.trainer.lightning_module.prepare_data()
|
||||
self.trainer.call_hook("prepare_data")
|
||||
self.trainer._is_data_prepared = True
|
||||
|
||||
def can_prepare_data(self):
|
||||
|
|
|
@ -77,12 +77,17 @@ class FxValidator:
|
|||
training_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
|
||||
validation_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
|
||||
test_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
|
||||
on_before_batch_transfer=None,
|
||||
transfer_batch_to_device=None,
|
||||
on_after_batch_transfer=None,
|
||||
backward=None,
|
||||
optimizer_step=None,
|
||||
# TODO(@carmocca): some {step,epoch}_{start,end} are missing
|
||||
configure_optimizers=None,
|
||||
on_train_dataloader=None,
|
||||
train_dataloader=None,
|
||||
on_val_dataloader=None,
|
||||
val_dataloader=None,
|
||||
on_test_dataloader=None,
|
||||
test_dataloader=None,
|
||||
prepare_data=None,
|
||||
configure_callbacks=None,
|
||||
on_validation_model_eval=None,
|
||||
on_test_model_eval=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -90,12 +95,12 @@ class FxValidator:
|
|||
"""Check if the given function name is allowed to log"""
|
||||
if fx_name not in cls.functions:
|
||||
raise RuntimeError(
|
||||
f"You are trying to `self.log()` inside `{fx_name}` but it is not implemented."
|
||||
f"Logging inside `{fx_name}` is not implemented."
|
||||
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
|
||||
)
|
||||
allowed = cls.functions[fx_name]
|
||||
if allowed is None:
|
||||
raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`")
|
||||
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")
|
||||
|
||||
m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
|
||||
if on_step not in allowed["on_step"]:
|
||||
|
|
|
@ -523,8 +523,9 @@ class TrainerDataLoadingMixin(ABC):
|
|||
Returns:
|
||||
The dataloader
|
||||
"""
|
||||
self.call_hook(f"on_{stage.dataloader_prefix}_dataloader")
|
||||
dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")()
|
||||
hook = f"{stage.dataloader_prefix}_dataloader"
|
||||
self.call_hook("on_" + hook, pl_module=model)
|
||||
dataloader = self.call_hook(hook, pl_module=model)
|
||||
if isinstance(dataloader, tuple):
|
||||
dataloader = list(dataloader)
|
||||
self.accelerator.barrier("get_dataloaders")
|
||||
|
|
|
@ -29,9 +29,10 @@ class TrainerOptimizersMixin(ABC):
|
|||
|
||||
_lightning_optimizers: Optional[List[LightningOptimizer]]
|
||||
|
||||
def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]:
|
||||
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
|
||||
pl_module = self.lightning_module or model
|
||||
self._lightning_optimizers = None
|
||||
optim_conf = model.configure_optimizers()
|
||||
optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
|
||||
if optim_conf is None:
|
||||
rank_zero_warn(
|
||||
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
|
||||
|
@ -95,7 +96,7 @@ class TrainerOptimizersMixin(ABC):
|
|||
' * A list of the previously described dict format, with an optional "frequency" key (int)'
|
||||
)
|
||||
|
||||
is_manual_optimization = not model.automatic_optimization
|
||||
is_manual_optimization = not pl_module.automatic_optimization
|
||||
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
|
||||
_validate_scheduler_optimizer(optimizers, lr_schedulers)
|
||||
|
||||
|
|
|
@ -1103,20 +1103,14 @@ class Trainer(
|
|||
# --------------------------
|
||||
# Pre-train
|
||||
# --------------------------
|
||||
# on pretrain routine start
|
||||
ref_model = self.lightning_module
|
||||
|
||||
self.on_pretrain_routine_start()
|
||||
ref_model.on_pretrain_routine_start()
|
||||
self.call_hook("on_pretrain_routine_start")
|
||||
|
||||
# print model summary
|
||||
if self.is_global_zero and self.weights_summary is not None and not self.testing:
|
||||
max_depth = ModelSummary.MODES[self.weights_summary]
|
||||
summarize(ref_model, max_depth=max_depth)
|
||||
summarize(self.lightning_module, max_depth=max_depth)
|
||||
|
||||
# on pretrain routine end
|
||||
self.on_pretrain_routine_end()
|
||||
ref_model.on_pretrain_routine_end()
|
||||
self.call_hook("on_pretrain_routine_end")
|
||||
|
||||
def _run_train(self) -> None:
|
||||
self._pre_training_routine()
|
||||
|
@ -1179,8 +1173,7 @@ class Trainer(
|
|||
stage = self.state.stage
|
||||
self.sanity_checking = True
|
||||
|
||||
# hook and callback
|
||||
self.on_sanity_check_start()
|
||||
self.call_hook("on_sanity_check_start")
|
||||
|
||||
# reload dataloaders
|
||||
self._evaluation_loop.reload_evaluation_dataloaders()
|
||||
|
@ -1189,7 +1182,7 @@ class Trainer(
|
|||
with torch.no_grad():
|
||||
self._evaluation_loop.run()
|
||||
|
||||
self.on_sanity_check_end()
|
||||
self.call_hook("on_sanity_check_end")
|
||||
|
||||
# reset validation metrics
|
||||
self.logger_connector.reset()
|
||||
|
@ -1245,8 +1238,7 @@ class Trainer(
|
|||
|
||||
if self.datamodule is not None:
|
||||
self.datamodule.setup(stage=fn)
|
||||
self.setup(stage=fn)
|
||||
self.lightning_module.setup(stage=fn)
|
||||
self.call_hook("setup", stage=fn)
|
||||
|
||||
self.accelerator.barrier("post_setup")
|
||||
|
||||
|
@ -1259,8 +1251,8 @@ class Trainer(
|
|||
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
|
||||
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
|
||||
with self.accelerator.model_sharded_context():
|
||||
model.configure_sharded_model()
|
||||
self.on_configure_sharded_model()
|
||||
self.call_hook("configure_sharded_model")
|
||||
self.call_hook("on_configure_sharded_model")
|
||||
model.call_configure_sharded_model_hook = True
|
||||
self.accelerator.call_configure_sharded_model_hook = False
|
||||
|
||||
|
@ -1272,8 +1264,7 @@ class Trainer(
|
|||
|
||||
self.data_connector.detach_data(self.lightning_module)
|
||||
|
||||
self.teardown(stage=fn)
|
||||
self.lightning_module.teardown(stage=fn)
|
||||
self.call_hook("teardown", stage=fn)
|
||||
|
||||
self.lightning_module._current_fx_name = None
|
||||
self.lightning_module._current_dataloader_idx = None
|
||||
|
@ -1288,28 +1279,30 @@ class Trainer(
|
|||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
||||
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
|
||||
if self.lightning_module:
|
||||
prev_fx_name = self.lightning_module._current_fx_name
|
||||
self.lightning_module._current_fx_name = hook_name
|
||||
def call_hook(
|
||||
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
pl_module = self.lightning_module or pl_module
|
||||
if pl_module:
|
||||
prev_fx_name = pl_module._current_fx_name
|
||||
pl_module._current_fx_name = hook_name
|
||||
|
||||
# always profile hooks
|
||||
with self.profiler.profile(hook_name):
|
||||
|
||||
# first call trainer hook
|
||||
if hasattr(self, hook_name):
|
||||
trainer_hook = getattr(self, hook_name)
|
||||
trainer_hook(*args, **kwargs)
|
||||
callback_fx = getattr(self, hook_name, None)
|
||||
if callable(callback_fx):
|
||||
callback_fx(*args, **kwargs)
|
||||
|
||||
# next call hook in lightningModule
|
||||
output = None
|
||||
model_ref = self.lightning_module
|
||||
if is_overridden(hook_name, model_ref):
|
||||
hook_fx = getattr(model_ref, hook_name)
|
||||
output = hook_fx(*args, **kwargs)
|
||||
model_fx = getattr(pl_module, hook_name, None)
|
||||
if callable(model_fx):
|
||||
output = model_fx(*args, **kwargs)
|
||||
|
||||
# call the accelerator hook
|
||||
if hasattr(self.accelerator, hook_name):
|
||||
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
|
||||
accelerator_hook = getattr(self.accelerator, hook_name)
|
||||
accelerator_output = accelerator_hook(*args, **kwargs)
|
||||
# Rely on the accelerator output if lightningModule hook returns nothing
|
||||
|
@ -1317,9 +1310,9 @@ class Trainer(
|
|||
# todo: move this data parallel logic into the data parallel plugin
|
||||
output = accelerator_output if output is None else output
|
||||
|
||||
if self.lightning_module:
|
||||
if pl_module:
|
||||
# restore current_fx when nested context
|
||||
self.lightning_module._current_fx_name = prev_fx_name
|
||||
pl_module._current_fx_name = prev_fx_name
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -34,11 +34,8 @@ from tests.helpers.utils import reset_seed
|
|||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
|
||||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
|
||||
def test_can_prepare_data(local_rank, node_rank):
|
||||
|
||||
model = BoringModel()
|
||||
dm = BoringDataModule()
|
||||
trainer = Trainer()
|
||||
trainer.model = model
|
||||
trainer.datamodule = dm
|
||||
|
||||
# 1 no DM
|
||||
|
|
|
@ -12,11 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning import Callback, LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
GradientAccumulationScheduler,
|
||||
|
@ -36,18 +35,22 @@ def test_checkpoint_callbacks_are_last(tmpdir):
|
|||
lr_monitor = LearningRateMonitor()
|
||||
progress_bar = ProgressBar()
|
||||
|
||||
# no model callbacks
|
||||
model = Mock()
|
||||
model.configure_callbacks.return_value = []
|
||||
# no model reference
|
||||
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
cb_connector._attach_model_callbacks()
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
|
||||
|
||||
# no model callbacks
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: []
|
||||
trainer.model = model
|
||||
cb_connector._attach_model_callbacks()
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
|
||||
|
||||
# with model-specific callbacks that substitute ones in Trainer
|
||||
model = Mock()
|
||||
model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2]
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2]
|
||||
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
|
@ -89,8 +92,8 @@ def test_attach_model_callbacks():
|
|||
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""
|
||||
|
||||
def assert_composition(trainer_callbacks, model_callbacks, expected):
|
||||
model = Mock()
|
||||
model.configure_callbacks.return_value = model_callbacks
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: model_callbacks
|
||||
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
|
@ -140,8 +143,8 @@ def test_attach_model_callbacks():
|
|||
|
||||
def test_attach_model_callbacks_override_info(caplog):
|
||||
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
|
||||
model = Mock()
|
||||
model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()]
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
|
||||
trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -140,17 +141,145 @@ def test_fx_validator(tmpdir):
|
|||
if allowed:
|
||||
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
|
||||
if not is_start and is_stage:
|
||||
with pytest.raises(MisconfigurationException, match="You can't"):
|
||||
with pytest.raises(MisconfigurationException, match="must be one of"):
|
||||
validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch)
|
||||
else:
|
||||
assert func_name in not_supported
|
||||
with pytest.raises(MisconfigurationException, match="function doesn't support"):
|
||||
with pytest.raises(MisconfigurationException, match="You can't"):
|
||||
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
|
||||
|
||||
with pytest.raises(RuntimeError, match="`foo` but it is not implemented"):
|
||||
with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"):
|
||||
validator.check_logging("foo", False, False)
|
||||
|
||||
|
||||
class HookedCallback(Callback):
|
||||
def __init__(self, not_supported):
|
||||
def call(hook, trainer, model=None, *_, **__):
|
||||
lightning_module = trainer.lightning_module or model
|
||||
if lightning_module is None:
|
||||
# `on_init_{start,end}` do not have the `LightningModule` available
|
||||
assert hook in ("on_init_start", "on_init_end")
|
||||
return
|
||||
|
||||
if hook in not_supported:
|
||||
with pytest.raises(MisconfigurationException, match=not_supported[hook]):
|
||||
lightning_module.log("anything", 1)
|
||||
else:
|
||||
lightning_module.log(hook, 1)
|
||||
|
||||
for h in get_members(Callback):
|
||||
setattr(self, h, partial(call, h))
|
||||
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
def __init__(self, not_supported):
|
||||
super().__init__()
|
||||
pl_module_hooks = get_members(LightningModule)
|
||||
pl_module_hooks.difference_update(
|
||||
{
|
||||
"log",
|
||||
"log_dict",
|
||||
# the following are problematic as they do have `self._current_fx_name` defined some times but
|
||||
# not others depending on where they were called. So we cannot reliably `self.log` in them
|
||||
"on_before_batch_transfer",
|
||||
"transfer_batch_to_device",
|
||||
"on_after_batch_transfer",
|
||||
"get_progress_bar_dict",
|
||||
}
|
||||
)
|
||||
# remove `nn.Module` hooks
|
||||
module_hooks = get_members(torch.nn.Module)
|
||||
pl_module_hooks.difference_update(module_hooks)
|
||||
|
||||
def call(hook, fn, *args, **kwargs):
|
||||
out = fn(*args, **kwargs)
|
||||
|
||||
if hook in not_supported:
|
||||
with pytest.raises(MisconfigurationException, match=not_supported[hook]):
|
||||
self.log("anything", 1)
|
||||
else:
|
||||
self.log(hook, 1)
|
||||
return out
|
||||
|
||||
for h in pl_module_hooks:
|
||||
attr = getattr(self, h)
|
||||
setattr(self, h, partial(call, h, attr))
|
||||
|
||||
|
||||
def test_fx_validator_integration(tmpdir):
|
||||
"""Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors"""
|
||||
not_supported = {
|
||||
None: "`self.trainer` reference is not registered",
|
||||
"on_before_accelerator_backend_setup": "You can't",
|
||||
"setup": "You can't",
|
||||
"configure_sharded_model": "You can't",
|
||||
"on_configure_sharded_model": "You can't",
|
||||
"configure_optimizers": "You can't",
|
||||
"on_fit_start": "You can't",
|
||||
"on_pretrain_routine_start": "You can't",
|
||||
"on_pretrain_routine_end": "You can't",
|
||||
"on_train_dataloader": "You can't",
|
||||
"train_dataloader": "You can't",
|
||||
"on_val_dataloader": "You can't",
|
||||
"val_dataloader": "You can't",
|
||||
"on_validation_end": "You can't",
|
||||
"on_train_end": "You can't",
|
||||
"on_fit_end": "You can't",
|
||||
"teardown": "You can't",
|
||||
"on_sanity_check_start": "You can't",
|
||||
"on_sanity_check_end": "You can't",
|
||||
"prepare_data": "You can't",
|
||||
"configure_callbacks": "You can't",
|
||||
"on_validation_model_eval": "You can't",
|
||||
"summarize": "not managed by the `Trainer",
|
||||
}
|
||||
model = HookedModel(not_supported)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=not_supported[None]):
|
||||
model.log("foo", 1)
|
||||
|
||||
callback = HookedCallback(not_supported)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
limit_test_batches=1,
|
||||
limit_predict_batches=1,
|
||||
callbacks=callback,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
not_supported.update(
|
||||
{
|
||||
# `lightning_module` ref is now present from the `fit` call
|
||||
"on_before_accelerator_backend_setup": "You can't",
|
||||
"on_test_dataloader": "You can't",
|
||||
"test_dataloader": "You can't",
|
||||
"on_test_model_eval": "You can't",
|
||||
"on_test_end": "You can't",
|
||||
}
|
||||
)
|
||||
trainer.test(model, verbose=False)
|
||||
|
||||
not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported})
|
||||
not_supported.update(
|
||||
{
|
||||
"on_predict_dataloader": "ResultCollection` is not registered yet",
|
||||
"predict_dataloader": "ResultCollection` is not registered yet",
|
||||
"on_predict_model_eval": "ResultCollection` is not registered yet",
|
||||
"on_predict_start": "ResultCollection` is not registered yet",
|
||||
"on_predict_epoch_start": "ResultCollection` is not registered yet",
|
||||
"on_predict_batch_start": "ResultCollection` is not registered yet",
|
||||
"predict_step": "ResultCollection` is not registered yet",
|
||||
"on_predict_batch_end": "ResultCollection` is not registered yet",
|
||||
"on_predict_epoch_end": "ResultCollection` is not registered yet",
|
||||
"on_predict_end": "ResultCollection` is not registered yet",
|
||||
}
|
||||
)
|
||||
trainer.predict(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2)
|
||||
def test_epoch_results_cache_dp(tmpdir):
|
||||
|
||||
|
|
Loading…
Reference in New Issue