Avoid `inference_mode` with FSDP (#17064)

This commit is contained in:
Carlos Mocholí 2023-03-14 17:03:12 +01:00 committed by GitHub
parent 8434ee7402
commit 67b94ef124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 12 deletions

View File

@ -13,19 +13,21 @@
# limitations under the License.
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional, Tuple
from typing import Any, Callable, ContextManager, Generator, Optional, Tuple, Type
import torch
import torch.distributed as dist
from torch import Tensor
import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher, _PrefetchDataFetcher
from lightning.pytorch.loops.progress import _BaseProgress
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.trainer.states import RunningStage
@ -153,16 +155,21 @@ def _no_grad_context(loop_run: Callable) -> Callable:
raise TypeError(f"`{type(self).__name__}` needs to be a Loop.")
if not hasattr(self, "inference_mode"):
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined")
context_manager = (
torch.inference_mode
if (
self.inference_mode
# inference mode is not supported with gloo backend (#9431) and TPU accelerators.
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
and not isinstance(self.trainer.accelerator, TPUAccelerator)
)
else torch.no_grad
)
context_manager: Type[ContextManager]
if dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo":
# gloo backend does not work properly.
# https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110
# TODO: explore why and possibly open an issue in PyTorch repository
context_manager = torch.no_grad
elif isinstance(self.trainer.accelerator, TPUAccelerator):
context_manager = torch.no_grad
elif _TORCH_GREATER_EQUAL_1_13 and isinstance(self.trainer.strategy, FSDPStrategy):
# https://github.com/pytorch/pytorch/issues/95957
context_manager = torch.no_grad
elif self.inference_mode:
context_manager = torch.inference_mode
else:
context_manager = torch.no_grad
with context_manager():
return loop_run(self, *args, **kwargs)

View File

@ -247,7 +247,6 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy):
limit_test_batches=2,
limit_predict_batches=2,
callbacks=[ck],
inference_mode=not _TORCH_GREATER_EQUAL_2_0, # TODO(carmocca): inference_mode raises RuntimeError
)
_run_multiple_stages(trainer, model)