from pytorch_lightning.utilities.distributed import rank_zero_warn
class WarningCache:
def __init__(self):
self.warnings = set()
def warn(self, m):
if m not in self.warnings:
self.warnings.add(m)
rank_zero_warn(m)