From 9059d21042a5f18fcb18a1792a901e8e62a3b61a Mon Sep 17 00:00:00 2001 From: Oliver Neumann Date: Tue, 12 May 2020 14:53:26 +0200 Subject: [PATCH] Missing profiler attribute in add_argparse_args() ArgumentParser (#1794) * Fixed typing annotation by adding boolean type. After that Profiler flag will be added to argparse. * Updated CHANGELOG.md * Updated git_init_arguments_and_types() to pass doctests. * Added doctest example to add_argparse_parser() --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/trainer.py | 29 +++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57143d63a7..ce4c8dbae5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561)) +- Fixed missing profiler attribute in add_argparse_args() ArgumentParser ([#1794](https://github.com/PyTorchLightning/pytorch-lightning/pull/1794)) + ## [0.7.5] - 2020-04-27 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d401a94645..aa3d87db01 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -124,7 +124,7 @@ class Trainer( num_sanity_val_steps: int = 5, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, - profiler: Optional[BaseProfiler] = None, + profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, @@ -584,6 +584,7 @@ class Trainer( ('process_position', (,), 0), ('profiler', (, + , ), None), ... @@ -622,6 +623,32 @@ class Trainer( Only arguments of the allowed types (str, float, int, bool) will extend the `parent_parser`. + + Examples: + >>> import argparse + >>> import pprint + >>> parser = argparse.ArgumentParser() + >>> parser = Trainer.add_argparse_args(parser) + >>> args = parser.parse_args([]) + >>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + {... + 'check_val_every_n_epoch': 1, + 'checkpoint_callback': True, + 'default_root_dir': None, + 'distributed_backend': None, + 'early_stop_callback': False, + ... + 'logger': True, + 'max_epochs': 1000, + 'max_steps': None, + 'min_epochs': 1, + 'min_steps': None, + ... + 'profiler': None, + 'progress_bar_callback': True, + 'progress_bar_refresh_rate': 1, + ...} + """ parser = ArgumentParser(parents=[parent_parser], add_help=False, )