From 6599ced17d939900c9cffdbf5a856af74d54f2d5 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 6 Dec 2021 12:01:38 +0530 Subject: [PATCH] Don't import torch_xla.debug for torch-xla<1.8 (#10836) --- CHANGELOG.md | 3 +++ pytorch_lightning/profiler/xla.py | 9 +++++++-- tests/profiler/test_xla_profiler.py | 6 ++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a168db7b4..f49707b296 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) +- Fixed importing `torch_xla.debug` for `torch-xla<1.8` ([#10836](https://github.com/PyTorchLightning/pytorch-lightning/pull/10836)) + + - Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810)) diff --git a/pytorch_lightning/profiler/xla.py b/pytorch_lightning/profiler/xla.py index e30f06f84e..c89685bcad 100644 --- a/pytorch_lightning/profiler/xla.py +++ b/pytorch_lightning/profiler/xla.py @@ -42,9 +42,10 @@ import logging from typing import Dict from pytorch_lightning.profiler.base import BaseProfiler -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _TPU_AVAILABLE: +if _TPU_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: import torch_xla.debug.profiler as xp log = logging.getLogger(__name__) @@ -65,6 +66,10 @@ class XLAProfiler(BaseProfiler): def __init__(self, port: int = 9012) -> None: """This Profiler will help you debug and optimize training workload performance for your models using Cloud TPU performance tools.""" + if not _TPU_AVAILABLE: + raise MisconfigurationException("`XLAProfiler` is only supported on TPUs") + if not _TORCH_GREATER_EQUAL_1_8: + raise MisconfigurationException("`XLAProfiler` is only supported with `torch-xla >= 1.8`") super().__init__(dirpath=None, filename=None) self.port = port self._recording_map: Dict = {} diff --git a/tests/profiler/test_xla_profiler.py b/tests/profiler/test_xla_profiler.py index 2afbf69a6d..7f460ea11d 100644 --- a/tests/profiler/test_xla_profiler.py +++ b/tests/profiler/test_xla_profiler.py @@ -18,14 +18,16 @@ import pytest from pytorch_lightning import Trainer from pytorch_lightning.profiler import XLAProfiler -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf if _TPU_AVAILABLE: - import torch_xla.debug.profiler as xp import torch_xla.utils.utils as xu + if _TORCH_GREATER_EQUAL_1_8: + import torch_xla.debug.profiler as xp + @RunIf(tpu=True) def test_xla_profiler_instance(tmpdir):