2019-12-04 15:57:32 +00:00
|
|
|
from abc import ABC
|
|
|
|
|
2020-03-17 22:44:00 +00:00
|
|
|
from pytorch_lightning import _logger as log
|
2020-06-27 01:45:13 +00:00
|
|
|
from pytorch_lightning.utilities import rank_zero_warn, APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
|
|
|
|
from pytorch_lightning.utilities.distributed import rank_zero_debug
|
2019-10-22 01:16:51 +00:00
|
|
|
|
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
class TrainerAMPMixin(ABC):
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-02-27 21:21:14 +00:00
|
|
|
# this is just a summary on variables used in this abstract class,
|
|
|
|
# the proper values/initialisation should be done in child class
|
2020-04-06 12:13:24 +00:00
|
|
|
precision: int
|
2020-04-23 18:47:08 +00:00
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
def init_amp(self):
|
|
|
|
if NATIVE_AMP_AVALAIBLE:
|
|
|
|
log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)")
|
2020-04-23 18:47:08 +00:00
|
|
|
|
|
|
|
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
if self.use_amp and NATIVE_AMP_AVALAIBLE:
|
|
|
|
log.info('Using native 16bit precision.')
|
2020-04-23 18:47:08 +00:00
|
|
|
return
|
|
|
|
|
2020-06-27 01:45:13 +00:00
|
|
|
if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover
|
|
|
|
raise ModuleNotFoundError(
|
|
|
|
"You set `use_amp=True` but do not have apex installed."
|
2020-07-28 20:28:22 +00:00
|
|
|
" Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
|
|
|
|
" and rerun with `use_amp=True`."
|
|
|
|
" This run will NOT use 16 bit precision."
|
2020-06-27 01:45:13 +00:00
|
|
|
)
|
2020-04-06 12:13:24 +00:00
|
|
|
|
|
|
|
if self.use_amp:
|
2020-06-27 01:45:13 +00:00
|
|
|
log.info('Using APEX 16bit precision.')
|
2020-04-06 12:13:24 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def use_amp(self) -> bool:
|
2020-04-23 18:47:08 +00:00
|
|
|
return self.precision == 16
|