# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time from functools import wraps from typing import Any, Callable, Dict, List, Optional from torch.utils.data import DataLoader import pytorch_lightning as pl def enabled_only(fn: Callable) -> Optional[Callable]: """Decorate a logger method to run it only on the process with rank 0. Args: fn: Function to decorate """ @wraps(fn) def wrapped_fn(self: Callable, *args: Any, **kwargs: Any) -> Optional[Any]: if self.enabled: fn(self, *args, **kwargs) return None return wrapped_fn class InternalDebugger: def __init__(self, trainer: "pl.Trainer") -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer self.events: List[Dict[str, Any]] = [] self.train_dataloader_calls: List[Dict[str, Any]] = [] self.val_dataloader_calls: List[Dict[str, Any]] = [] self.test_dataloader_calls: List[Dict[str, Any]] = [] self.dataloader_sequence_calls: List[Dict[str, Any]] = [] @enabled_only def track_event( self, evt_type: str, evt_value: Any = None, global_rank: Optional[int] = None, local_rank: Optional[int] = None, comment: str = "", ) -> None: self.events.append( { "timestamp": time.time(), "event": evt_type, "value": evt_value, "global_rank": global_rank, "local_rank": local_rank, "comment": comment, } ) @enabled_only def track_load_dataloader_call(self, name: str, dataloaders: List[DataLoader]) -> None: loader_counts = len(dataloaders) lengths = [] for dl in dataloaders: try: length = len(dl) # todo: specify the possible exception except Exception: length = -1 lengths.append(length) values = { "global_step": self.trainer.global_step, "epoch": self.trainer.current_epoch, "num_loaders": loader_counts, "lengths": lengths, "name": name, } # track the sequence in case we need to verify the sequence self.dataloader_sequence_calls.append(values) if "train" in name: self.train_dataloader_calls.append(values) elif "val" in name: self.val_dataloader_calls.append(values) elif "test" in name: self.test_dataloader_calls.append(values)