lightning/pytorch_lightning/trainer/auto_mix_precision.py

54 lines
1.8 KiB
Python
Raw Normal View History

from abc import ABC
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
class TrainerAMPMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
precision: int
use_native_amp: bool
Enable TPU support (#868) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * fix test pkg create (#873) * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Luis Capelo <luiscape@gmail.com> * Fix segmentation example (#876) * removed torchvision model and added custom model * minor fix * Fixed relative imports issue * Fix/typo (#880) * Update greetings.yml * Update greetings.yml * Changelog (#869) * Create CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Update PULL_REQUEST_TEMPLATE.md * Add PR links to Version 0.6.0 in CHANGELOG.md * Add PR links for Unreleased in CHANGELOG.md * Update PULL_REQUEST_TEMPLATE.md * Fixing Function Signatures (#871) * added tpu docs * added tpu flags * add tpu docs + init training call * amp * amp * amp * amp * optimizer step * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added auto data transfer to TPU * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print * added test return and print Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Luis Capelo <luiscape@gmail.com> Co-authored-by: Akshay Kulkarni <akshayk.vnit@gmail.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Shikhar Chauhan <xssChauhan@users.noreply.github.com>
2020-02-17 21:01:20 +00:00
def init_amp(self, use_amp):
if self.use_native_amp:
rank_zero_warn("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)"
" and this argument will be removed in v0.9.0", DeprecationWarning)
# Backward compatibility, TODO: remove in v0.9.0
if use_amp is not None:
rank_zero_warn("`use_amp` has been replaced by `precision` since v0.7.0"
" and this argument will be removed in v0.9.0", DeprecationWarning)
self.precision = 16 if use_amp else 32
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
if use_amp and self.use_native_amp:
log.info('Using 16bit precision.')
return
# TODO: remove all below for v0.9.0
if use_amp and not APEX_AVAILABLE: # pragma: no-cover
raise ModuleNotFoundError("""
You set `use_amp=True` but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux
this run will NOT use 16 bit precision
""")
if self.use_amp:
log.info('Using 16bit precision.')
@property
def use_amp(self) -> bool:
return self.precision == 16