moves init apex from LM to apex connector (#3923)
This commit is contained in:
parent
c1559a1476
commit
d71ed277d4
|
@ -24,10 +24,7 @@ Training set-up
|
|||
|
||||
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data`
|
||||
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup`
|
||||
- :meth:`~pytorch_lightning.core.lightning.LightningModule.init_ddp_connection`
|
||||
- :meth:`~pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin.init_optimizers`
|
||||
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_apex`
|
||||
- :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`
|
||||
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.train_dataloader`
|
||||
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader`
|
||||
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader`
|
||||
|
|
|
@ -1024,17 +1024,6 @@ Advanced hooks
|
|||
^^^^^^^^^^^^^^
|
||||
Use these hooks to modify advanced functionality
|
||||
|
||||
configure_apex
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_apex
|
||||
:noindex:
|
||||
|
||||
configure_ddp
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.configure_ddp
|
||||
:noindex:
|
||||
|
||||
configure_sync_batchnorm
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -985,40 +985,6 @@ class LightningModule(
|
|||
|
||||
return model
|
||||
|
||||
def configure_apex(
|
||||
self,
|
||||
amp: object,
|
||||
model: "LightningModule",
|
||||
optimizers: List[Optimizer],
|
||||
amp_level: str,
|
||||
) -> Tuple["LightningModule", List[Optimizer]]:
|
||||
r"""
|
||||
Override to init AMP your own way.
|
||||
Must return a model and list of optimizers.
|
||||
|
||||
Args:
|
||||
amp: pointer to amp library object.
|
||||
model: pointer to current :class:`LightningModule`.
|
||||
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
|
||||
amp_level: AMP mode chosen ('O1', 'O2', etc...)
|
||||
|
||||
Return:
|
||||
Apex wrapped model and optimizers
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# Default implementation used by Trainer.
|
||||
def configure_apex(self, amp, model, optimizers, amp_level):
|
||||
model, optimizers = amp.initialize(
|
||||
model, optimizers, opt_level=amp_level,
|
||||
)
|
||||
|
||||
return model, optimizers
|
||||
"""
|
||||
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||
return model, optimizers
|
||||
|
||||
def configure_optimizers(
|
||||
self,
|
||||
):
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -24,10 +26,44 @@ class ApexPlugin:
|
|||
self.trainer = trainer
|
||||
|
||||
def connect(self, model, optimizers):
|
||||
model, optimizers = model.configure_apex(amp, model, optimizers, self.trainer.amp_level)
|
||||
model, optimizers = self.configure_apex(amp, model, optimizers, self.trainer.amp_level)
|
||||
self.trainer.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
|
||||
return model, optimizers
|
||||
|
||||
def training_step(self, fx, args):
|
||||
output = fx(args)
|
||||
return output
|
||||
|
||||
def configure_apex(
|
||||
self,
|
||||
amp: object,
|
||||
model: "LightningModule",
|
||||
optimizers: List[Optimizer],
|
||||
amp_level: str,
|
||||
) -> Tuple["LightningModule", List[Optimizer]]:
|
||||
r"""
|
||||
Override to init AMP your own way.
|
||||
Must return a model and list of optimizers.
|
||||
|
||||
Args:
|
||||
amp: pointer to amp library object.
|
||||
model: pointer to current :class:`LightningModule`.
|
||||
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
|
||||
amp_level: AMP mode chosen ('O1', 'O2', etc...)
|
||||
|
||||
Return:
|
||||
Apex wrapped model and optimizers
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# Default implementation used by Trainer.
|
||||
def configure_apex(self, amp, model, optimizers, amp_level):
|
||||
model, optimizers = amp.initialize(
|
||||
model, optimizers, opt_level=amp_level,
|
||||
)
|
||||
|
||||
return model, optimizers
|
||||
"""
|
||||
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||
return model, optimizers
|
||||
|
|
Loading…
Reference in New Issue