moves init apex from LM to apex connector (#3923)

This commit is contained in:
William Falcon 2020-10-06 21:31:56 -04:00 committed by GitHub
parent c1559a1476
commit d71ed277d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 49 deletions

View File

@ -24,10 +24,7 @@ Training set-up
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection`
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.train_dataloader`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader`
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader`

View File

@ -1024,17 +1024,6 @@ Advanced hooks
^^^^^^^^^^^^^^
Use these hooks to modify advanced functionality
configure_apex
~~~~~~~~~~~~~~
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_apex
:noindex:
configure_ddp
~~~~~~~~~~~~~
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_ddp
:noindex:
configure_sync_batchnorm
~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -985,40 +985,6 @@ class LightningModule(
return model
def configure_apex(
self,
amp: object,
model: "LightningModule",
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple["LightningModule", List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Examples:
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers
def configure_optimizers(
self,
):

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from torch.optim.optimizer import Optimizer
try:
from apex import amp
@ -24,10 +26,44 @@ class ApexPlugin:
self.trainer = trainer
def connect(self, model, optimizers):
model, optimizers = model.configure_apex(amp, model, optimizers, self.trainer.amp_level)
model, optimizers = self.configure_apex(amp, model, optimizers, self.trainer.amp_level)
self.trainer.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
return model, optimizers
def training_step(self, fx, args):
output = fx(args)
return output
def configure_apex(
self,
amp: object,
model: "LightningModule",
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple["LightningModule", List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.
Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)
Return:
Apex wrapped model and optimizers
Examples:
.. code-block:: python
# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)
return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers