35 lines
768 B
Python
35 lines
768 B
Python
from functools import wraps
|
|
import warnings
|
|
from pytorch_lightning import _logger as log
|
|
import os
|
|
|
|
|
|
def rank_zero_only(fn):
|
|
|
|
@wraps(fn)
|
|
def wrapped_fn(*args, **kwargs):
|
|
if rank_zero_only.rank == 0:
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
|
rank_zero_only.rank = getattr(rank_zero_only, 'rank', int(os.environ.get('LOCAL_RANK', 0)))
|
|
|
|
def _warn(*args, **kwargs):
|
|
warnings.warn(*args, **kwargs)
|
|
|
|
|
|
def _info(*args, **kwargs):
|
|
log.info(*args, **kwargs)
|
|
|
|
|
|
def _debug(*args, **kwargs):
|
|
log.info(*args, **kwargs)
|
|
|
|
|
|
rank_zero_debug = rank_zero_only(_debug)
|
|
rank_zero_info = rank_zero_only(_info)
|
|
rank_zero_warn = rank_zero_only(_warn)
|