Show tf32 info only on rank 0 (#16152)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-02-03 00:56:12 +01:00 committed by GitHub
parent 2e9861f10d
commit 85f7e1c9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 5 deletions

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
from contextlib import contextmanager
@ -19,6 +18,7 @@ from functools import lru_cache
from typing import Dict, Generator, List, Optional, Set, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_info
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.utilities.imports import (
@ -27,8 +27,6 @@ from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
)
_log = logging.getLogger(__name__)
class CUDAAccelerator(Accelerator):
"""Accelerator for NVIDIA CUDA devices."""
@ -252,7 +250,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None:
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
# `set_float32_matmul_precision`
if torch.get_float32_matmul_precision() == "highest": # default
_log.info(
rank_zero_info(
f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly"
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"
" precision for performance. For more details, read https://pytorch.org/docs/stable/generated/"

View File

@ -89,7 +89,11 @@ def test_force_nvml_based_cuda_check():
@RunIf(min_torch="1.12")
@mock.patch("torch.cuda.get_device_capability", return_value=(10, 1))
@mock.patch("torch.cuda.get_device_name", return_value="Z100")
def test_tf32_message(_, __, caplog):
def test_tf32_message(_, __, caplog, monkeypatch):
# for some reason, caplog doesn't work with our rank_zero_info utilities
monkeypatch.setattr(lightning.fabric.accelerators.cuda, "rank_zero_info", logging.info)
device = Mock()
expected = "Z100') that has Tensor Cores"
assert torch.get_float32_matmul_precision() == "highest" # default in torch