Avoid `inference_mode` with FSDP (#17064)
This commit is contained in:
parent
8434ee7402
commit
67b94ef124
|
@ -13,19 +13,21 @@
|
|||
# limitations under the License.
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Generator, Optional, Tuple
|
||||
from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch.accelerators import TPUAccelerator
|
||||
from lightning.pytorch.callbacks.timer import Timer
|
||||
from lightning.pytorch.loops import _Loop
|
||||
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher
|
||||
from lightning.pytorch.loops.progress import _BaseProgress
|
||||
from lightning.pytorch.strategies import FSDPStrategy
|
||||
from lightning.pytorch.strategies.parallel import ParallelStrategy
|
||||
from lightning.pytorch.strategies.strategy import Strategy
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
|
@ -153,16 +155,21 @@ def _no_grad_context(loop_run: Callable) -> Callable:
|
|||
raise TypeError(f"`{type(self).__name__}` needs to be a Loop.")
|
||||
if not hasattr(self, "inference_mode"):
|
||||
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined")
|
||||
context_manager = (
|
||||
torch.inference_mode
|
||||
if (
|
||||
self.inference_mode
|
||||
# inference mode is not supported with gloo backend (#9431) and TPU accelerators.
|
||||
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
|
||||
and not isinstance(self.trainer.accelerator, TPUAccelerator)
|
||||
)
|
||||
else torch.no_grad
|
||||
)
|
||||
context_manager: Type[ContextManager]
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo":
|
||||
# gloo backend does not work properly.
|
||||
# https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110
|
||||
# TODO: explore why and possibly open an issue in PyTorch repository
|
||||
context_manager = torch.no_grad
|
||||
elif isinstance(self.trainer.accelerator, TPUAccelerator):
|
||||
context_manager = torch.no_grad
|
||||
elif _TORCH_GREATER_EQUAL_1_13 and isinstance(self.trainer.strategy, FSDPStrategy):
|
||||
# https://github.com/pytorch/pytorch/issues/95957
|
||||
context_manager = torch.no_grad
|
||||
elif self.inference_mode:
|
||||
context_manager = torch.inference_mode
|
||||
else:
|
||||
context_manager = torch.no_grad
|
||||
with context_manager():
|
||||
return loop_run(self, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -247,7 +247,6 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy):
|
|||
limit_test_batches=2,
|
||||
limit_predict_batches=2,
|
||||
callbacks=[ck],
|
||||
inference_mode=not _TORCH_GREATER_EQUAL_2_0, # TODO(carmocca): inference_mode raises RuntimeError
|
||||
)
|
||||
_run_multiple_stages(trainer, model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue