2020-04-24 21:21:00 +00:00
|
|
|
from functools import wraps
|
|
|
|
import warnings
|
2020-06-13 07:47:45 +00:00
|
|
|
from pytorch_lightning import _logger as log
|
2020-06-19 06:38:10 +00:00
|
|
|
import os
|
2020-04-24 21:21:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def rank_zero_only(fn):
|
|
|
|
|
|
|
|
@wraps(fn)
|
|
|
|
def wrapped_fn(*args, **kwargs):
|
|
|
|
if rank_zero_only.rank == 0:
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
|
|
|
getattr(rank_zero_only, 'rank')
|
|
|
|
except AttributeError:
|
2020-06-19 06:38:10 +00:00
|
|
|
rank_zero_only.rank = os.environ.get('LOCAL_RANK', 0)
|
2020-04-24 21:21:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _warn(*args, **kwargs):
|
|
|
|
warnings.warn(*args, **kwargs)
|
|
|
|
|
|
|
|
|
2020-06-13 07:47:45 +00:00
|
|
|
def _info(*args, **kwargs):
|
|
|
|
log.info(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
rank_zero_info = rank_zero_only(_info)
|
2020-04-24 21:21:00 +00:00
|
|
|
rank_zero_warn = rank_zero_only(_warn)
|