Show tf32 info only on rank 0 (#16152)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
2e9861f10d
commit
85f7e1c9c8
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
@ -19,6 +18,7 @@ from functools import lru_cache
|
|||
from typing import Dict, Generator, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_info
|
||||
|
||||
from lightning.fabric.accelerators.accelerator import Accelerator
|
||||
from lightning.fabric.utilities.imports import (
|
||||
|
@ -27,8 +27,6 @@ from lightning.fabric.utilities.imports import (
|
|||
_TORCH_GREATER_EQUAL_2_0,
|
||||
)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CUDAAccelerator(Accelerator):
|
||||
"""Accelerator for NVIDIA CUDA devices."""
|
||||
|
@ -252,7 +250,7 @@ def _check_cuda_matmul_precision(device: torch.device) -> None:
|
|||
# check that the user hasn't changed the precision already, this works for both `allow_tf32 = True` and
|
||||
# `set_float32_matmul_precision`
|
||||
if torch.get_float32_matmul_precision() == "highest": # default
|
||||
_log.info(
|
||||
rank_zero_info(
|
||||
f"You are using a CUDA device ({torch.cuda.get_device_name(device)!r}) that has Tensor Cores. To properly"
|
||||
" utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off"
|
||||
" precision for performance. For more details, read https://pytorch.org/docs/stable/generated/"
|
||||
|
|
|
@ -89,7 +89,11 @@ def test_force_nvml_based_cuda_check():
|
|||
@RunIf(min_torch="1.12")
|
||||
@mock.patch("torch.cuda.get_device_capability", return_value=(10, 1))
|
||||
@mock.patch("torch.cuda.get_device_name", return_value="Z100")
|
||||
def test_tf32_message(_, __, caplog):
|
||||
def test_tf32_message(_, __, caplog, monkeypatch):
|
||||
|
||||
# for some reason, caplog doesn't work with our rank_zero_info utilities
|
||||
monkeypatch.setattr(lightning.fabric.accelerators.cuda, "rank_zero_info", logging.info)
|
||||
|
||||
device = Mock()
|
||||
expected = "Z100') that has Tensor Cores"
|
||||
assert torch.get_float32_matmul_precision() == "highest" # default in torch
|
||||
|
|
Loading…
Reference in New Issue