2020-01-17 11:03:31 +00:00
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
from abc import ABC
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
APEX_AVAILABLE = True
|
|
|
|
except ImportError:
|
|
|
|
APEX_AVAILABLE = False
|
2020-02-01 20:47:58 +00:00
|
|
|
import logging as log
|
2019-10-22 01:16:51 +00:00
|
|
|
|
|
|
|
|
2019-12-04 15:57:32 +00:00
|
|
|
class TrainerAMPMixin(ABC):
|
2019-10-22 01:16:51 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
def __init__(self):
|
|
|
|
self.use_amp = None
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
def init_amp(self, use_amp):
|
|
|
|
self.use_amp = use_amp and APEX_AVAILABLE
|
|
|
|
if self.use_amp:
|
2020-02-01 20:47:58 +00:00
|
|
|
log.info('Using 16bit precision.')
|
2019-10-22 01:16:51 +00:00
|
|
|
|
|
|
|
if use_amp and not APEX_AVAILABLE: # pragma: no cover
|
|
|
|
msg = """
|
2019-12-04 16:39:14 +00:00
|
|
|
You set `use_amp=True` but do not have apex installed.
|
2019-10-22 01:16:51 +00:00
|
|
|
Install apex first using this guide and rerun with use_amp=True:
|
|
|
|
https://github.com/NVIDIA/apex#linux
|
|
|
|
|
|
|
|
this run will NOT use 16 bit precision
|
|
|
|
"""
|
|
|
|
raise ModuleNotFoundError(msg)
|