From 85f7e1c9c852e94f7e786751da55ffd7aea3f474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 3 Feb 2023 00:56:12 +0100 Subject: [PATCH] Show tf32 info only on rank 0 (#16152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/fabric/accelerators/cuda.py | 6 ++---- tests/tests_fabric/accelerators/test_cuda.py | 6 +++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 82bd5560b0..90ed59b4e2 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import warnings from contextlib import contextmanager @@ -19,6 +18,7 @@ from functools import lru_cache from typing import Dict, Generator, List, Optional, Set, Union import torch +from lightning_utilities.core.rank_zero import rank_zero_info from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.utilities.imports import ( @@ -27,8 +27,6 @@ from lightning.fabric.utilities.imports import ( _TORCH_GREATER_EQUAL_2_0, ) -_log = logging.getLogger(__name__) - class CUDAAccelerator(Accelerator): """Accelerator for NVIDIA CUDA devices.""" @@ -252,7 +250,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None: # check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and # `set_float32_matmul_precision` if torch.get_float32_matmul_precision() == "highest": # default - _log.info( + rank_zero_info( f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly" " utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off" " precision for performance. For more details, read https://pytorch.org/docs/stable/generated/" diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index 47a7b59765..f93970f816 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -89,7 +89,11 @@ def test_force_nvml_based_cuda_check(): @RunIf(min_torch="1.12") @mock.patch("torch.cuda.get_device_capability", return_value=(10, 1)) @mock.patch("torch.cuda.get_device_name", return_value="Z100") -def test_tf32_message(_, __, caplog): +def test_tf32_message(_, __, caplog, monkeypatch): + + # for some reason, caplog doesn't work with our rank_zero_info utilities + monkeypatch.setattr(lightning.fabric.accelerators.cuda, "rank_zero_info", logging.info) + device = Mock() expected = "Z100') that has Tensor Cores" assert torch.get_float32_matmul_precision() == "highest" # default in torch