update usage of deprecated profiler (#5010)

* drop deprecated profiler

* lut

Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
This commit is contained in:
Jirka Borovec 2020-12-10 08:38:14 +01:00 committed by GitHub
parent cdbddbe99f
commit 77fb425dd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 11 deletions

View File

@ -237,7 +237,7 @@ def run_cli():
help='seed for initializing training.')
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
parser.set_defaults(
profiler=True,
profiler="simple",
deterministic=True,
max_epochs=90,
)

View File

@ -18,6 +18,11 @@ from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, Simple
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
}
class ProfilerConnector:
@ -28,9 +33,9 @@ class ProfilerConnector:
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
# TODO: Update exception on removal of bool
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler` "
"are valid values for `Trainer`'s `profiler` parameter. "
f"Received {profiler} which is of type {type(profiler)}.")
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`"
" are valid values for `Trainer`'s `profiler` parameter."
f" Received {profiler} which is of type {type(profiler)}.")
if isinstance(profiler, bool):
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
@ -39,11 +44,9 @@ class ProfilerConnector:
if profiler:
profiler = SimpleProfiler()
elif isinstance(profiler, str):
profiler = profiler.lower()
if profiler == "simple":
profiler = SimpleProfiler()
elif profiler == "advanced":
profiler = AdvancedProfiler()
if profiler.lower() in PROFILERS:
profiler_class = PROFILERS[profiler.lower()]
profiler = profiler_class()
else:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")

View File

@ -1476,6 +1476,6 @@ def test_trainer_profiler_incorrect_str_arg():
))
def test_trainer_profiler_incorrect_arg_type(profiler):
with pytest.raises(MisconfigurationException,
match=r"Only None, bool, str and subclasses of `BaseProfiler` "
r"are valid values for `Trainer`'s `profiler` parameter. *"):
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)