diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6b77e4b05a..bf9c30190c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -164,6 +164,13 @@ class LightningModule( """ return self.device.type == "cuda" + @property + def automatic_optimization(self) -> bool: + """ + If False you are responsible for calling .backward, .step, zero_grad. + """ + return True + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index ca7fce6e2d..dbdceb1532 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -35,6 +35,9 @@ class ModelConnector: else: ref_model = model + automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization + self.trainer.train_loop.automatic_optimization = automatic_optimization + for m in [model, ref_model]: m.trainer = self.trainer m.logger = self.trainer.logger diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d5bdd9e318..46e4abbe58 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -137,7 +137,7 @@ class Trainer( amp_backend: str = 'native', amp_level: str = 'O2', distributed_backend: Optional[str] = None, - automatic_optimization: bool = True, + automatic_optimization: Optional[bool] = None, move_metrics_to_cpu: bool = False, ): r""" @@ -212,7 +212,9 @@ class Trainer( log_every_n_steps: How often to log within steps (defaults to every 50 steps). automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad. - Meant to be used with multiple optimizers by advanced users. + If False you are responsible for calling .backward, .step, zero_grad in LightningModule. + This argument has been moved to LightningModule. It is deprecated here in v1.1 and + will be removed in v1.3. prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data @@ -355,6 +357,14 @@ class Trainer( ) # init train loop related flags + # TODO: deprecate in 1.2.0 + if automatic_optimization is None: + automatic_optimization = True + else: + rank_zero_warn( + "Disable automatic optimization with the trainer flag is deprecated and will be removed in v1.3.0!" + "Please use the property on the LightningModule for disabling automatic optimization" + ) self.train_loop.on_trainer_init( max_epochs, min_epochs,