diff --git a/docs/source/hooks.rst b/docs/source/hooks.rst index d2f86531c1..c0da6ed3ec 100644 --- a/docs/source/hooks.rst +++ b/docs/source/hooks.rst @@ -24,10 +24,7 @@ Training set-up - :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` - :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection` - :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex` -- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp` - :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.train_dataloader` - :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader` - :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader` diff --git a/docs/source/lightning_module.rst b/docs/source/lightning_module.rst index dfa8977299..c51718c023 100644 --- a/docs/source/lightning_module.rst +++ b/docs/source/lightning_module.rst @@ -1024,17 +1024,6 @@ Advanced hooks ^^^^^^^^^^^^^^ Use these hooks to modify advanced functionality -configure_apex -~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_apex - :noindex: - -configure_ddp -~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_ddp - :noindex: configure_sync_batchnorm ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e1f82968c7..c2f0ce611c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -985,40 +985,6 @@ class LightningModule( return model - def configure_apex( - self, - amp: object, - model: "LightningModule", - optimizers: List[Optimizer], - amp_level: str, - ) -> Tuple["LightningModule", List[Optimizer]]: - r""" - Override to init AMP your own way. - Must return a model and list of optimizers. - - Args: - amp: pointer to amp library object. - model: pointer to current :class:`LightningModule`. - optimizers: list of optimizers passed in :meth:`configure_optimizers`. - amp_level: AMP mode chosen ('O1', 'O2', etc...) - - Return: - Apex wrapped model and optimizers - - Examples: - .. code-block:: python - - # Default implementation used by Trainer. - def configure_apex(self, amp, model, optimizers, amp_level): - model, optimizers = amp.initialize( - model, optimizers, opt_level=amp_level, - ) - - return model, optimizers - """ - model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) - return model, optimizers - def configure_optimizers( self, ): diff --git a/pytorch_lightning/plugins/apex.py b/pytorch_lightning/plugins/apex.py index ee5c17a4e6..40317ac7bf 100644 --- a/pytorch_lightning/plugins/apex.py +++ b/pytorch_lightning/plugins/apex.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple +from torch.optim.optimizer import Optimizer try: from apex import amp @@ -24,10 +26,44 @@ class ApexPlugin: self.trainer = trainer def connect(self, model, optimizers): - model, optimizers = model.configure_apex(amp, model, optimizers, self.trainer.amp_level) + model, optimizers = self.configure_apex(amp, model, optimizers, self.trainer.amp_level) self.trainer.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers) return model, optimizers def training_step(self, fx, args): output = fx(args) return output + + def configure_apex( + self, + amp: object, + model: "LightningModule", + optimizers: List[Optimizer], + amp_level: str, + ) -> Tuple["LightningModule", List[Optimizer]]: + r""" + Override to init AMP your own way. + Must return a model and list of optimizers. + + Args: + amp: pointer to amp library object. + model: pointer to current :class:`LightningModule`. + optimizers: list of optimizers passed in :meth:`configure_optimizers`. + amp_level: AMP mode chosen ('O1', 'O2', etc...) + + Return: + Apex wrapped model and optimizers + + Examples: + .. code-block:: python + + # Default implementation used by Trainer. + def configure_apex(self, amp, model, optimizers, amp_level): + model, optimizers = amp.initialize( + model, optimizers, opt_level=amp_level, + ) + + return model, optimizers + """ + model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level) + return model, optimizers