From f5304897cec970bb1cfaef901bbe12532242f1e3 Mon Sep 17 00:00:00 2001
From: Akash Kwatra <akashkw@gmail.com>
Date: Fri, 25 Feb 2022 07:26:20 -0800
Subject: [PATCH] Deprecate `BaseProfiler.profile_iterable` (#12102)

---
 CHANGELOG.md                            |  3 ++
 pytorch_lightning/profiler/base.py      |  9 +++++
 tests/deprecated_api/test_remove_1-8.py | 44 +++++++++++++++++++++++++
 tests/profiler/test_profiler.py         | 27 ---------------
 4 files changed, 56 insertions(+), 27 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9957fd4ffa..058c92fe20 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -430,6 +430,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 - Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
 
 
+- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
+
+
 ### Removed
 
 - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
diff --git a/pytorch_lightning/profiler/base.py b/pytorch_lightning/profiler/base.py
index d453095b24..706e71fa5a 100644
--- a/pytorch_lightning/profiler/base.py
+++ b/pytorch_lightning/profiler/base.py
@@ -20,6 +20,7 @@ from pathlib import Path
 from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union
 
 from pytorch_lightning.utilities.cloud_io import get_filesystem
+from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
 
 log = logging.getLogger(__name__)
 
@@ -83,6 +84,14 @@ class BaseProfiler(AbstractProfiler):
             self.stop(action_name)
 
     def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
+        """Profiles over each value of an iterable.
+
+        See deprecation message below.
+
+        .. deprecated:: v1.6
+            `BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.
+        """
+        rank_zero_deprecation("`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.")
         iterator = iter(iterable)
         while True:
             try:
diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py
index d51c5638f0..09e82ddbe3 100644
--- a/tests/deprecated_api/test_remove_1-8.py
+++ b/tests/deprecated_api/test_remove_1-8.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Test deprecated functionality which will be removed in v1.8.0."""
+import time
 from unittest.mock import Mock
 
 import numpy as np
@@ -33,6 +34,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnSharde
 from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
 from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
 from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
+from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
 from pytorch_lightning.trainer.states import RunningStage
 from pytorch_lightning.utilities.apply_func import move_data_to_device
 from pytorch_lightning.utilities.enums import DeviceType, DistributedType
@@ -608,3 +610,45 @@ def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir):
         match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
     ):
         trainer.fit(model)
+
+
+@pytest.mark.flaky(reruns=3)
+@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
+def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list):
+    """Ensure the reported durations are reasonably accurate."""
+
+    def _sleep_generator(durations):
+        """the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long
+        it takes to call __next__"""
+        for duration in durations:
+            time.sleep(duration)
+            yield duration
+
+    def _get_python_cprofile_total_duration(profile):
+        return sum(x.inlinetime for x in profile.getstats())
+
+    simple_profiler = SimpleProfiler()
+    iterable = _sleep_generator(expected)
+
+    with pytest.deprecated_call(
+        match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
+    ):
+        for _ in simple_profiler.profile_iterable(iterable, action):
+            pass
+
+    # we exclude the last item in the recorded durations since that's when StopIteration is raised
+    np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
+
+    advanced_profiler = AdvancedProfiler(dirpath=tmpdir, filename="profiler")
+
+    iterable = _sleep_generator(expected)
+
+    with pytest.deprecated_call(
+        match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
+    ):
+        for _ in advanced_profiler.profile_iterable(iterable, action):
+            pass
+
+    recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
+    expected_total_duration = np.sum(expected)
+    np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py
index 137068626b..f63a31f6d8 100644
--- a/tests/profiler/test_profiler.py
+++ b/tests/profiler/test_profiler.py
@@ -67,19 +67,6 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list)
     np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2)
 
 
-@pytest.mark.flaky(reruns=3)
-@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
-def test_simple_profiler_iterable_durations(simple_profiler, action: str, expected: list):
-    """Ensure the reported durations are reasonably accurate."""
-    iterable = _sleep_generator(expected)
-
-    for _ in simple_profiler.profile_iterable(iterable, action):
-        pass
-
-    # we exclude the last item in the recorded durations since that's when StopIteration is raised
-    np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
-
-
 def test_simple_profiler_overhead(simple_profiler, n_iter=5):
     """Ensure that the profiler doesn't introduce too much overhead during training."""
     for _ in range(n_iter):
@@ -289,20 +276,6 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l
     np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
 
 
-@pytest.mark.flaky(reruns=3)
-@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
-def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, expected: list):
-    """Ensure the reported durations are reasonably accurate."""
-    iterable = _sleep_generator(expected)
-
-    for _ in advanced_profiler.profile_iterable(iterable, action):
-        pass
-
-    recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
-    expected_total_duration = np.sum(expected)
-    np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
-
-
 @pytest.mark.flaky(reruns=3)
 def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
     """ensure that the profiler doesn't introduce too much overhead during training."""