Deprecate `BaseProfiler.profile_iterable` (#12102)

This commit is contained in:
Akash Kwatra 2022-02-25 07:26:20 -08:00 committed by GitHub
parent 61dd5e4d5e
commit f5304897ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 56 additions and 27 deletions

View File

@ -430,6 +430,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
### Removed
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

View File

@ -20,6 +20,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
log = logging.getLogger(__name__)
@ -83,6 +84,14 @@ class BaseProfiler(AbstractProfiler):
self.stop(action_name)
def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator:
"""Profiles over each value of an iterable.
See deprecation message below.
.. deprecated:: v1.6
`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.
"""
rank_zero_deprecation("`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8.")
iterator = iter(iterable)
while True:
try:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.8.0."""
import time
from unittest.mock import Mock
import numpy as np
@ -33,6 +34,7 @@ from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnSharde
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
@ -608,3 +610,45 @@ def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir):
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
):
trainer.fit(model)
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list):
"""Ensure the reported durations are reasonably accurate."""
def _sleep_generator(durations):
"""the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long
it takes to call __next__"""
for duration in durations:
time.sleep(duration)
yield duration
def _get_python_cprofile_total_duration(profile):
return sum(x.inlinetime for x in profile.getstats())
simple_profiler = SimpleProfiler()
iterable = _sleep_generator(expected)
with pytest.deprecated_call(
match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
):
for _ in simple_profiler.profile_iterable(iterable, action):
pass
# we exclude the last item in the recorded durations since that's when StopIteration is raised
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
advanced_profiler = AdvancedProfiler(dirpath=tmpdir, filename="profiler")
iterable = _sleep_generator(expected)
with pytest.deprecated_call(
match="`BaseProfiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8."
):
for _ in advanced_profiler.profile_iterable(iterable, action):
pass
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
expected_total_duration = np.sum(expected)
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)

View File

@ -67,19 +67,6 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list)
np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2)
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_simple_profiler_iterable_durations(simple_profiler, action: str, expected: list):
"""Ensure the reported durations are reasonably accurate."""
iterable = _sleep_generator(expected)
for _ in simple_profiler.profile_iterable(iterable, action):
pass
# we exclude the last item in the recorded durations since that's when StopIteration is raised
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
@ -289,20 +276,6 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, expected: list):
"""Ensure the reported durations are reasonably accurate."""
iterable = _sleep_generator(expected)
for _ in advanced_profiler.profile_iterable(iterable, action):
pass
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
expected_total_duration = np.sum(expected)
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
@pytest.mark.flaky(reruns=3)
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""ensure that the profiler doesn't introduce too much overhead during training."""