lightning/pytorch_lightning/trainer/auto_mix_precision.py

29 lines
720 B
Python
Raw Normal View History

clean v2 docs (#691) * updated gitignore * Update README.md * updated gitignore * updated links in ninja file * updated docs * Update README.md * Update README.md * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * fixing TensorBoard (#687) * flake8 * fix typo * fix tensorboardlogger drop test_tube dependence * formatting * fix tensorboard & tests * upgrade Tensorboard * test formatting separately * try to fix JIT issue * add tests for 1.4 * added direct links to docs * updated gitignore * updated links in ninja file * updated docs * finished callbacks * finished callbacks * finished callbacks * fixed left menu * added callbacks to menu * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * added direct links to docs * finished rebase * making private members * making private members * making private members * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * set auto dp if no backend * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * working on trainer docs * fixed lightning import * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * cleared spaces * finished lightning module * finished lightning module * finished lightning module * finished lightning module * added callbacks * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * set auto dp if no backend * added loggers * added loggers * added loggers * added loggers * added loggers * added loggers * flake 8 * flake 8 Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-01-17 11:03:31 +00:00
from abc import ABC
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
import logging
class TrainerAMPMixin(ABC):
def init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
logging.info('using 16bit precision')
if use_amp and not APEX_AVAILABLE: # pragma: no cover
msg = """
You set `use_amp=True` but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux
this run will NOT use 16 bit precision
"""
raise ModuleNotFoundError(msg)