33 lines
661 B
Python
33 lines
661 B
Python
from functools import wraps
|
|
import warnings
|
|
from pytorch_lightning import _logger as log
|
|
|
|
|
|
def rank_zero_only(fn):
|
|
|
|
@wraps(fn)
|
|
def wrapped_fn(*args, **kwargs):
|
|
if rank_zero_only.rank == 0:
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
try:
|
|
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
|
getattr(rank_zero_only, 'rank')
|
|
except AttributeError:
|
|
rank_zero_only.rank = 0
|
|
|
|
|
|
def _warn(*args, **kwargs):
|
|
warnings.warn(*args, **kwargs)
|
|
|
|
|
|
def _info(*args, **kwargs):
|
|
log.info(*args, **kwargs)
|
|
|
|
|
|
rank_zero_info = rank_zero_only(_info)
|
|
rank_zero_warn = rank_zero_only(_warn)
|