feature: Allow str arguments in Trainer.profiler (#3656)

* allow trainer's profiler param to have a str value

* add tests

* update docs

* update exception message

* Update CHANGELOG

* fix pep8 issues

* cleanup test code

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Add deprecation warning if using bool for profiler

* Add deprecation tests and move deprecated tests

* Remove bool option to profiler from docs

* Deprecate bool args to profiler in CHANGELOG

* fixup! Add deprecation warning if using bool for profiler

* fixup! Add deprecation tests and move deprecated tests

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Implement suggestions, remove whitespace

* fixup! Implement suggestions, remove whitespace

* Allow bool, str (case insensitive), BaseProfiler

* Add info about bool deprecation to trainer

* fixup! Add info about bool deprecation to trainer

* Move deprecate todo to test_deprecated

* Test wrong profiler type, improve error message

* fixup! Test wrong profiler type, improve error message

* Update pytorch_lightning/trainer/connectors/profiler_connector.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Apply suggestions from code review

* Readd bool to profiler types, test cli profiler arg

* Remove extra whitespace in doc

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update deprecation versions

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Dusan Drevicky 2020-10-27 11:57:16 +01:00 committed by GitHub
parent 48b6de0c40
commit c50c225f05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 115 additions and 19 deletions

View File

@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344))
- Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
### Changed
@ -48,6 +51,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))
- Deprecated bool values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656))
### Removed

View File

@ -22,12 +22,12 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
Enable simple profiling
-----------------------
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
If you only wish to profile the standard actions, you can set `profiler="simple"`
when constructing your `Trainer` object.
.. code-block:: python
trainer = Trainer(..., profiler=True)
trainer = Trainer(..., profiler="simple")
The profiler's results will be printed at the completion of a training `fit()`.
@ -59,6 +59,10 @@ This option uses Python's cProfiler_ to provide a report of time spent on *each*
.. code-block:: python
trainer = Trainer(..., profiler="advanced")
or
profiler = AdvancedProfiler()
trainer = Trainer(..., profiler=profiler)

View File

@ -1199,14 +1199,11 @@ See the :ref:`profiler documentation <profiler>`. for more details.
# default used by the Trainer
trainer = Trainer(profiler=None)
# to profile standard training events
trainer = Trainer(profiler=True)
# to profile standard training events, equivalent to `profiler=SimpleProfiler()`
trainer = Trainer(profiler="simple")
# equivalent to profiler=True
trainer = Trainer(profiler=SimpleProfiler())
# advanced profiler for function-level stats
trainer = Trainer(profiler=AdvancedProfiler())
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
trainer = Trainer(profiler="advanced")
progress_bar_refresh_rate
^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -11,7 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from pytorch_lightning.profiler import PassThroughProfiler, SimpleProfiler
from typing import Union
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler, AdvancedProfiler
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class ProfilerConnector:
@ -19,8 +24,27 @@ class ProfilerConnector:
def __init__(self, trainer):
self.trainer = trainer
def on_trainer_init(self, profiler):
# configure profiler
if profiler is True:
profiler = SimpleProfiler()
def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
"are valid values for `Trainer`'s `profiler` parameter. "
f"Received {profiler} which is of type {type(profiler)}.")
if isinstance(profiler, bool):
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.",
DeprecationWarning)
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
profiler = profiler.lower()
if profiler == "simple":
profiler = SimpleProfiler()
elif profiler == "advanced":
profiler = AdvancedProfiler()
else:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")
self.trainer.profiler = profiler or PassThroughProfiler()

View File

@ -120,7 +120,7 @@ class Trainer(
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[Union[BaseProfiler, bool]] = None,
profiler: Optional[Union[BaseProfiler, bool, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
@ -212,7 +212,8 @@ class Trainer(
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
profiler: To profile individual steps during training and assist in identifying bottlenecks.
profiler: To profile individual steps during training and assist in identifying bottlenecks. Passing bool
value is deprecated in v1.1 and will be removed in v1.3.
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0

View File

@ -174,8 +174,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
elif str in arg_types:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general

View File

@ -1,13 +1,17 @@
"""Test deprecated functionality which will be removed in vX.Y.Z"""
from argparse import ArgumentParser
import pytest
import sys
from unittest import mock
import torch
from tests.base import EvalModelTemplate
from pytorch_lightning.metrics.functional.classification import auc
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -22,6 +26,37 @@ def test_tbd_remove_in_v1_2_0():
checkpoint_cb = ModelCheckpoint(filepath='.', dirpath='.')
# TODO: remove bool from Trainer.profiler param in v1.3.0, update profiler_connector.py
@pytest.mark.parametrize(['profiler', 'expected'], [
(True, SimpleProfiler),
(False, PassThroughProfiler),
])
def test_trainer_profiler_remove_in_v1_3_0(profiler, expected):
with pytest.deprecated_call(match='will be removed in v1.3'):
trainer = Trainer(profiler=profiler)
assert isinstance(trainer.profiler, expected)
@pytest.mark.parametrize(
['cli_args', 'expected_parsed_arg', 'expected_profiler'],
[
('--profiler', True, SimpleProfiler),
('--profiler True', True, SimpleProfiler),
('--profiler False', False, PassThroughProfiler),
],
)
def test_trainer_cli_profiler_remove_in_v1_3_0(cli_args, expected_parsed_arg, expected_profiler):
cli_args = cli_args.split(' ')
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)
assert getattr(args, "profiler") == expected_parsed_arg
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer.profiler, expected_profiler)
def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:

View File

@ -32,6 +32,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -1408,3 +1409,32 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
trainer.fit(model)
expected_calls = [call(metrics=ANY, step=s) for s in range(log_interval - 1, max_steps, log_interval)]
log_metrics_mock.assert_has_calls(expected_calls)
@pytest.mark.parametrize(['profiler', 'expected'], [
(None, PassThroughProfiler),
(SimpleProfiler(), SimpleProfiler),
(AdvancedProfiler(), AdvancedProfiler),
('simple', SimpleProfiler),
('Simple', SimpleProfiler),
('advanced', AdvancedProfiler),
])
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {'profiler': profiler} if profiler is not None else {}
trainer = Trainer(**kwargs)
assert isinstance(trainer.profiler, expected)
def test_trainer_profiler_incorrect_str_arg():
with pytest.raises(ValueError, match=r".*can only be 'simple' or 'advanced'"):
Trainer(profiler="unknown_profiler")
@pytest.mark.parametrize('profiler', (
42, [42], {"a": 42}, torch.tensor(42), Trainer(),
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(MisconfigurationException,
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
r"are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)