diff --git a/CHANGELOG.md b/CHANGELOG.md index 9957fd4ffa..058c92fe20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -430,6 +430,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832)) +- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/pytorch_lightning/profiler/base.py b/pytorch_lightning/profiler/base.py index d453095b24..706e71fa5a 100644 --- a/pytorch_lightning/profiler/base.py +++ b/pytorch_lightning/profiler/base.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation log = logging.getLogger(__name__) @@ -83,6 +84,14 @@ class BaseProfiler(AbstractProfiler): self.stop(action_name) def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator: + """Profiles over each value of an iterable. + + See deprecation message below. + + .. deprecated:: v1.6 + `BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8. + """ + rank_zero_deprecation("`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.") iterator = iter(iterable) while True: try: diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index d51c5638f0..09e82ddbe3 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.8.0.""" +import time from unittest.mock import Mock import numpy as np @@ -33,6 +34,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnSharde from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType @@ -608,3 +610,45 @@ def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir): match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8" ): trainer.fit(model) + + +@pytest.mark.flaky(reruns=3) +@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])]) +def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list): + """Ensure the reported durations are reasonably accurate.""" + + def _sleep_generator(durations): + """the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long + it takes to call __next__""" + for duration in durations: + time.sleep(duration) + yield duration + + def _get_python_cprofile_total_duration(profile): + return sum(x.inlinetime for x in profile.getstats()) + + simple_profiler = SimpleProfiler() + iterable = _sleep_generator(expected) + + with pytest.deprecated_call( + match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." + ): + for _ in simple_profiler.profile_iterable(iterable, action): + pass + + # we exclude the last item in the recorded durations since that's when StopIteration is raised + np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2) + + advanced_profiler = AdvancedProfiler(dirpath=tmpdir, filename="profiler") + + iterable = _sleep_generator(expected) + + with pytest.deprecated_call( + match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." + ): + for _ in advanced_profiler.profile_iterable(iterable, action): + pass + + recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action]) + expected_total_duration = np.sum(expected) + np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 137068626b..f63a31f6d8 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -67,19 +67,6 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list) np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2) -@pytest.mark.flaky(reruns=3) -@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])]) -def test_simple_profiler_iterable_durations(simple_profiler, action: str, expected: list): - """Ensure the reported durations are reasonably accurate.""" - iterable = _sleep_generator(expected) - - for _ in simple_profiler.profile_iterable(iterable, action): - pass - - # we exclude the last item in the recorded durations since that's when StopIteration is raised - np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2) - - def test_simple_profiler_overhead(simple_profiler, n_iter=5): """Ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): @@ -289,20 +276,6 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) -@pytest.mark.flaky(reruns=3) -@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])]) -def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, expected: list): - """Ensure the reported durations are reasonably accurate.""" - iterable = _sleep_generator(expected) - - for _ in advanced_profiler.profile_iterable(iterable, action): - pass - - recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action]) - expected_total_duration = np.sum(expected) - np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) - - @pytest.mark.flaky(reruns=3) def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): """ensure that the profiler doesn't introduce too much overhead during training."""