diff --git a/.actions/assistant.py b/.actions/assistant.py index 92239a406d..54d401b3bb 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -86,6 +86,7 @@ class _RequirementWithComment(Requirement): 'arrow>=1.2.0' >>> _RequirementWithComment("arrow").adjust("major") 'arrow' + """ out = str(self) if self.strict: @@ -115,6 +116,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen >>> txt = '\\n'.join(txt) >>> [r.adjust('none') for r in _parse_requirements(txt)] ['this', 'example', 'foo # strict', 'thing'] + """ lines = yield_lines(strs) pip_argument = None @@ -149,6 +151,7 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") >>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['sphinx<...] + """ assert unfreeze in {"none", "major", "all"} path = Path(path_dir) / file_name @@ -165,6 +168,7 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str: >>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '...PyTorch Lightning is just organized PyTorch...' + """ path_readme = os.path.join(path_dir, "README.md") with open(path_readme, encoding="utf-8") as fo: @@ -244,6 +248,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme """Load all base requirements from all particular packages and prune duplicates. >>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements")) + """ requires = [ load_requirements(d, unfreeze="none" if freeze_requirements else "major") @@ -300,6 +305,7 @@ def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning 'http://pytorch_lightning.ai', \ 'from lightning_fabric import __version__', \ '@lightning.ai'] + """ out = lines[:] for source_import, target_import in mapping: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0442ca53c7..370b18b19f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,7 +60,8 @@ repos: rev: v1.7.3 hooks: - id: docformatter - args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] + additional_dependencies: [tomli] + args: ["--in-place"] - repo: https://github.com/asottile/yesqa rev: v1.5.0 diff --git a/docs/source-app/examples/file_server/app.py b/docs/source-app/examples/file_server/app.py index 3afba48469..ed14df22d9 100644 --- a/docs/source-app/examples/file_server/app.py +++ b/docs/source-app/examples/file_server/app.py @@ -23,6 +23,7 @@ class FileServer(L.LightningWork): drive: The drive can share data inside your application. base_dir: The local directory where the data will be stored. chunk_size: The quantity of bytes to download/upload at once. + """ super().__init__( cloud_build_config=L.BuildConfig(["flask, flask-cors"]), @@ -238,4 +239,5 @@ def test_file_server_in_cloud(): # 2. By calling logs = get_logs_fn(), # you get all the logs currently on the admin page. + """ diff --git a/docs/source-app/examples/github_repo_runner/app.py b/docs/source-app/examples/github_repo_runner/app.py index 486055f7a8..b670baa4bf 100644 --- a/docs/source-app/examples/github_repo_runner/app.py +++ b/docs/source-app/examples/github_repo_runner/app.py @@ -36,6 +36,7 @@ class GithubRepoRunner(TracerPythonScript): script_args: The arguments to be provided to the script. requirements: The python requirements tp run the script. cloud_compute: The object to select the cloud instance. + """ super().__init__( script_path=script_path, diff --git a/docs/source-app/examples/model_server_app/locust_component.py b/docs/source-app/examples/model_server_app/locust_component.py index 4351506f5f..432336adf8 100644 --- a/docs/source-app/examples/model_server_app/locust_component.py +++ b/docs/source-app/examples/model_server_app/locust_component.py @@ -10,6 +10,7 @@ class Locust(LightningWork): Arguments: num_users: Number of users emulated by Locust + """ # Note: Using the default port 8089 of Locust. super().__init__( diff --git a/docs/source-app/examples/model_server_app/model_server.py b/docs/source-app/examples/model_server_app/model_server.py index 8562c63d8c..f571f613de 100644 --- a/docs/source-app/examples/model_server_app/model_server.py +++ b/docs/source-app/examples/model_server_app/model_server.py @@ -18,6 +18,7 @@ class MLServer(LightningWork): Example: "mlserver_sklearn.SKLearnModel". Learn more here: $ML_SERVER_URL/tree/master/runtimes workers: Number of server worker. + """ def __init__( @@ -51,6 +52,7 @@ class MLServer(LightningWork): Arguments: model_path: The path to the trained model. + """ # 1: Use the host and port at runtime so it works in the cloud. # $ML_SERVER_URL/blob/master/mlserver/settings.py#L50 diff --git a/examples/app/hpo/utils.py b/examples/app/hpo/utils.py index a08fda2f61..9d27c726b0 100644 --- a/examples/app/hpo/utils.py +++ b/examples/app/hpo/utils.py @@ -15,6 +15,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: Usage: download_file('http://web4host.net/5MB.zip') + """ if url == "NEED_TO_BE_CREATED": raise NotImplementedError diff --git a/examples/app/layout/app.py b/examples/app/layout/app.py index 7048f62a94..0e9efabba7 100644 --- a/examples/app/layout/app.py +++ b/examples/app/layout/app.py @@ -5,6 +5,7 @@ Run the app: lightning run app examples/layout/demo.py This starts one server for each flow that returns a UI. Access the UI at the link printed in the terminal. + """ import os diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index 69895b6498..3e991de74e 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -137,6 +137,7 @@ class MyCustomTrainer: If not specified, no validation will run. ckpt_path: Path to previous checkpoints to resume training from. If specified, will always look for the latest checkpoint within the given directory. + """ self.fabric.launch() @@ -207,6 +208,7 @@ class MyCustomTrainer: If greater then the number of batches in the ``train_loader``, this has no effect. scheduler_cfg: The learning rate scheduler configuration. Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values. + """ self.fabric.call("on_train_epoch_start") iterable = self.progbar_wrapper( @@ -268,6 +270,7 @@ class MyCustomTrainer: val_loader: The dataloader yielding the validation batches. limit_batches: Limits the batches during this validation epoch. If greater then the number of batches in the ``val_loader``, this has no effect. + """ # no validation if val_loader wasn't passed if val_loader is None: @@ -311,13 +314,14 @@ class MyCustomTrainer: torch.set_grad_enabled(True) def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor: - """A single training step, running forward and backward. The optimizer step is called separately, as this - is given as a closure to the optimizer step. + """A single training step, running forward and backward. The optimizer step is called separately, as this is + given as a closure to the optimizer step. Args: model: the lightning module to train batch: the batch to run the forward on batch_idx: index of the current batch w.r.t the current epoch + """ outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx) @@ -347,6 +351,7 @@ class MyCustomTrainer: Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values. level: whether we are trying to step on epoch- or step-level current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step`` + """ # no scheduler @@ -395,6 +400,7 @@ class MyCustomTrainer: Args: iterable: the iterable to wrap with tqdm total: the total length of the iterable, necessary in case the number of batches was limited. + """ if self.fabric.is_global_zero: return tqdm(iterable, total=total, **kwargs) @@ -406,6 +412,7 @@ class MyCustomTrainer: Args: state: a mapping contaning model, optimizer and lr scheduler path: the path to load the checkpoint from + """ if state is None: state = {} @@ -458,6 +465,7 @@ class MyCustomTrainer: Args: configure_optim_output: The output of ``configure_optimizers``. For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`. + """ _lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"} @@ -511,6 +519,7 @@ class MyCustomTrainer: prog_bar: a progressbar (on global rank zero) or an iterable (every other rank). candidates: the values to add as postfix strings to the progressbar. prefix: the prefix to add to each of these values. + """ if isinstance(prog_bar, tqdm) and candidates is not None: postfix_str = "" diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index 5f4d9313c6..c05a35fc8d 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -25,6 +25,7 @@ and replace ``loss.backward()`` with ``self.backward(loss)``. Accelerate your training loop by setting the ``--accelerator``, ``--strategy``, ``--devices`` options directly from the command line. See ``lightning run model --help`` or learn more from the documentation: https://lightning.ai/docs/fabric. + """ import argparse diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index 006397f8e9..377579fccd 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -14,6 +14,7 @@ """MNIST autoencoder example. To run: python autoencoder.py --trainer.max_epochs=50 + """ from os import path from typing import Optional, Tuple diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index 65cf036f70..589f632bac 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -14,6 +14,7 @@ """MNIST backbone image classifier example. To run: python backbone_image_classifier.py --trainer.max_epochs=50 + """ from os import path from typing import Optional diff --git a/examples/pytorch/basics/profiler_example.py b/examples/pytorch/basics/profiler_example.py index 0c429d2917..5fe4004946 100644 --- a/examples/pytorch/basics/profiler_example.py +++ b/examples/pytorch/basics/profiler_example.py @@ -20,6 +20,7 @@ visualized in 2 ways: * With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin) 1. pip install tensorboard torch-tb-profiler 2. tensorboard --logdir={FOLDER} + """ from os import path diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py index 4bfd9de384..f55e6aa73a 100644 --- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py +++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a -pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the -'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, -see below) is trained for 15 epochs. +pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats +and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see +below) is trained for 15 epochs. The training consists of three stages. @@ -37,6 +37,7 @@ Note: To run: python computer_vision_fine_tuning.py fit + """ import logging @@ -97,6 +98,7 @@ class CatDogImageDataModule(LightningDataModule): dl_path: root directory where to download the data num_workers: number of CPU workers batch_size: number of sample in a batch + """ super().__init__() @@ -174,6 +176,7 @@ class TransferLearningModel(LightningModule): milestones: List of two epochs milestones lr: Initial learning rate lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone + """ super().__init__() self.backbone = backbone @@ -209,6 +212,7 @@ class TransferLearningModel(LightningModule): """Forward pass. Returns logits. + """ # 1. Feature extraction: x = self.feature_extractor(x) diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 734d625629..e31dec1244 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -16,6 +16,7 @@ After a few epochs, launch TensorBoard to see the images being generated at every batch: tensorboard --logdir default + """ from argparse import ArgumentParser, Namespace diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py index 0d7275f58d..553b500e09 100644 --- a/examples/pytorch/domain_templates/imagenet.py +++ b/examples/pytorch/domain_templates/imagenet.py @@ -28,6 +28,7 @@ or show all options you can change: python imagenet.py --help python imagenet.py fit --help + """ import os from typing import Optional diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index 0f3e455b73..3d1e5d0161 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -29,6 +29,7 @@ References [1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- Second-Edition/blob/master/Chapter06/02_dqn_pong.py + """ import argparse @@ -54,6 +55,7 @@ class DQN(nn.Module): DQN( (net): Sequential(...) ) + """ def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): @@ -79,6 +81,7 @@ class ReplayBuffer: >>> ReplayBuffer(5) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.ReplayBuffer object at ...> + """ def __init__(self, capacity: int) -> None: @@ -96,6 +99,7 @@ class ReplayBuffer: Args: experience: tuple (state, action, reward, done, new_state) + """ self.buffer.append(experience) @@ -117,6 +121,7 @@ class RLDataset(IterableDataset): >>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.RLDataset object at ...> + """ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: @@ -141,6 +146,7 @@ class Agent: >>> buffer = ReplayBuffer(10) >>> Agent(env, buffer) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.Agent object at ...> + """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: @@ -168,6 +174,7 @@ class Agent: Returns: action + """ if np.random.random() < epsilon: action = self.env.action_space.sample() @@ -194,6 +201,7 @@ class Agent: Returns: reward, done + """ action = self.get_action(net, epsilon, device) @@ -222,6 +230,7 @@ class DQNLightning(LightningModule): (net): Sequential(...) ) ) + """ def __init__( @@ -270,6 +279,7 @@ class DQNLightning(LightningModule): Args: steps: number of random steps to populate the buffer with + """ for i in range(steps): self.agent.play_step(self.net, epsilon=1.0) @@ -282,6 +292,7 @@ class DQNLightning(LightningModule): Returns: q values + """ return self.net(x) @@ -293,6 +304,7 @@ class DQNLightning(LightningModule): Returns: loss + """ states, actions, rewards, dones, next_states = batch @@ -308,8 +320,8 @@ class DQNLightning(LightningModule): return nn.MSELoss()(state_action_values, expected_state_action_values) def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: - """Carries out a single step through the environment to update the replay buffer. Then calculates loss - based on the minibatch received. + """Carries out a single step through the environment to update the replay buffer. Then calculates loss based on + the minibatch received. Args: batch: current mini batch of replay data @@ -317,6 +329,7 @@ class DQNLightning(LightningModule): Returns: Training loss and log metrics + """ device = self.get_device(batch) epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame) diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index b68fcf720b..16aa5ebc86 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -26,6 +26,7 @@ References [1] https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py [2] https://github.com/openai/spinningup [3] https://github.com/sid-sundrani/ppo_lightning + """ import argparse from typing import Callable, Iterator, List, Tuple @@ -52,8 +53,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): class ActorCategorical(nn.Module): - """Policy network, for discrete action spaces, which returns a distribution and an action given an - observation.""" + """Policy network, for discrete action spaces, which returns a distribution and an action given an observation.""" def __init__(self, actor_net): """ @@ -81,6 +81,7 @@ class ActorCategorical(nn.Module): Returns: log probability of the action under pi + """ return pi.log_prob(actions) @@ -117,6 +118,7 @@ class ActorContinuous(nn.Module): Returns: log probability of the action under pi + """ return pi.log_prob(actions).sum(axis=-1) @@ -127,6 +129,7 @@ class ExperienceSourceDataset(IterableDataset): Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is generated is defined the Lightning model itself + """ def __init__(self, generate_batch: Callable): @@ -144,6 +147,7 @@ class PPOLightning(LightningModule): Train: trainer = Trainer() trainer.fit(model) + """ def __init__( @@ -231,6 +235,7 @@ class PPOLightning(LightningModule): Returns: Tuple of policy and action + """ pi, action = self.actor(x) value = self.critic(x) @@ -245,6 +250,7 @@ class PPOLightning(LightningModule): Returns: list of discounted rewards/advantages + """ assert isinstance(rewards[0], float) @@ -267,6 +273,7 @@ class PPOLightning(LightningModule): Returns: list of advantages + """ rews = rewards + [last_value] vals = values + [last_value] @@ -373,6 +380,7 @@ class PPOLightning(LightningModule): Args: batch: batch of replay buffer/trajectory data + """ state, action, old_logp, qval, adv = batch @@ -405,8 +413,7 @@ class PPOLightning(LightningModule): return optimizer_actor, optimizer_critic def optimizer_step(self, *args, **kwargs): - """Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data - sample.""" + """Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data sample.""" for _ in range(self.nb_optim_iters): super().optimizer_step(*args, **kwargs) diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py index eb816756ce..286084d5ff 100644 --- a/examples/pytorch/domain_templates/semantic_segmentation.py +++ b/examples/pytorch/domain_templates/semantic_segmentation.py @@ -31,8 +31,7 @@ DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27 def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)): - """Create synthetic dataset with random images, just to simulate that the dataset have been already - downloaded.""" + """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded.""" path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH) path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH) for p_dir in (path_dir_images, path_dir_masks): @@ -65,6 +64,7 @@ class KITTI(Dataset): In the `get_item` function, images and masks are resized to the given `img_size`, masks are encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only (mask does not usually require transforms, but they can be implemented in a similar way). + """ IMAGE_PATH = os.path.join("training", "image_2") @@ -154,6 +154,7 @@ class UNet(nn.Module): (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) ) ) + """ def __init__(self, num_classes: int = 19, num_layers: int = 5, features_start: int = 64, bilinear: bool = False): @@ -200,6 +201,7 @@ class DoubleConv(nn.Module): DoubleConv( (net): Sequential(...) ) + """ def __init__(self, in_ch: int, out_ch: int): @@ -229,6 +231,7 @@ class Down(nn.Module): ) ) ) + """ def __init__(self, in_ch: int, out_ch: int): @@ -240,8 +243,8 @@ class Down(nn.Module): class Up(nn.Module): - """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature - map from contracting path, followed by double 3x3 convolution. + """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map + from contracting path, followed by double 3x3 convolution. >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Up( @@ -250,6 +253,7 @@ class Up(nn.Module): (net): Sequential(...) ) ) + """ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): @@ -306,6 +310,7 @@ class SegModel(LightningModule): ) ) ) + """ def __init__( diff --git a/pyproject.toml b/pyproject.toml index 8aa9903f36..27334b8122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,13 @@ skip = ["_notebooks"] line-length = 120 exclude = '(_notebooks/.*)' +[tool.docformatter] +recursive = true +# this need to be shorter as some docstings are r"""... +wrap-summaries = 119 +wrap-descriptions = 120 +blank = true + [tool.ruff] line-length = 120 diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py index 3dd2b8d642..5e6f9ba3dd 100644 --- a/requirements/collect_env_details.py +++ b/requirements/collect_env_details.py @@ -14,6 +14,7 @@ """Diagnose your system and show basic information. This server mainly to get detail info for better bug reporting. + """ import os diff --git a/setup.py b/setup.py index 043308ebb6..d5d92d8228 100755 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ There are considered three main scenarios for installing this project: compared against PyPI registry b) with a parameterization build desired packages in to standard `dist/` folder c) validate packages and publish to PyPI + """ import contextlib import glob diff --git a/src/lightning/app/api/http_methods.py b/src/lightning/app/api/http_methods.py index 8cab27096b..dc3dc32bee 100644 --- a/src/lightning/app/api/http_methods.py +++ b/src/lightning/app/api/http_methods.py @@ -58,6 +58,7 @@ class _FastApiMockRequest: def configure_api(self): return [Post("/api/v1/request", self.request)] + """ _body: Optional[str] = None @@ -116,6 +117,7 @@ class _HttpMethod: route: The path used to route the requests method: The associated flow method timeout: The time in seconds taken before raising a timeout exception. + """ self.route = route self.attached_to_flow = hasattr(method, "__self__") diff --git a/src/lightning/app/cli/cmd_install.py b/src/lightning/app/cli/cmd_install.py index 05ead42b42..b43aa3f88f 100644 --- a/src/lightning/app/cli/cmd_install.py +++ b/src/lightning/app/cli/cmd_install.py @@ -566,6 +566,7 @@ def _install_app_from_source( If true, overwrite the app directory without asking if it already exists git_sha: The git_sha for checking out the git repo of the app. + """ if not cwd: diff --git a/src/lightning/app/cli/commands/logs.py b/src/lightning/app/cli/commands/logs.py index eba1746cbd..4587987ae5 100644 --- a/src/lightning/app/cli/commands/logs.py +++ b/src/lightning/app/cli/commands/logs.py @@ -46,6 +46,7 @@ def logs(app_name: str, components: List[str], follow: bool) -> None: Print logs only from selected works: $ lightning show logs my-application root.work_a root.work_b + """ _show_logs(app_name, components, follow) diff --git a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py index e1b30e1c11..8c7dad2fe7 100644 --- a/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py +++ b/src/lightning/app/cli/component-template/tests/test_placeholdername_component.py @@ -2,6 +2,7 @@ r"""To test a lightning component: 1. Init the component. 2. call .run() + """ from placeholdername.component import TemplateComponent diff --git a/src/lightning/app/cli/connect/app.py b/src/lightning/app/cli/connect/app.py index 76de20256f..ebad9b1297 100644 --- a/src/lightning/app/cli/connect/app.py +++ b/src/lightning/app/cli/connect/app.py @@ -57,6 +57,7 @@ def connect_app(app_name_or_id: str): \b # once done, disconnect and go back to the standard lightning CLI commands lightning disconnect + """ from lightning.app.utilities.commands.base import _download_command diff --git a/src/lightning/app/cli/lightning_cli_delete.py b/src/lightning/app/cli/lightning_cli_delete.py index c48bee07ca..179e5b6fc3 100644 --- a/src/lightning/app/cli/lightning_cli_delete.py +++ b/src/lightning/app/cli/lightning_cli_delete.py @@ -98,6 +98,7 @@ def delete_app(app_name: str, skip_user_confirm_prompt: bool) -> None: Deleting an app also deletes all app websites, works, artifacts, and logs. This permanently removes any record of the app as well as all any of its associated resources and data. This does not affect any resources and data associated with other Lightning apps on your account. + """ console = Console() diff --git a/src/lightning/app/cli/lightning_cli_launch.py b/src/lightning/app/cli/lightning_cli_launch.py index 8cf56453d8..c171fd7b94 100644 --- a/src/lightning/app/cli/lightning_cli_launch.py +++ b/src/lightning/app/cli/lightning_cli_launch.py @@ -40,10 +40,11 @@ def launch() -> None: @click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str) @click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int) def run_server(file: str, queue_id: str, host: str, port: int) -> None: - """It takes the application file as input, build the application object and then use that to run the - application server. + """It takes the application file as input, build the application object and then use that to run the application + server. This is used by the cloud runners to start the status server for the application + """ logger.debug(f"Run Server: {file} {queue_id} {host} {port}") start_application_server(file, host, port, queue_id=queue_id) @@ -54,10 +55,11 @@ def run_server(file: str, queue_id: str, host: str, port: int) -> None: @click.option("--queue-id", help="ID for identifying queue", default="", type=str) @click.option("--base-url", help="Base url at which the app server is hosted", default="") def run_flow(file: str, queue_id: str, base_url: str) -> None: - """It takes the application file as input, build the application object, proxy all the work components and then - run the application flow defined in the root component. + """It takes the application file as input, build the application object, proxy all the work components and then run + the application flow defined in the root component. It does exactly what a singleprocess dispatcher would do but with proxied work components. + """ logger.debug(f"Run Flow: {file} {queue_id} {base_url}") run_lightning_flow(file, queue_id=queue_id, base_url=base_url) @@ -68,8 +70,8 @@ def run_flow(file: str, queue_id: str, base_url: str) -> None: @click.option("--work-name", type=str) @click.option("--queue-id", help="ID for identifying queue", default="", type=str) def run_work(file: str, work_name: str, queue_id: str) -> None: - """Unlike other entrypoints, this command will take the file path or module details for a work component and - run that by fetching the states from the queues.""" + """Unlike other entrypoints, this command will take the file path or module details for a work component and run + that by fetching the states from the queues.""" logger.debug(f"Run Work: {file} {work_name} {queue_id}") run_lightning_work( file=file, @@ -109,10 +111,11 @@ def run_flow_and_servers( port: int, flow_port: Tuple[Tuple[str, int]], ) -> None: - """It takes the application file as input, build the application object and then use that to run the - application flow defined in the root component, the application server and all the flow frontends. + """It takes the application file as input, build the application object and then use that to run the application + flow defined in the root component, the application server and all the flow frontends. This is used by the cloud runners to start the flow, the status server and all frontends for the application + """ logger.debug(f"Run Flow: {file} {queue_id} {base_url}") logger.debug(f"Run Server: {file} {queue_id} {host} {port}.") diff --git a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py index 6f5b2eb563..a2935140a2 100644 --- a/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py +++ b/src/lightning/app/cli/pl-app-template/core/components/logger/tensorboard.py @@ -13,6 +13,7 @@ class TensorBoard(LightningFlow): Args: log_dir: The path to the directory where the TensorBoard log-files will appear. sync_every_n_seconds: How often to sync the log directory (given as an argument to the run method) + """ super().__init__() self.worker = TensorBoardWorker(log_dir=log_dir, sync_every_n_seconds=sync_every_n_seconds) diff --git a/src/lightning/app/components/database/client.py b/src/lightning/app/components/database/client.py index 01643afbfe..81f0862918 100644 --- a/src/lightning/app/components/database/client.py +++ b/src/lightning/app/components/database/client.py @@ -29,6 +29,7 @@ def _configure_session() -> Session: """Configures the session for GET and POST requests. It enables a generous retrial strategy that waits for the application server to connect. + """ retry_strategy = Retry( # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) diff --git a/src/lightning/app/components/database/server.py b/src/lightning/app/components/database/server.py index 6da7710cfa..3fbf75f01a 100644 --- a/src/lightning/app/components/database/server.py +++ b/src/lightning/app/components/database/server.py @@ -158,6 +158,7 @@ class Database(LightningWork): # RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility. kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair]))) + """ super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"])) self.db_filename = db_filename diff --git a/src/lightning/app/components/database/utilities.py b/src/lightning/app/components/database/utilities.py index dd31c12da6..e4561d9245 100644 --- a/src/lightning/app/components/database/utilities.py +++ b/src/lightning/app/components/database/utilities.py @@ -52,6 +52,7 @@ def _pydantic_column_type(pydantic_type: Any) -> Any: class TrialConfig(SQLModel, table=False): ... params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float])) + """ class PydanticJSONType(TypeDecorator, Generic[T]): diff --git a/src/lightning/app/components/multi_node/base.py b/src/lightning/app/components/multi_node/base.py index a300918452..1da8ef34e3 100644 --- a/src/lightning/app/components/multi_node/base.py +++ b/src/lightning/app/components/multi_node/base.py @@ -67,6 +67,7 @@ class MultiNode(LightningFlow): running locally. work_args: Arguments to be provided to the work on instantiation. work_kwargs: Keywords arguments to be provided to the work on instantiation. + """ super().__init__() if num_nodes > 1 and not is_running_in_cloud(): diff --git a/src/lightning/app/components/python/popen.py b/src/lightning/app/components/python/popen.py index 34b50140c5..9e585e2c21 100644 --- a/src/lightning/app/components/python/popen.py +++ b/src/lightning/app/components/python/popen.py @@ -70,6 +70,7 @@ class PopenPythonScript(LightningWork): .. literalinclude:: ../../../../examples/app/components/python/component_popen.py :language: python + """ super().__init__(**kwargs) if not os.path.exists(script_path): diff --git a/src/lightning/app/components/python/tracer.py b/src/lightning/app/components/python/tracer.py index e0bde7c0b5..9048bceaec 100644 --- a/src/lightning/app/components/python/tracer.py +++ b/src/lightning/app/components/python/tracer.py @@ -117,6 +117,7 @@ class TracerPythonScript(LightningWork): .. literalinclude:: ../../../../examples/app/components/python/app.py :language: python + """ super().__init__(**kwargs) self.script_path = str(script_path) diff --git a/src/lightning/app/components/serve/auto_scaler.py b/src/lightning/app/components/serve/auto_scaler.py index 165a6422e4..f0838398a1 100644 --- a/src/lightning/app/components/serve/auto_scaler.py +++ b/src/lightning/app/components/serve/auto_scaler.py @@ -114,8 +114,8 @@ def _create_fastapi(title: str) -> _TrackableFastAPI: class _LoadBalancer(LightningWork): - r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton - API asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests. + r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API + asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests. After enabling you will require to send username and password from the request header for the private endpoints. @@ -131,6 +131,7 @@ class _LoadBalancer(LightningWork): api_name: The name to be displayed on the UI. Normally, it is the name of the work class cold_start_proxy: The proxy service to use while the work is cold starting. **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc. + """ @requires(["aiohttp"]) @@ -236,6 +237,7 @@ class _LoadBalancer(LightningWork): Two instances of this function should not be running with shared `_state_server` as that would create race conditions + """ while True: await asyncio.sleep(0.05) @@ -293,6 +295,7 @@ class _LoadBalancer(LightningWork): """This function checks if we have processing capacity for one more request or not. Depends on the value from here, we decide whether we should proxy the request or not + """ if not self._fastapi_app: return False @@ -385,6 +388,7 @@ class _LoadBalancer(LightningWork): """Updates works that load balancer distributes requests to. AutoScaler uses this method to increase/decrease the number of works. + """ old_server_urls = set(self.servers) current_server_urls = { @@ -623,6 +627,7 @@ class AutoScaler(LightningFlow): Returns: The name of the new work attribute. + """ work_attribute = uuid.uuid4().hex work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}" @@ -665,6 +670,7 @@ class AutoScaler(LightningFlow): Returns: The target number of running works. The value will be adjusted after this method runs so that it satisfies ``min_replicas<=replicas<=max_replicas``. + """ pending_requests = metrics["pending_requests"] active_or_pending_works = replicas + metrics["pending_works"] diff --git a/src/lightning/app/components/serve/cold_start_proxy.py b/src/lightning/app/components/serve/cold_start_proxy.py index 1be5315f76..6ab829d875 100644 --- a/src/lightning/app/components/serve/cold_start_proxy.py +++ b/src/lightning/app/components/serve/cold_start_proxy.py @@ -35,6 +35,7 @@ class ColdStartProxy: Args: proxy_url (str): The url of the proxy service + """ @requires(["aiohttp"]) @@ -46,12 +47,13 @@ class ColdStartProxy: async def handle_request(self, request: BaseModel) -> Any: """This method is called when the request is received while the work is cold starting. The default - implementation of this method is to forward the request body to the proxy service with POST method but the - user can override this method to handle the request in any way. + implementation of this method is to forward the request body to the proxy service with POST method but the user + can override this method to handle the request in any way. Args: request (BaseModel): The request body, a pydantic model that is being forwarded by load balancer which is a FastAPI service + """ try: async with aiohttp.ClientSession() as session: diff --git a/src/lightning/app/components/serve/gradio_server.py b/src/lightning/app/components/serve/gradio_server.py index dc9f2d7847..4dd6ae8139 100644 --- a/src/lightning/app/components/serve/gradio_server.py +++ b/src/lightning/app/components/serve/gradio_server.py @@ -77,6 +77,7 @@ class ServeGradio(LightningWork, abc.ABC): """Override to instantiate and return your model. The model would be accessible under self.model + """ def run(self, *args: Any, **kwargs: Any): diff --git a/src/lightning/app/components/serve/python_server.py b/src/lightning/app/components/serve/python_server.py index 9e70688725..518c296285 100644 --- a/src/lightning/app/components/serve/python_server.py +++ b/src/lightning/app/components/serve/python_server.py @@ -211,6 +211,7 @@ class PythonServer(LightningWork, abc.ABC): ... return {"prediction": self._model(request.image)} ... >>> app = LightningApp(SimpleServer()) + """ super().__init__(parallel=True, **kwargs) if not issubclass(input_type, BaseModel): @@ -228,6 +229,7 @@ class PythonServer(LightningWork, abc.ABC): Note that this will be called exactly once on every work machines. So if you have multiple machines for serving, this will be called on each of them. + """ return @@ -243,6 +245,7 @@ class PythonServer(LightningWork, abc.ABC): This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction using the model(s) etc goes here + """ pass @@ -325,6 +328,7 @@ class PythonServer(LightningWork, abc.ABC): """Run method takes care of configuring and setting up a FastAPI server behind the scenes. Normally, you don't need to override this method. + """ self.setup(*args, **kwargs) diff --git a/src/lightning/app/components/serve/serve.py b/src/lightning/app/components/serve/serve.py index 5cf1ad6dbc..0ae4030229 100644 --- a/src/lightning/app/components/serve/serve.py +++ b/src/lightning/app/components/serve/serve.py @@ -67,6 +67,7 @@ class ModelInferenceAPI(LightningWork, abc.ABC): host: Address to be used to serve the model. port: Port to be used to serve the model. workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments. + """ super().__init__(parallel=True, host=host, port=port) if input and input not in _DESERIALIZER: @@ -151,8 +152,8 @@ class ModelInferenceAPI(LightningWork, abc.ABC): def _maybe_create_instance() -> Optional[ModelInferenceAPI]: - """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi - workers are present.""" + """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers + are present.""" render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None) render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None) if render_fn_name is None or render_fn_module_file is None: diff --git a/src/lightning/app/components/serve/streamlit.py b/src/lightning/app/components/serve/streamlit.py index 9c39ce23d5..8e9d2d34ae 100644 --- a/src/lightning/app/components/serve/streamlit.py +++ b/src/lightning/app/components/serve/streamlit.py @@ -50,6 +50,7 @@ class ServeStreamlit(LightningWork, abc.ABC): """Optionally override to instantiate and return your model. The model will be accessible under ``self.model``. + """ return None diff --git a/src/lightning/app/components/serve/types/type.py b/src/lightning/app/components/serve/types/type.py index 54d36d8afa..157940a60f 100644 --- a/src/lightning/app/components/serve/types/type.py +++ b/src/lightning/app/components/serve/types/type.py @@ -28,4 +28,5 @@ class BaseType(abc.ABCMeta): """Take the inputs from the network and deserilize/convert them them. Output from this method will go to the exposed method as arguments. + """ diff --git a/src/lightning/app/components/training.py b/src/lightning/app/components/training.py index 7c47bea8eb..5ffe843869 100644 --- a/src/lightning/app/components/training.py +++ b/src/lightning/app/components/training.py @@ -159,6 +159,7 @@ class LightningTrainerScript(LightningFlow): cloud_compute: The cloud compute object used in the cloud. sanity_serving: Whether to validate that the model correctly implements the ServableModule API + """ super().__init__() self.script_path = script_path diff --git a/src/lightning/app/core/api.py b/src/lightning/app/core/api.py index d4f0760053..b34073d310 100644 --- a/src/lightning/app/core/api.py +++ b/src/lightning/app/core/api.py @@ -270,8 +270,8 @@ async def post_delta( x_lightning_session_uuid: Optional[str] = Header(None), # type: ignore[assignment] x_lightning_session_id: Optional[str] = Header(None), # type: ignore[assignment] ) -> Optional[Dict]: - """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to - update the state.""" + """This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update + the state.""" if x_lightning_session_uuid is None: raise Exception("Missing X-Lightning-Session-UUID header") diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py index 778404bf47..e3116eca37 100644 --- a/src/lightning/app/core/app.py +++ b/src/lightning/app/core/app.py @@ -383,8 +383,7 @@ class LightningApp: return deltas def maybe_apply_changes(self) -> Optional[bool]: - """Get the deltas from both the flow queue and the work queue, merge the two deltas and update the - state.""" + """Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state.""" self._send_flow_to_work_deltas(self.state) if not self.collect_changes: @@ -503,6 +502,7 @@ class LightningApp: """Entry point of the LightningApp. This would be dispatched by the Runtime objects. + """ self._original_state = deepcopy(self.state) done = False diff --git a/src/lightning/app/core/flow.py b/src/lightning/app/core/flow.py index ff21a01007..c600f08015 100644 --- a/src/lightning/app/core/flow.py +++ b/src/lightning/app/core/flow.py @@ -396,6 +396,7 @@ class LightningFlow: .. deprecated:: 1.9.0 This function is deprecated and will be removed in 2.0.0. Use :meth:`stop` instead. + """ warnings.warn( DeprecationWarning( @@ -411,6 +412,7 @@ class LightningFlow: (prefixed by '__') attributes are not. Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable. + """ return name in LightningFlow._INTERNAL_STATE_VARS or not name.startswith("_") @@ -487,6 +489,7 @@ class LightningFlow:
+ """ if not user_key: frame = cast(FrameType, inspect.currentframe()).f_back @@ -626,6 +629,7 @@ class LightningFlow:
+ """ return [{"name": name, "content": component} for (name, component) in self.flows.items()] @@ -639,6 +643,7 @@ class LightningFlow: run_once: Whether to run the entire iteration only once. Otherwise, it would restart from the beginning. user_key: Key to be used to track the caching mechanism. + """ if not isinstance(iterable, Iterable): raise TypeError(f"An iterable should be provided to `self.iterate` method. Found {iterable}") @@ -708,6 +713,7 @@ class LightningFlow: .. code-block:: bash lightning my_command_name --args name=my_own_name + """ raise NotImplementedError @@ -741,6 +747,7 @@ class LightningFlow: Once the app is running, you can access the Swagger UI of the app under the ``/docs`` route. + """ raise NotImplementedError @@ -805,6 +812,7 @@ class LightningFlow: children_states: The state of the dynamic children of this flow. strict: Whether to raise an exception if a dynamic children hasn't been re-created. + """ self.set_state(flow_state, recurse=False) direct_children_states = {k: v for k, v in children_states.items() if "." not in k} diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py index 18e02dd989..1900f961a7 100644 --- a/src/lightning/app/core/queues.py +++ b/src/lightning/app/core/queues.py @@ -182,6 +182,7 @@ class BaseQueue(ABC): timeout: Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used. A timeout of None can be used to block indefinitely. + """ pass @@ -190,6 +191,7 @@ class BaseQueue(ABC): """Returns True if the queue is running, False otherwise. Child classes should override this property and implement custom logic as required + """ return True @@ -286,6 +288,7 @@ class RedisQueue(BaseQueue): timeout: Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used. A timeout of None can be used to block indefinitely. + """ if timeout is None: # this means it's blocking in redis @@ -464,6 +467,7 @@ class HTTPQueue(BaseQueue): This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments + """ if "_" not in queue_name: return "", queue_name diff --git a/src/lightning/app/core/work.py b/src/lightning/app/core/work.py index f5416c0cfa..15a762d6f9 100644 --- a/src/lightning/app/core/work.py +++ b/src/lightning/app/core/work.py @@ -124,6 +124,7 @@ class LightningWork:
+ """ from lightning.app.runners.backends.backend import Backend @@ -212,6 +213,7 @@ class LightningWork: By default, this attribute returns the empty string and the ip address will only be returned once the work runs. Locally, the address is 127.0.0.1 and in the cloud it will be determined by the cluster. + """ return self._internal_ip @@ -221,6 +223,7 @@ class LightningWork: By default, this attribute returns the empty string and the ip address will only be returned once the work runs. Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster. + """ return self._public_ip @@ -234,6 +237,7 @@ class LightningWork: (prefixed by '__') attributes are not. Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable. + """ return name in LightningWork._INTERNAL_STATE_VARS or not name.startswith("_") @@ -247,6 +251,7 @@ class LightningWork: """Returns the display name of the LightningWork in the cloud. The display name needs to set before the run method of the work is called. + """ return self._display_name @@ -269,6 +274,7 @@ class LightningWork: """Whether to run in parallel mode or not. When parallel is False, the flow waits for the work to finish. + """ return self._parallel @@ -325,6 +331,7 @@ class LightningWork: """Return the current status of the work. All statuses are stored in the state. + """ call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH] if call_hash in self._calls: @@ -628,6 +635,7 @@ class LightningWork: Raises: LightningPlatformException: If resource exceeds platform quotas or other constraints. + """ def on_exception(self, exception: BaseException) -> None: @@ -636,8 +644,7 @@ class LightningWork: raise exception def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus: - """Method used to return the first request and the total count of timeout after the latest succeeded - status.""" + """Method used to return the first request and the total count of timeout after the latest succeeded status.""" succeeded_statuses = [ status_idx for status_idx, status in enumerate(statuses) if status["stage"] == WorkStageStatus.SUCCEEDED ] @@ -653,6 +660,7 @@ class LightningWork: """Override this hook to add your logic when the work is exiting. Note: This hook is not guaranteed to be called when running in the cloud. + """ pass @@ -660,6 +668,7 @@ class LightningWork: """Stops LightingWork component and shuts down hardware provisioned via L.CloudCompute. This can only be called from a ``LightningFlow``. + """ if not self._backend: raise RuntimeError(f"Only the `LightningFlow` can request this work ({self.name!r}) to stop.") @@ -675,6 +684,7 @@ class LightningWork: """Delete LightingWork component and shuts down hardware provisioned via L.CloudCompute. Locally, the work.delete() behaves as work.stop(). + """ if not self._backend: raise Exception( @@ -755,4 +765,5 @@ class LightningWork: returned URL can depend on the state. This is not the case if the work returns a :class:`~lightning.app.frontend.frontend.Frontend`. These need to be provided at the time of app creation in order for the runtime to start the server. + """ diff --git a/src/lightning/app/frontend/frontend.py b/src/lightning/app/frontend/frontend.py index d2d24b5656..2d87d4ebcd 100644 --- a/src/lightning/app/frontend/frontend.py +++ b/src/lightning/app/frontend/frontend.py @@ -23,6 +23,7 @@ class Frontend(ABC): """Base class for any frontend that gets exposed by LightningFlows. The flow attribute will be set by the app while bootstrapping. + """ def __init__(self) -> None: @@ -48,6 +49,7 @@ class Frontend(ABC): def start_server(self, host, port, root_path=""): self._process = subprocess.Popen(["flask", "run" "--host", host, "--port", str(port)]) + """ @abstractmethod @@ -62,4 +64,5 @@ class Frontend(ABC): def stop_server(self): self._process.kill() + """ diff --git a/src/lightning/app/frontend/just_py/just_py.py b/src/lightning/app/frontend/just_py/just_py.py index 7d5ac44306..11a9d55799 100644 --- a/src/lightning/app/frontend/just_py/just_py.py +++ b/src/lightning/app/frontend/just_py/just_py.py @@ -81,6 +81,7 @@ class JustPyFrontend(Frontend): app = LightningApp(Flow()) + """ def __init__(self, render_fn: Callable) -> None: diff --git a/src/lightning/app/frontend/panel/__init__.py b/src/lightning/app/frontend/panel/__init__.py index bb67ee1568..96cb5550ce 100644 --- a/src/lightning/app/frontend/panel/__init__.py +++ b/src/lightning/app/frontend/panel/__init__.py @@ -1,5 +1,4 @@ -"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app -framework.""" +"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app framework.""" from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher from lightning.app.frontend.panel.panel_frontend import PanelFrontend diff --git a/src/lightning/app/frontend/panel/app_state_comm.py b/src/lightning/app/frontend/panel/app_state_comm.py index eb1f018786..9fec245f7b 100644 --- a/src/lightning/app/frontend/panel/app_state_comm.py +++ b/src/lightning/app/frontend/panel/app_state_comm.py @@ -94,6 +94,7 @@ def _watch_app_state(callback: Callable): def handle_state_change(): print("The App State changed.") watch_app_state(handle_state_change) + """ _CALLBACKS.append(callback) _start_websocket() diff --git a/src/lightning/app/frontend/panel/app_state_watcher.py b/src/lightning/app/frontend/panel/app_state_watcher.py index 528a19acce..3612fb1147 100644 --- a/src/lightning/app/frontend/panel/app_state_watcher.py +++ b/src/lightning/app/frontend/panel/app_state_watcher.py @@ -72,6 +72,7 @@ class AppStateWatcher(Parameterized): Pydantic which additionally provides powerful and unique features for building reactive apps. Please note the ``AppStateWatcher`` is a singleton, i.e., only one instance is instantiated + """ state: AppState = ClassSelector( diff --git a/src/lightning/app/frontend/panel/panel_frontend.py b/src/lightning/app/frontend/panel/panel_frontend.py index f4a5c68f57..d0f4ead1ca 100644 --- a/src/lightning/app/frontend/panel/panel_frontend.py +++ b/src/lightning/app/frontend/panel/panel_frontend.py @@ -35,6 +35,7 @@ def _has_panel_autoreload() -> bool: """Returns True if the PANEL_AUTORELOAD environment variable is set to 'yes' or 'true'. Please note the casing of value does not matter + """ return os.environ.get("PANEL_AUTORELOAD", "no").lower() in ["yes", "y", "true"] @@ -98,6 +99,7 @@ class PanelFrontend(Frontend): For development you can get Panel autoreload by setting the ``PANEL_AUTORELOAD`` environment variable to 'yes', i.e. run ``PANEL_AUTORELOAD=yes lightning run app app_basic.py`` + """ @requires("panel") diff --git a/src/lightning/app/frontend/panel/panel_serve_render_fn.py b/src/lightning/app/frontend/panel/panel_serve_render_fn.py index df6a83d713..0c06cccd06 100644 --- a/src/lightning/app/frontend/panel/panel_serve_render_fn.py +++ b/src/lightning/app/frontend/panel/panel_serve_render_fn.py @@ -26,6 +26,7 @@ Example: .. code-block:: bash python panel_serve_render_fn + """ import inspect import os diff --git a/src/lightning/app/frontend/stream_lit.py b/src/lightning/app/frontend/stream_lit.py index 1ce03b997e..0cdc37296d 100644 --- a/src/lightning/app/frontend/stream_lit.py +++ b/src/lightning/app/frontend/stream_lit.py @@ -61,6 +61,7 @@ class StreamlitFrontend(Frontend): st.write("Hello from streamlit!") st.write(state.counter) + """ @requires("streamlit") diff --git a/src/lightning/app/frontend/streamlit_base.py b/src/lightning/app/frontend/streamlit_base.py index 189bbba82b..b03628b449 100644 --- a/src/lightning/app/frontend/streamlit_base.py +++ b/src/lightning/app/frontend/streamlit_base.py @@ -14,6 +14,7 @@ """This file gets run by streamlit, which we launch within Lightning. From here, we will call the render function that the user provided in ``configure_layout``. + """ import os import pydoc diff --git a/src/lightning/app/frontend/utils.py b/src/lightning/app/frontend/utils.py index 1399046686..80898f1213 100644 --- a/src/lightning/app/frontend/utils.py +++ b/src/lightning/app/frontend/utils.py @@ -37,6 +37,7 @@ def _get_flow_state(flow: str) -> AppState: Returns: AppState: An AppState scoped to the current Flow. + """ app_state = AppState() app_state._request_state() # pylint: disable=protected-access @@ -54,6 +55,7 @@ def _get_frontend_environment(flow: str, render_fn_or_file: Callable | str, port Returns: os._Environ: An environment + """ env = os.environ.copy() env["LIGHTNING_FLOW_NAME"] = flow diff --git a/src/lightning/app/frontend/web.py b/src/lightning/app/frontend/web.py index f6fa639d3e..2e7d9f3f2f 100644 --- a/src/lightning/app/frontend/web.py +++ b/src/lightning/app/frontend/web.py @@ -44,6 +44,7 @@ class StaticWebFrontend(Frontend): def configure_layout(self): return StaticWebFrontend("path/to/folder/to/serve") + """ def __init__(self, serve_dir: str) -> None: @@ -102,8 +103,7 @@ def _start_server( def _get_log_config(log_file: str) -> dict: - """Returns a logger configuration in the format expected by uvicorn that sends all logs to the given - logfile.""" + """Returns a logger configuration in the format expected by uvicorn that sends all logs to the given logfile.""" # Modified from the default config found in uvicorn.config.LOGGING_CONFIG return { "version": 1, diff --git a/src/lightning/app/launcher/launcher.py b/src/lightning/app/launcher/launcher.py index 8f00731161..3d2a066889 100644 --- a/src/lightning/app/launcher/launcher.py +++ b/src/lightning/app/launcher/launcher.py @@ -109,6 +109,7 @@ def run_lightning_work( It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud specific logic is being implemented here + """ logger.debug(f"Run Lightning Work {file} {work_name} {queue_id}") @@ -231,6 +232,7 @@ def serve_frontend(file: str, flow_name: str, host: str, port: int): It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud specific logic is being implemented here. + """ _set_frontend_context() logger.debug(f"Run Serve Frontend {file} {flow_name} {host} {port}") @@ -340,15 +342,16 @@ def manage_server_processes(processes: List[Tuple[str, Process]]) -> None: def _get_frontends_from_app(entrypoint_file): - """This function is used to get the frontends from the app. It will be used to start the frontends in a - separate process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded - locally to set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 - if the app.spec.frontends is not set. + """This function is used to get the frontends from the app. It will be used to start the frontends in a separate + process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded locally to + set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 if the + app.spec.frontends is not set. NOTE: frontend_name are sorted to ensure that they get consistent ports. :param entrypoint_file: The entrypoint file for the app :return: A list of tuples of the form (frontend_name, port_number) + """ app = load_app_from_file(entrypoint_file) diff --git a/src/lightning/app/launcher/lightning_backend.py b/src/lightning/app/launcher/lightning_backend.py index 1e3c096e45..fc100b9ead 100644 --- a/src/lightning/app/launcher/lightning_backend.py +++ b/src/lightning/app/launcher/lightning_backend.py @@ -252,6 +252,7 @@ class CloudBackend(Backend): Normally, the Lightning frameworks communicates statuses through the queues, but while the Work instance is being provisionied, the queues don't exist yet and hence we need to make API calls directly to the backend to fetch the status and update it in the states. + """ if not works: return @@ -305,6 +306,7 @@ class CloudBackend(Backend): """Stop resources for all LightningWorks in this app. The Works are stopped rather than deleted so that they can be inspected for debugging. + """ cloud_works = self._get_cloud_work_specs(self.client) diff --git a/src/lightning/app/plugin/plugin.py b/src/lightning/app/plugin/plugin.py index db66a4ad24..67d15d0c27 100644 --- a/src/lightning/app/plugin/plugin.py +++ b/src/lightning/app/plugin/plugin.py @@ -60,6 +60,7 @@ class LightningPlugin: Returns: The relative URL of the created job. + """ from lightning.app.runners.backends.cloud import CloudBackend from lightning.app.runners.cloud import CloudRuntime diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py index 0101b7a720..d742c9d459 100644 --- a/src/lightning/app/runners/cloud.py +++ b/src/lightning/app/runners/cloud.py @@ -198,8 +198,8 @@ class CloudRuntime(Runtime): cluster_id: str, source_app: Optional[str] = None, ) -> str: - """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties - such as the project and cluster IDs that are instead passed directly. + """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such + as the project and cluster IDs that are instead passed directly. Args: project_id: The ID of the project. @@ -214,6 +214,7 @@ class CloudRuntime(Runtime): Returns: The URL of the created job. + """ # Dispatch in four phases: resolution, validation, spec creation, API transactions # Resolution @@ -432,6 +433,7 @@ class CloudRuntime(Runtime): """Find and load the config file if it exists (otherwise create an empty config). Override the name if provided. + """ config_file = _get_config_file(self.entrypoint) cloudspace_config = AppConfig.load_from_file(config_file) if config_file.exists() and load else AppConfig() @@ -611,6 +613,7 @@ class CloudRuntime(Runtime): """Check if the user likely needs credits to run the app with its hardware. Returns False if user has 1 or more credits. + """ balance = project.balance if balance is None: @@ -698,8 +701,8 @@ class CloudRuntime(Runtime): raise RuntimeError(f"Unknown mount protocol `{mount.protocol}` for work `{work.name}`.") def _get_flow_servers(self) -> List[V1Flowserver]: - """Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs - to start servers.""" + """Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs to + start servers.""" flow_servers: List[V1Flowserver] = [] for flow_name in self.app.frontends: flow_server = V1Flowserver(name=flow_name) @@ -889,8 +892,7 @@ class CloudRuntime(Runtime): def _get_env_vars( env_vars: Dict[str, str], secrets: Dict[str, str], run_app_comment_commands: bool ) -> List[V1EnvVar]: - """Generate the list of environment variable specs for the app, including variables set by the - framework.""" + """Generate the list of environment variable specs for the app, including variables set by the framework.""" v1_env_vars = [V1EnvVar(name=k, value=v) for k, v in env_vars.items()] if len(secrets.values()) > 0: @@ -929,6 +931,7 @@ class CloudRuntime(Runtime): """Create the cloudspace if it doesn't exist. Return the cloudspace ID. + """ if existing_cloudspace is None: cloudspace_body = ProjectIdCloudspacesBody(name=name, can_download_source_code=True) @@ -980,6 +983,7 @@ class CloudRuntime(Runtime): """Transfer an existing instance to the given run ID and update its specification. Return the instance. + """ run_instance = self.backend.client.lightningapp_instance_service_update_lightningapp_instance_release( project_id=project_id, diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py index 93f091f870..2db6266893 100644 --- a/src/lightning/app/runners/multiprocess.py +++ b/src/lightning/app/runners/multiprocess.py @@ -39,6 +39,7 @@ class MultiProcessRuntime(Runtime): The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach queues to enable communication between the different processes. + """ backend: Union[str, Backend] = "multiprocessing" diff --git a/src/lightning/app/runners/runtime.py b/src/lightning/app/runners/runtime.py index 375d1d16a5..d710cae985 100644 --- a/src/lightning/app/runners/runtime.py +++ b/src/lightning/app/runners/runtime.py @@ -66,6 +66,7 @@ def dispatch( run_app_comment_commands: whether to parse commands from the entrypoint file and execute them before app startup enable_basic_auth: whether to enable basic authentication for the app (use credentials in the format username:password as an argument) + """ from lightning.app.runners.runtime_type import RuntimeType from lightning.app.utilities.component import _set_flow_context diff --git a/src/lightning/app/source_code/copytree.py b/src/lightning/app/source_code/copytree.py index 592d3d7dee..ba51af98af 100644 --- a/src/lightning/app/source_code/copytree.py +++ b/src/lightning/app/source_code/copytree.py @@ -34,9 +34,9 @@ def _copytree( dirs_exist_ok=False, dry_run=False, ) -> List[str]: - """Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like - `git` does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks. - Differences between original and this function are. + """Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like `git` + does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks. Differences + between original and this function are. 1. It supports a list of ignore function instead of a single one in the original. We can use this for filtering out files based on nested @@ -66,6 +66,7 @@ def _copytree( If exception(s) occur, an Error is raised with a list of reasons. + """ files_copied = [] @@ -146,8 +147,8 @@ def _parse_lightningignore(lines: Tuple[str]) -> Set[str]: def _read_lightningignore(path: Path) -> Set[str]: - """Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's - done to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path. + """Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's done + to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path. Parameters ---------- @@ -158,6 +159,7 @@ def _read_lightningignore(path: Path) -> Set[str]: ------- Set[str] Set of unique lines. + """ raw_lines = path.open().readlines() return _parse_lightningignore(raw_lines) diff --git a/src/lightning/app/source_code/hashing.py b/src/lightning/app/source_code/hashing.py index 362d32f259..6cd823e9a0 100644 --- a/src/lightning/app/source_code/hashing.py +++ b/src/lightning/app/source_code/hashing.py @@ -33,6 +33,7 @@ def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int ---------- [1] https://crypto.stackexchange.com/questions/70101/blake2-vs-md5-for-checksum-file-integrity [2] https://stackoverflow.com/questions/1131220/get-md5-hash-of-big-files-in-python + """ # validate input if algorithm == "blake2": diff --git a/src/lightning/app/source_code/local.py b/src/lightning/app/source_code/local.py index b01e0c4c0a..bfad7eb442 100644 --- a/src/lightning/app/source_code/local.py +++ b/src/lightning/app/source_code/local.py @@ -133,6 +133,7 @@ class LocalSourceCodeDir: packaged repository files which have a size > 2GB. This limitation should be removed during the datastore upload redesign + """ if self.package_path.stat().st_size > 2e9: raise OSError( diff --git a/src/lightning/app/source_code/tar.py b/src/lightning/app/source_code/tar.py index 7ca93c798b..c3aca1ae31 100644 --- a/src/lightning/app/source_code/tar.py +++ b/src/lightning/app/source_code/tar.py @@ -36,6 +36,7 @@ def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tu ------- Tuple[int, int] Size in megabytes and file count + """ size = 0 count = 0 @@ -61,6 +62,7 @@ class _TarResults: The total size of the original directory files in bytes after_size: int The total size of the compressed and tarred split files in bytes + """ before_size: int @@ -70,13 +72,12 @@ class _TarResults: def _get_split_size( total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT ) -> int: - """Calculate the split size we should use to split the multipart upload of an object to a bucket. We are - limited to 1000 max parts as the way we are using ListMultipartUploads. More info - https://github.com/gridai/grid/pull/5267 + """Calculate the split size we should use to split the multipart upload of an object to a bucket. We are limited + to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267 https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html - https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 - bytes for a single file upload. + https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes + for a single file upload. Parameters ---------- @@ -91,6 +92,7 @@ def _get_split_size( ------- int Split size + """ max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above if total_size > max_size: @@ -123,6 +125,7 @@ def _tar_path(source_path: str, target_file: str, compression: bool = False) -> ------- TarResults Results that holds file counts and sizes + """ if os.path.isdir(source_path): before_size, _ = _get_dir_size_and_count(source_path) @@ -149,6 +152,7 @@ def _tar_path_python(source_path: str, target_file: str, compression: bool = Fal Target tar file compression: bool, default False Enable compression, which is disabled by default. + """ file_mode = "w:gz" if compression else "w:" @@ -172,6 +176,7 @@ def _tar_path_subprocess(source_path: str, target_file: str, compression: bool = Target tar file compression: bool, default False Enable compression, which is disabled by default. + """ # Only add compression when users explicitly request it. # We do this because it takes too long to compress diff --git a/src/lightning/app/source_code/uploader.py b/src/lightning/app/source_code/uploader.py index 306ee96d7e..82336c7b0b 100644 --- a/src/lightning/app/source_code/uploader.py +++ b/src/lightning/app/source_code/uploader.py @@ -35,6 +35,7 @@ class FileUploader: Size of all files to upload name: str Name of this upload to display progress + """ workers: int = 8 @@ -72,6 +73,7 @@ class FileUploader: ------- str ETag from response + """ disconnect_retries = retries while disconnect_retries > 0: diff --git a/src/lightning/app/storage/copier.py b/src/lightning/app/storage/copier.py index 144c2335a7..3619654cfb 100644 --- a/src/lightning/app/storage/copier.py +++ b/src/lightning/app/storage/copier.py @@ -52,6 +52,7 @@ class _Copier(Thread): will send requests to this queue. copy_response_queue: A queue connecting the central StorageOrchestrator with the Copier. The Copier will send a response to this queue whenever a requested copy has finished. + """ def __init__( @@ -116,6 +117,7 @@ def _copy_files( interpreted as a folder as well. If the source is a file, the destination path is interpreted as a file too. Files in a folder are copied recursively and efficiently using multiple threads. + """ if fs is None: fs = _filesystem() diff --git a/src/lightning/app/storage/drive.py b/src/lightning/app/storage/drive.py index 0cdb310046..d90c5a0d41 100644 --- a/src/lightning/app/storage/drive.py +++ b/src/lightning/app/storage/drive.py @@ -46,6 +46,7 @@ class Drive: component_name: The component name which owns this drive. When not provided, it is automatically inferred by Lightning. root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir). + """ if id.startswith("s3://"): raise ValueError( @@ -96,6 +97,7 @@ class Drive: Arguments: path: The relative path to your files to be added to the Drive. + """ if not self.component_name: raise Exception("The component name needs to be known to put a path to the Drive.") @@ -121,6 +123,7 @@ class Drive: path: The relative path you want to list files from the Drive. component_name: By default, the Drive lists files across all components. If you provide a component name, the listing is specific to this component. + """ if _is_flow_context(): raise Exception("The flow isn't allowed to list files from a Drive.") @@ -165,6 +168,7 @@ class Drive: If you provide a component name, the matching is specific to this component. timeout: Whether to wait for the files to be available if not created yet. overwrite: Whether to override the provided path if it exists. + """ if _is_flow_context(): raise Exception("The flow isn't allowed to get files from a Drive.") @@ -207,11 +211,12 @@ class Drive: self._get(self.fs, match, pathlib.Path(os.path.join(self.root_folder, path)).resolve(), overwrite=overwrite) def delete(self, path: str) -> None: - """This method enables to delete files under the provided path from the Drive in a blocking fashion. Only - the component which added a file can delete them. + """This method enables to delete files under the provided path from the Drive in a blocking fashion. Only the + component which added a file can delete them. Arguments: path: The relative path you want to delete files from the Drive. + """ if not self.component_name: raise Exception("The component name needs to be known to delete a path to the Drive.") diff --git a/src/lightning/app/storage/filesystem.py b/src/lightning/app/storage/filesystem.py index 141a29e8a1..943a6a750b 100644 --- a/src/lightning/app/storage/filesystem.py +++ b/src/lightning/app/storage/filesystem.py @@ -42,6 +42,7 @@ class FileSystem: src_path: The path to your files locally dst_path: The path to your files transfered in the shared storage. put_fn: The method to use to put files in the shared storage. + """ if not os.path.exists(Path(src_path).resolve()): raise FileExistsError(f"The provided path {src_path} doesn't exist") @@ -66,6 +67,7 @@ class FileSystem: src_path: The path to your files in the shared storage dst_path: The path to your files transfered locally get_fn: The method to use to put files in the shared storage. + """ if not src_path.startswith("/"): raise Exception(f"The provided destination {src_path} needs to start with `/`.") @@ -80,6 +82,7 @@ class FileSystem: Arguments: path: The path to files to list. + """ if not path.startswith("/"): raise Exception(f"The provided destination {path} needs to start with `/`.") @@ -104,6 +107,7 @@ class FileSystem: Arguments: path: The path to files to list. + """ if not path.startswith("/"): raise Exception(f"The provided destination {path} needs to start with `/`.") diff --git a/src/lightning/app/storage/mount.py b/src/lightning/app/storage/mount.py index efe922aa97..8142b4574a 100644 --- a/src/lightning/app/storage/mount.py +++ b/src/lightning/app/storage/mount.py @@ -34,6 +34,7 @@ class Mount: mount_path: An absolute directory path in the work where external data source should be mounted as a filesystem. This path should not already exist in your codebase. If not included, then the root_dir will be set to `/data/` + """ source: str = "" diff --git a/src/lightning/app/storage/orchestrator.py b/src/lightning/app/storage/orchestrator.py index 406539f4c7..43ce7b76e5 100644 --- a/src/lightning/app/storage/orchestrator.py +++ b/src/lightning/app/storage/orchestrator.py @@ -47,6 +47,7 @@ class StorageOrchestrator(Thread): put requests on this queue for the file-transfer thread to complete. copy_response_queues: A dictionary of Queues where each Queue connects to one Work. The queue is expected to contain the completion response from the file-transfer thread running in the Work process. + """ def __init__( diff --git a/src/lightning/app/storage/path.py b/src/lightning/app/storage/path.py index f0b7ee9560..0ecb79a7f5 100644 --- a/src/lightning/app/storage/path.py +++ b/src/lightning/app/storage/path.py @@ -53,6 +53,7 @@ class Path(PathlibPath): Args: *args: Accepts the same arguments as in :class:`pathlib.Path` **kwargs: Accepts the same keyword arguments as in :class:`pathlib.Path` + """ @classmethod @@ -105,6 +106,7 @@ class Path(PathlibPath): """The name of the LightningWork where this path was first created. Attaching a Path to a LightningWork will automatically make it the `origin`. + """ from lightning.app.core.work import LightningWork @@ -115,6 +117,7 @@ class Path(PathlibPath): """The name of the LightningWork where this path is being accessed. By default, this is the same as the :attr:`origin_name`. + """ from lightning.app.core.work import LightningWork @@ -125,6 +128,7 @@ class Path(PathlibPath): """The hash of this Path uniquely identifies the file path and the associated origin Work. Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork. + """ if self._origin is None: return None @@ -152,6 +156,7 @@ class Path(PathlibPath): If you strictly want to check local existence only, use :meth:`exists_local` instead. If you strictly want to check existence on the remote (regardless of whether the file exists locally or not), use :meth:`exists_remote`. + """ return self.exists_local() or (self._origin and self.exists_remote()) @@ -164,6 +169,7 @@ class Path(PathlibPath): Raises: RuntimeError: If the path is not attached to any Work (origin undefined). + """ # Fail early if we need to check the remote but an origin is not defined if not self._origin or self._request_queue is None or self._response_queue is None: @@ -272,6 +278,7 @@ class Path(PathlibPath): Args: work: LightningWork to be attached to this Path. + """ if self._origin is None: # Can become an owner only if there is not already one @@ -374,11 +381,11 @@ def _is_lit_path(path: Union[str, Path]) -> bool: def _shared_local_mount_path() -> pathlib.Path: - """Returns the shared directory through which the Copier threads move files from one Work filesystem to - another. + """Returns the shared directory through which the Copier threads move files from one Work filesystem to another. The shared directory can be set via the environment variable ``SHARED_MOUNT_DIRECTORY`` and should be pointing to a directory that all Works have mounted (shared filesystem). + """ path = pathlib.Path(os.environ.get("SHARED_MOUNT_DIRECTORY", ".shared")) path.mkdir(parents=True, exist_ok=True) @@ -397,6 +404,7 @@ def _shared_storage_path() -> pathlib.Path: The shared path gets set by the environment. Locally, it is pointing to a directory determined by the ``SHARED_MOUNT_DIRECTORY`` environment variable. In the cloud, the shared path will point to a S3 bucket. All Works have access to this shared dropbox. + """ storage_path = os.getenv("LIGHTNING_STORAGE_PATH", "") if storage_path != "": diff --git a/src/lightning/app/storage/payload.py b/src/lightning/app/storage/payload.py index 05d3463488..255a60c3fe 100644 --- a/src/lightning/app/storage/payload.py +++ b/src/lightning/app/storage/payload.py @@ -61,6 +61,7 @@ class _BasePayload(ABC): """The hash of this Payload uniquely identifies the payload and the associated origin Work. Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork. + """ if self._origin is None: return None @@ -72,6 +73,7 @@ class _BasePayload(ABC): """The name of the LightningWork where this payload was first created. Attaching a Payload to a LightningWork will automatically make it the `origin`. + """ from lightning.app.core.work import LightningWork @@ -82,6 +84,7 @@ class _BasePayload(ABC): """The name of the LightningWork where this payload is being accessed. By default, this is the same as the :attr:`origin_name`. + """ from lightning.app.core.work import LightningWork @@ -107,6 +110,7 @@ class _BasePayload(ABC): Args: work: LightningWork to be attached to this Payload. + """ if self._origin is None: # Can become an owner only if there is not already one @@ -130,6 +134,7 @@ class _BasePayload(ABC): Raises: RuntimeError: If the payload is not attached to any Work (origin undefined). + """ # Fail early if we need to check the remote but an origin is not defined if not self._origin or self._request_queue is None or self._response_queue is None: diff --git a/src/lightning/app/testing/helpers.py b/src/lightning/app/testing/helpers.py index 99a81b2523..c1eafce723 100644 --- a/src/lightning/app/testing/helpers.py +++ b/src/lightning/app/testing/helpers.py @@ -62,6 +62,7 @@ class _RunIf: @pytest.mark.parametrize("arg1", [1, 2.0]) def test_wrapper(arg1): assert arg1 > 0.0 + """ def __new__( @@ -155,6 +156,7 @@ class EmptyFlow(LightningFlow): """A LightningFlow that implements all abstract methods to do nothing. Useful for mocking in tests. + """ def run(self): @@ -165,6 +167,7 @@ class EmptyWork(LightningWork): """A LightningWork that implements all abstract methods to do nothing. Useful for mocking in tests. + """ def run(self): diff --git a/src/lightning/app/testing/testing.py b/src/lightning/app/testing/testing.py index 93a6be1997..f873355ab0 100644 --- a/src/lightning/app/testing/testing.py +++ b/src/lightning/app/testing/testing.py @@ -503,6 +503,7 @@ def delete_cloud_lightning_apps(name=None): """Cleanup cloud apps that start with the name test-{PR_NUMBER}-{TEST_APP_NAME}. PR_NUMBER and TEST_APP_NAME are environment variables. + """ client = LightningClient() diff --git a/src/lightning/app/utilities/app_commands.py b/src/lightning/app/utilities/app_commands.py index 2db33d9ebb..e3e8af50d3 100644 --- a/src/lightning/app/utilities/app_commands.py +++ b/src/lightning/app/utilities/app_commands.py @@ -44,6 +44,7 @@ def _extract_commands_from_file(file_name: str) -> CommandLines: """Extract all lines at the top of the file which contain commands to execute. The return struct contains a list of commands to execute with the corresponding line number the command executed on. + """ cl = CommandLines( file=file_name, @@ -83,6 +84,7 @@ def _execute_app_commands(cl: CommandLines) -> None: """Open a subprocess shell to execute app commands. The calling app environment is used in the current environment the code is running in + """ for command, line_number in zip(cl.commands, cl.line_numbers): logger.info(f"Running app setup command: {command}") @@ -116,6 +118,7 @@ def run_app_commands(file: str) -> None: foo! bar <--- not a command import lightning <--- not a command, end parsing. where `echo "hello world" && pip install foo` would be executed in the current running environment. + """ cl = _extract_commands_from_file(file_name=file) if len(cl.commands) == 0: diff --git a/src/lightning/app/utilities/app_helpers.py b/src/lightning/app/utilities/app_helpers.py index 018de7ab32..7878fac4e3 100644 --- a/src/lightning/app/utilities/app_helpers.py +++ b/src/lightning/app/utilities/app_helpers.py @@ -57,8 +57,7 @@ class StateEntry: class StateStore(ABC): - """Base class of State store that provides simple key, value store to keep track of app state, served app - state.""" + """Base class of State store that provides simple key, value store to keep track of app state, served app state.""" @abstractmethod def __init__(self): @@ -352,6 +351,7 @@ def _walk_to_component( """Returns a generator that runs through the tree starting from the root down to the given component. At each node, yields parent and child as a tuple. + """ from lightning.app.structures import Dict, List @@ -469,6 +469,7 @@ def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict: root_flow: The flow at the top of the component tree. state: The collected state dict. strict: Whether to validate all components have been re-created. + """ # 1: Reload the state of the existing works for w in root_flow.works(): diff --git a/src/lightning/app/utilities/app_logs.py b/src/lightning/app/utilities/app_logs.py index 903bc615ba..446418f9b1 100644 --- a/src/lightning/app/utilities/app_logs.py +++ b/src/lightning/app/utilities/app_logs.py @@ -49,6 +49,7 @@ def _push_log_events_to_read_queue_callback(component_name: str, read_queue: que """Pushes _LogEvents from websocket to read_queue. Returns callback function used with `on_message_callback` of websocket.WebSocketApp. + """ def callback(ws_app: WebSocketApp, msg: str): diff --git a/src/lightning/app/utilities/auth.py b/src/lightning/app/utilities/auth.py index b29801c9fd..2ccb2d0681 100644 --- a/src/lightning/app/utilities/auth.py +++ b/src/lightning/app/utilities/auth.py @@ -42,6 +42,7 @@ def _credential_string_to_basic_auth_params(credential_string: str) -> Dict[str, """Returns the name/ID pair for each given Secret name. Raises a `ValueError` if any of the given Secret names do not exist. + """ if credential_string.count(":") != 1: raise ValueError( diff --git a/src/lightning/app/utilities/cli_helpers.py b/src/lightning/app/utilities/cli_helpers.py index 6280fdecad..ad9c835ab4 100644 --- a/src/lightning/app/utilities/cli_helpers.py +++ b/src/lightning/app/utilities/cli_helpers.py @@ -129,6 +129,7 @@ class _LightningAppOpenAPIRetriever: Arguments: app_id_or_name_or_url: An identified for the app. use_cache: Whether to load the openapi spec from the cache. + """ self.app_id_or_name_or_url = app_id_or_name_or_url self.url = None diff --git a/src/lightning/app/utilities/clusters.py b/src/lightning/app/utilities/clusters.py index a083e41c71..663ba66d45 100644 --- a/src/lightning/app/utilities/clusters.py +++ b/src/lightning/app/utilities/clusters.py @@ -25,6 +25,7 @@ def _get_default_cluster(client: LightningClient, project_id: str) -> str: """This utility implements a minimal version of the cluster selection logic used in the cloud. TODO: This should be requested directly from the platform. + """ cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id=project_id).clusters diff --git a/src/lightning/app/utilities/component.py b/src/lightning/app/utilities/component.py index 75c4c09d56..bb4bc5ecfd 100644 --- a/src/lightning/app/utilities/component.py +++ b/src/lightning/app/utilities/component.py @@ -36,6 +36,7 @@ def _convert_paths_after_init(root: "LightningFlow"): This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or consumer. + """ from lightning.app import LightningFlow, LightningWork from lightning.app.storage import Path @@ -52,6 +53,7 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]: """Utility function to sanitize the state of a component. Sanitization enables the state to be deep-copied and hashed. + """ from lightning.app.storage import Drive, Path from lightning.app.storage.payload import _BasePayload @@ -132,6 +134,7 @@ def _context(ctx: str) -> Generator[None, None, None]: The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork. See also :func:`_get_context`, :func:`_set_context`. For internal use only. + """ prev = _get_context() _set_context(ctx) diff --git a/src/lightning/app/utilities/data_structures.py b/src/lightning/app/utilities/data_structures.py index 626e381d73..495c43fd0e 100644 --- a/src/lightning/app/utilities/data_structures.py +++ b/src/lightning/app/utilities/data_structures.py @@ -29,6 +29,7 @@ class AttributeDict(Dict): "key2": abc "my-key": 3.14 "new_key": 42 + """ def __getattr__(self, key: str) -> Optional[Any]: diff --git a/src/lightning/app/utilities/exceptions.py b/src/lightning/app/utilities/exceptions.py index fb63a6ce7b..3bb5eced46 100644 --- a/src/lightning/app/utilities/exceptions.py +++ b/src/lightning/app/utilities/exceptions.py @@ -29,6 +29,7 @@ class _ApiExceptionHandler(Group): However, if the ApiException cannot be decoded, or is not a 4xx error, the original ApiException will be re-raised. + """ def invoke(self, ctx: Context) -> Any: @@ -81,6 +82,7 @@ class LightningPlatformException(Exception): # pragma: no cover It gets raised by the Lightning Launcher on the platform side when the app is running in the cloud, and is useful when framework or user code needs to catch exceptions specific to the platform, e.g., when resources exceed quotas. + """ diff --git a/src/lightning/app/utilities/git.py b/src/lightning/app/utilities/git.py index aa3294aa19..1293a2a509 100644 --- a/src/lightning/app/utilities/git.py +++ b/src/lightning/app/utilities/git.py @@ -26,6 +26,7 @@ def execute_git_command(args: List[str], cwd=None) -> str: ------- output: str String combining stdout and stderr. + """ process = subprocess.run(["git"] + args, capture_output=True, text=True, cwd=cwd, check=False) return process.stdout.strip() + process.stderr.strip() @@ -61,6 +62,7 @@ def check_if_remote_head_is_different() -> Union[bool, None]: This only compares the local SHA to the HEAD commit of a given branch. This check won't be used if user isn't in a HEAD locally. + """ # Check SHA values. local_sha = execute_git_command(["rev-parse", "@"]) @@ -78,6 +80,7 @@ def has_uncommitted_files() -> bool: """Checks if user has uncommited files in local repository. If there are uncommited files, then show a prompt indicating that uncommited files exist locally. + """ files = execute_git_command(["update-index", "--refresh"]) return bool(files) diff --git a/src/lightning/app/utilities/imports.py b/src/lightning/app/utilities/imports.py index 7c4c43a3c2..092917660f 100644 --- a/src/lightning/app/utilities/imports.py +++ b/src/lightning/app/utilities/imports.py @@ -32,6 +32,7 @@ def _get_extras(extras: str) -> str: """Get the given extras as a space delimited string. Used by the platform to install cloud extras in the cloud. + """ from lightning.app import __package_name__ diff --git a/src/lightning/app/utilities/introspection.py b/src/lightning/app/utilities/introspection.py index 394c5da593..87200576e6 100644 --- a/src/lightning/app/utilities/introspection.py +++ b/src/lightning/app/utilities/introspection.py @@ -22,13 +22,14 @@ if TYPE_CHECKING: class LightningVisitor(ast.NodeVisitor): - """Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected - to define class_name and implement the analyze_class_def method. + """Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to + define class_name and implement the analyze_class_def method. Attributes ---------- class_name: str Name of class to identify, to be defined in subclasses. + """ class_name: Optional[str] = None @@ -63,6 +64,7 @@ class LightningModuleVisitor(LightningVisitor): Names of methods that are part of the LightningModule API. hooks: Set[str] Names of hooks that are part of the LightningModule API. + """ class_name: Optional[str] = "LightningModule" @@ -132,6 +134,7 @@ class LightningDataModuleVisitor(LightningVisitor): Name of class to identify. methods: Set[str] Names of methods that are part of the LightningDataModule API. + """ class_name = "LightningDataModule" @@ -155,6 +158,7 @@ class LightningLoggerVisitor(LightningVisitor): Name of class to identify. methods: Set[str] Names of methods that are part of the Logger API. + """ class_name = "Logger" @@ -171,6 +175,7 @@ class LightningCallbackVisitor(LightningVisitor): Name of class to identify. methods: Set[str] Names of methods that are part of the Logger API. + """ class_name = "Callback" @@ -223,6 +228,7 @@ class LightningStrategyVisitor(LightningVisitor): Name of class to identify. methods: Set[str] Names of methods that are part of the Logger API. + """ class_name = "Strategy" @@ -282,6 +288,7 @@ class Scanner: glob_pattern: str Glob pattern to use when looking for files in the path, applied when path is a directory. Default is "**/*.py". + """ # TODO: Finalize introspecting the methods from all the discovered methods. @@ -341,6 +348,7 @@ class Scanner: List[Dict[str, Any]] List of dicts containing all metadata required to import modules found. + """ modules_found: Dict[str, List[Dict[str, Any]]] = {} diff --git a/src/lightning/app/utilities/layout.py b/src/lightning/app/utilities/layout.py index 553bcd6e91..e1a61539a0 100644 --- a/src/lightning/app/utilities/layout.py +++ b/src/lightning/app/utilities/layout.py @@ -26,6 +26,7 @@ def _add_comment_to_literal_code(method, contains, comment): """Inspects a method's code and adds a message to it. This is a nice to have, so if it fails for some reason, it shouldn't affect the program. + """ try: lines = inspect.getsource(method) diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py index f22aafa3ce..4504704c3d 100644 --- a/src/lightning/app/utilities/load_app.py +++ b/src/lightning/app/utilities/load_app.py @@ -61,6 +61,7 @@ def _load_objects_from_file( raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit. mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to be loaded without installing dependencies. + """ # Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313 @@ -110,6 +111,7 @@ def load_app_from_file( Arguments: filepath: The path to the file containing the LightningApp. raise_exception: If True, raise an exception if the app cannot be loaded. + """ from lightning.app.core.app import LightningApp @@ -142,6 +144,7 @@ def open_python_file(filename): In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263 headers in their source files with their own encodings. In that case, we should respect the author's encoding. + """ import tokenize @@ -204,6 +207,7 @@ def _patch_sys_path(append): Args: append: The value to append to the path. + """ if append in sys.path: yield diff --git a/src/lightning/app/utilities/login.py b/src/lightning/app/utilities/login.py index bc2d5d713c..3db7d1cb3b 100644 --- a/src/lightning/app/utilities/login.py +++ b/src/lightning/app/utilities/login.py @@ -60,6 +60,7 @@ class Auth: Returns ---------- True if credentials are available. + """ if not self.secrets_file.exists(): logger.debug("Credentials file not found.") @@ -117,6 +118,7 @@ class Auth: Returns ---------- authorization header to use when authentication completes. + """ if not self.load(): # First try to authenticate from env diff --git a/src/lightning/app/utilities/logs_socket_api.py b/src/lightning/app/utilities/logs_socket_api.py index 98d95bfa96..8bc8fa47d1 100644 --- a/src/lightning/app/utilities/logs_socket_api.py +++ b/src/lightning/app/utilities/logs_socket_api.py @@ -79,6 +79,7 @@ class _LightningLogsSocketAPI(_AuthTokenGetter): Returns: WebSocketApp of the wanted socket + """ _token = self._get_api_token() clean_ws_host = urlparse(self.api_client.configuration.host).netloc diff --git a/src/lightning/app/utilities/name_generator.py b/src/lightning/app/utilities/name_generator.py index 28c43c241c..c57a65f63a 100644 --- a/src/lightning/app/utilities/name_generator.py +++ b/src/lightning/app/utilities/name_generator.py @@ -1353,6 +1353,7 @@ def get_unique_name(): 'meek-ardinghelli-4506' >>> get_unique_name() 'truthful-dijkstra-2286' + """ adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311 return f"{adjective}-{surname}-{i}" diff --git a/src/lightning/app/utilities/network.py b/src/lightning/app/utilities/network.py index 3d80479c3f..314631b559 100644 --- a/src/lightning/app/utilities/network.py +++ b/src/lightning/app/utilities/network.py @@ -95,6 +95,7 @@ def _configure_session() -> Session: """Configures the session for GET and POST requests. It enables a generous retrial strategy that waits for the application server to connect. + """ retry_strategy = Retry( # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) @@ -124,10 +125,10 @@ def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> floa def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable: - """Returns the function decorated by a wrapper that retries the call several times if a connection error - occurs. + """Returns the function decorated by a wrapper that retries the call several times if a connection error occurs. The retries follow an exponential backoff. + """ @wraps(func) @@ -175,6 +176,7 @@ class LightningClient(GridRestClient): Args: retry: Whether API calls should follow a retry mechanism with exponential backoff. max_tries: Maximum number of attempts (or -1 to retry forever). + """ def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None: @@ -275,5 +277,6 @@ class HTTPClient: We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but it is crucial for finding bugs when we have them + """ pass diff --git a/src/lightning/app/utilities/openapi.py b/src/lightning/app/utilities/openapi.py index c79501bc6b..f210c3cd47 100644 --- a/src/lightning/app/utilities/openapi.py +++ b/src/lightning/app/utilities/openapi.py @@ -49,6 +49,7 @@ def create_openapi_object(json_obj: Dict, target: Any): Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid object. + """ if not isinstance(json_obj, dict): raise TypeError("json_obj must be a dictionary") diff --git a/src/lightning/app/utilities/packaging/app_config.py b/src/lightning/app/utilities/packaging/app_config.py index f22ffa99d1..57177344a8 100644 --- a/src/lightning/app/utilities/packaging/app_config.py +++ b/src/lightning/app/utilities/packaging/app_config.py @@ -29,6 +29,7 @@ class AppConfig: Args: name: Optional name of the application. If not provided, auto-generates a new name. + """ name: str = field(default_factory=get_unique_name) @@ -56,6 +57,7 @@ class AppConfig: Args: directory: Path to a folder which contains the '.lightning' config file to load. + """ return cls.load_from_file(pathlib.Path(directory, _APP_CONFIG_FILENAME)) @@ -65,6 +67,7 @@ def _get_config_file(source_path: Union[str, pathlib.Path]) -> pathlib.Path: Args: source_path: A path to a folder or a file. + """ source_path = pathlib.Path(source_path).absolute() if source_path.is_file(): diff --git a/src/lightning/app/utilities/packaging/build_config.py b/src/lightning/app/utilities/packaging/build_config.py index fb247309d8..8a580da71d 100644 --- a/src/lightning/app/utilities/packaging/build_config.py +++ b/src/lightning/app/utilities/packaging/build_config.py @@ -81,6 +81,7 @@ class BuildConfig: image: The base image that the work runs on. This should be a publicly accessible image from a registry that doesn't enforce rate limits (such as DockerHub) to pull this image, otherwise your application will not start. + """ requirements: List[str] = field(default_factory=list) @@ -111,6 +112,7 @@ class BuildConfig: return ["apt-get install libsparsehash-dev"] BuildConfig(requirements=["git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0"]) + """ return [] diff --git a/src/lightning/app/utilities/packaging/cloud_compute.py b/src/lightning/app/utilities/packaging/cloud_compute.py index 58ac06afdb..246c04b148 100644 --- a/src/lightning/app/utilities/packaging/cloud_compute.py +++ b/src/lightning/app/utilities/packaging/cloud_compute.py @@ -88,6 +88,7 @@ class CloudCompute: interruptible: Whether to run on a interruptible machine e.g the machine can be stopped at any time by the providers. This is also known as spot or preemptible machines. Compared to on-demand machines, they tend to be cheaper. + """ name: str = "default" diff --git a/src/lightning/app/utilities/packaging/lightning_utils.py b/src/lightning/app/utilities/packaging/lightning_utils.py index 3852c941ed..c49a6d2d88 100644 --- a/src/lightning/app/utilities/packaging/lightning_utils.py +++ b/src/lightning/app/utilities/packaging/lightning_utils.py @@ -108,10 +108,11 @@ def get_dist_path_if_editable_install(project_name) -> str: def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = "lightning") -> Optional[Callable]: - """This function determines if lightning is installed in editable mode (for developers) and packages the - current lightning source along with the app. + """This function determines if lightning is installed in editable mode (for developers) and packages the current + lightning source along with the app. For normal users who install via PyPi or Conda, then this function does not do anything. + """ if not get_dist_path_if_editable_install(package_name): return None diff --git a/src/lightning/app/utilities/proxies.py b/src/lightning/app/utilities/proxies.py index 39d3378506..9ef301f31f 100644 --- a/src/lightning/app/utilities/proxies.py +++ b/src/lightning/app/utilities/proxies.py @@ -152,6 +152,7 @@ class ProxyWorkRun: Currently, this performs a check against strings that look like filesystem paths and may need to be wrapped with a Lightning Path by the user. + """ def warn_if_pathlike(obj: Union[os.PathLike, str]): @@ -172,8 +173,8 @@ class ProxyWorkRun: @staticmethod def _process_call_args(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """Processes all positional and keyword arguments before they get passed to the caller queue and sent to - the LightningWork. + """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the + LightningWork. Currently, this method only applies sanitization to Lightning Path objects. @@ -183,6 +184,7 @@ class ProxyWorkRun: Returns: The positional and keyword arguments in the same order they were passed in. + """ def sanitize(obj: Union[Path, Drive]) -> Union[Path, Dict]: @@ -200,8 +202,8 @@ class ProxyWorkRun: @staticmethod def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """Processes all positional and keyword arguments before they get passed to the caller queue and sent to - the LightningWork. + """Processes all positional and keyword arguments before they get passed to the caller queue and sent to the + LightningWork. Currently, this method only applies sanitization to Hashable Objects. @@ -211,6 +213,7 @@ class ProxyWorkRun: Returns: The positional and keyword arguments in the same order they were passed in. + """ from lightning.app.utilities.types import Hashable @@ -221,9 +224,9 @@ class ProxyWorkRun: class WorkStateObserver(Thread): - """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed - from one interval to the next, it will compute the delta and add it to the queue which is connected to the - Flow. This enables state changes to be captured that are not triggered through a setattr call. + """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed from + one interval to the next, it will compute the delta and add it to the queue which is connected to the Flow. This + enables state changes to be captured that are not triggered through a setattr call. Args: work: The LightningWork for which the state should be monitored @@ -238,6 +241,7 @@ class WorkStateObserver(Thread): def run(self): # This update gets sent to the Flow once the thread compares the new state with the previous one self.list.append(1) + """ def __init__( @@ -690,6 +694,7 @@ def persist_artifacts(work: "LightningWork") -> None: storage. Files that don't exist or do not originate from the given Work will be skipped. + """ artifact_paths = [getattr(work, name) for name in work._paths] # only copy files that belong to this Work, i.e., when the path's origin refers to the current Work diff --git a/src/lightning/app/utilities/secrets.py b/src/lightning/app/utilities/secrets.py index 9ba758d4d2..57347c92a1 100644 --- a/src/lightning/app/utilities/secrets.py +++ b/src/lightning/app/utilities/secrets.py @@ -22,6 +22,7 @@ def _names_to_ids(secret_names: Iterable[str]) -> Dict[str, str]: """Returns the name/ID pair for each given Secret name. Raises a `ValueError` if any of the given Secret names do not exist. + """ lightning_client = LightningClient() diff --git a/src/lightning/app/utilities/state.py b/src/lightning/app/utilities/state.py index 014eaa4fdf..df187776d6 100644 --- a/src/lightning/app/utilities/state.py +++ b/src/lightning/app/utilities/state.py @@ -93,6 +93,7 @@ class AppState: my_affiliation: A tuple describing the affiliation this app state represents. When storing a state dict on this AppState, this affiliation will be used to reduce the scope of the given state. plugin: A plugin to handle authorization. + """ self._use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ self._host = host or ("http://127.0.0.1" if self._use_localhost else None) @@ -132,6 +133,7 @@ class AppState: For example, if the affiliation is ``("root", "subflow")``, then the returned state will be ``state["flows"]["subflow"]``. + """ children_state = state for name in my_affiliation: diff --git a/src/lightning/app/utilities/tracer.py b/src/lightning/app/utilities/tracer.py index c5a0e56b01..fe44f91947 100644 --- a/src/lightning/app/utilities/tracer.py +++ b/src/lightning/app/utilities/tracer.py @@ -101,6 +101,7 @@ class Tracer: Optionally provide two functions that will execute prior to and after the method. The functions also have a chance to modify the input arguments and the return values of the methods. + """ self.methods.append((cls, method_name, stack_level, pre_fn, post_fn)) @@ -108,6 +109,7 @@ class Tracer: """Modify classes by wrapping methods that need to be traced. Initialize the output trace dict. + """ self.res = {} for cls, method, stack_level, pre_fn, post_fn in self.methods: @@ -163,6 +165,7 @@ class Tracer: """Execute the command-line arguments in args after instrumenting for tracing. Restore the classes to their initial state after tracing. + """ args = list(args) script = args[0] diff --git a/src/lightning/app/utilities/tree.py b/src/lightning/app/utilities/tree.py index 5dafaee6bf..69ae7d144b 100644 --- a/src/lightning/app/utilities/tree.py +++ b/src/lightning/app/utilities/tree.py @@ -26,6 +26,7 @@ def breadth_first(root: "Component", types: Type["ComponentTuple"] = None): Arguments: root: The root component of the tree types: If provided, only the component types in this list will be visited. + """ yield from _BreadthFirstVisitor(root, types) diff --git a/src/lightning/data/backends.py b/src/lightning/data/backends.py index 1d4dbcccd4..5fcb32eb78 100644 --- a/src/lightning/data/backends.py +++ b/src/lightning/data/backends.py @@ -29,6 +29,7 @@ class S3DatasetBackend: Returns: credentials object to be used for file reading + """ from botocore.credentials import InstanceMetadataProvider from botocore.utils import InstanceMetadataFetcher diff --git a/src/lightning/data/datasets/base.py b/src/lightning/data/datasets/base.py index d0ebc16340..c8ee89ac96 100644 --- a/src/lightning/data/datasets/base.py +++ b/src/lightning/data/datasets/base.py @@ -11,6 +11,7 @@ class _Dataset(TorchDataset): Args: backend: storage location of the data_source. current options are "s3" or "local" + """ def __init__(self, backend: Literal["local", "s3"] = "local"): @@ -31,6 +32,7 @@ class _Dataset(TorchDataset): Returns: A stream object of the file. + """ return OpenCloudFileObj( path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs diff --git a/src/lightning/data/datasets/env.py b/src/lightning/data/datasets/env.py index 4e923f3473..51a9f21271 100644 --- a/src/lightning/data/datasets/env.py +++ b/src/lightning/data/datasets/env.py @@ -10,6 +10,7 @@ class _DistributedEnv: Args: world_size: The number of total distributed training processes global_rank: The rank of the current process within this pool of training processes + """ def __init__(self, world_size: int, global_rank: int): @@ -24,6 +25,7 @@ class _DistributedEnv: This detection may not work in processes spawned from the distributed processes (e.g. DataLoader workers) as the distributed framework won't be initialized there. It will default to 1 distributed process in this case. + """ if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() @@ -50,6 +52,7 @@ class _WorkerEnv: Args: world_size: The number of dataloader workers for the current training process rank: The rank of the current worker within the number of workers + """ def __init__(self, world_size: int, rank: int): @@ -63,6 +66,7 @@ class _WorkerEnv: Note: This only works reliably within a dataloader worker as otherwise the necessary information won't be present. In such a case it will default to 1 worker + """ worker_info = get_worker_info() num_workers = worker_info.num_workers if worker_info is not None else 1 @@ -83,6 +87,7 @@ class Environment: Args: dist_env: The distributed environment (distributed worldsize and global rank) worker_env: The worker environment (number of workers, worker rank) + """ def __init__(self, dist_env: Optional[_DistributedEnv], worker_env: Optional[_WorkerEnv]): @@ -105,6 +110,7 @@ class Environment: num_workers: The number of workers per distributed training process current_worker_rank: The rank of the current worker within the number of workers of the current training process + """ dist_env = _DistributedEnv(dist_world_size, global_rank) worker_env = _WorkerEnv(num_workers, current_worker_rank) @@ -117,6 +123,7 @@ class Environment: Note: This may not be accurate in a non-dataloader-worker process like the main training process as it doesn't necessarily know about the number of dataloader workers. + """ assert self.worker_env is not None assert self.dist_env is not None @@ -129,6 +136,7 @@ class Environment: Note: This may not be accurate in a non-dataloader-worker process like the main training process as it doesn't necessarily know about the number of dataloader workers. + """ assert self.worker_env is not None assert self.dist_env is not None diff --git a/src/lightning/data/datasets/index.py b/src/lightning/data/datasets/index.py index 9a0b81b42f..34ecaaf095 100644 --- a/src/lightning/data/datasets/index.py +++ b/src/lightning/data/datasets/index.py @@ -10,6 +10,7 @@ def get_index(s3_connection_path: str, index_file_path: str) -> bool: Returns: Returns True is the index got created and False if it wasn't + """ if s3_connection_path.startswith("/data/"): @@ -82,6 +83,7 @@ def _get_index(data_connection_path: str, index_file_path: str) -> bool: Returns: True if the index retrieved + """ PROJECT_ID_ENV = "LCP_ID" diff --git a/src/lightning/data/datasets/iterable.py b/src/lightning/data/datasets/iterable.py index f97b16f89e..54388138bc 100644 --- a/src/lightning/data/datasets/iterable.py +++ b/src/lightning/data/datasets/iterable.py @@ -29,6 +29,7 @@ class _Chunk: chunk_data: The original data contained by this chunk chunk_size: The number of samples contained in this chunk start_index: the index from where to start sampling the chunk (already retrieved samples) + """ def __init__(self, chunk_data: Any, chunk_size: int, start_index: int = 0): @@ -62,8 +63,8 @@ class _Chunk: class LightningIterableDataset(_StatefulIterableDataset, _Dataset): - """An iterable dataset that can be resumed mid-epoch, implements chunking and sharding of chunks. The behavior - of this dataset can be customized with the following hooks: + """An iterable dataset that can be resumed mid-epoch, implements chunking and sharding of chunks. The behavior of + this dataset can be customized with the following hooks: - ``prepare_chunk`` gives the possibility to prepare the chunk one iteration before its actually loaded (e.g. download from s3). @@ -100,6 +101,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): Note: Order of data is only guaranteed when resuming with the same distributed settings and the same number of workers. Everything else leads to different sharding and therefore results in different data order. + """ def __init__( @@ -156,11 +158,12 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): @abstractmethod def load_chunk(self, chunk: Any) -> Any: - """Implement this to load a single chunk into memory. This could e.g. mean loading the file that has - previously been downloaded from s3. + """Implement this to load a single chunk into memory. This could e.g. mean loading the file that has previously + been downloaded from s3. Args: chunk: The chunk that should be currently loaded + """ @abstractmethod @@ -171,6 +174,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): Args: chunk: The chunk the sample should be retrieved from index: The index of the current sample to retrieve within the chunk. + """ def prepare_chunk(self, chunk: Any) -> None: @@ -178,6 +182,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): Args: chunk: the chunk data to prepare. + """ def __iter__(self) -> "LightningIterableDataset": @@ -185,6 +190,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): Before that, detects the env if necessary, shuffles chunks, shards the data and shuffles sample orders within chunks. + """ self._curr_chunk_index = self._start_index_chunk self._curr_sample_index = self._start_index_sample @@ -206,6 +212,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): """Returns the next sample. If necessary, this also loads the new chunks. + """ self._check_if_sharded() self._ensure_chunks_loaded() @@ -243,6 +250,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): returned_samples: the number of totally returned samples by the dataloader(s) (across all distributed training processes). num_workers: number of dataloader workers per distributed training process. + """ # compute indices locally again since other workers may have different offsets @@ -275,6 +283,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): Note: Some of the changes only take effect when creating a new iterator + """ state_dict = deepcopy(state_dict) self._start_index_chunk = state_dict.pop("current_chunk") @@ -322,6 +331,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): """Shards the chunks if necessary. No-op if already sharded + """ if not self._local_chunks: num_shards = self._env.num_shards @@ -365,6 +375,7 @@ class LightningIterableDataset(_StatefulIterableDataset, _Dataset): first_chunk_index: The point to which the generator should be replayed shuffle_chunk_order: Whether to shuffle the order of chunks shuffle_sample_order: Whether to shuffle the order of samples within a chunk + """ # re-seed generator if self._generator is not None and self._initial_generator_state is not None: diff --git a/src/lightning/data/datasets/mapping.py b/src/lightning/data/datasets/mapping.py index 046824cb2d..811c8fa909 100644 --- a/src/lightning/data/datasets/mapping.py +++ b/src/lightning/data/datasets/mapping.py @@ -15,6 +15,7 @@ class LightningDataset(_Dataset, ABC): data_source: path of data directory. ex. s3://mybucket/path backend: storage location of the data_source. current options are "s3" or "local" path_to_index_file: path to index file that lists all file contents of the data_source. + """ def __init__( @@ -36,6 +37,7 @@ class LightningDataset(_Dataset, ABC): Returns: The contents of the index file (all the file paths in the data_source) + """ if not os.path.isfile(self.index_file): get_index(self.data_source, self.index_file) @@ -49,6 +51,7 @@ class LightningDataset(_Dataset, ABC): Returns: The loaded item + """ file_path = self.files[idx] @@ -66,6 +69,7 @@ class LightningDataset(_Dataset, ABC): """Loads each sample in the dataset. Any data prep/cleaning logic goes here. For ex. image transformations, text cleaning, etc. + """ pass diff --git a/src/lightning/data/fileio.py b/src/lightning/data/fileio.py index a6e2ec887d..abacb1acfe 100644 --- a/src/lightning/data/fileio.py +++ b/src/lightning/data/fileio.py @@ -21,6 +21,7 @@ def path_to_url(path: str, bucket_name: str, bucket_root_path: str = "/") -> str Returns: Full S3 url path + """ if not path.startswith(bucket_root_path): raise ValueError(f"Cannot create a path from {path} relative to {bucket_root_path}") @@ -36,6 +37,7 @@ def open_single_file( Returns: The opened file stream. + """ from torchdata.datapipes.iter import FSSpecFileOpener, IterableWrapper @@ -54,6 +56,7 @@ def open_single_file_with_retry( Returns: The opened file stream. + """ from torchdata.datapipes.iter import FSSpecFileOpener, IterableWrapper @@ -83,6 +86,7 @@ class OpenCloudFileObj: mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default). kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``). + """ def __init__( diff --git a/src/lightning/fabric/_graveyard/tpu.py b/src/lightning/fabric/_graveyard/tpu.py index 2a45f928b3..b38cfc4708 100644 --- a/src/lightning/fabric/_graveyard/tpu.py +++ b/src/lightning/fabric/_graveyard/tpu.py @@ -34,6 +34,7 @@ class SingleTPUStrategy(SingleDeviceXLAStrategy): """Legacy class. Use :class:`~lightning.fabric.strategies.single_xla.SingleDeviceXLAStrategy` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -50,6 +51,7 @@ class TPUAccelerator(XLAAccelerator): """Legacy class. Use :class:`~lightning.fabric.accelerators.xla.XLAAccelerator` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -63,6 +65,7 @@ class TPUPrecision(XLAPrecision): """Legacy class. Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -76,6 +79,7 @@ class TPUBf16Precision(XLABf16Precision): """Legacy class. Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/lightning/fabric/accelerators/accelerator.py b/src/lightning/fabric/accelerators/accelerator.py index f843f05f21..3a8aa85ad0 100644 --- a/src/lightning/fabric/accelerators/accelerator.py +++ b/src/lightning/fabric/accelerators/accelerator.py @@ -25,6 +25,7 @@ class Accelerator(ABC): An Accelerator is meant to deal with one type of hardware. .. warning:: Writing your own accelerator is an :ref:`experimental ` feature. + """ @abstractmethod diff --git a/src/lightning/fabric/accelerators/cuda.py b/src/lightning/fabric/accelerators/cuda.py index 2f19529597..c79c943afc 100644 --- a/src/lightning/fabric/accelerators/cuda.py +++ b/src/lightning/fabric/accelerators/cuda.py @@ -88,6 +88,7 @@ def find_usable_cuda_devices(num_devices: int = -1) -> List[int]: Warning: If multiple processes call this function at the same time, there can be race conditions in the case where both processes determine that the device is unoccupied, leading into one of them crashing later on. + """ visible_devices = _get_all_visible_cuda_devices() if not visible_devices: @@ -128,6 +129,7 @@ def _get_all_visible_cuda_devices() -> List[int]: Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you have 8 physical GPUs. If ``CUDA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]`` because these are the three visible GPUs after applying the mask ``CUDA_VISIBLE_DEVICES``. + """ return list(range(num_cuda_devices())) @@ -135,8 +137,7 @@ def _get_all_visible_cuda_devices() -> List[int]: # TODO: Remove once minimum supported PyTorch version is 2.0 @contextmanager def _patch_cuda_is_available() -> Generator: - """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if - possible.""" + """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if possible.""" if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_2_0: # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding # otherwise, patching is_available could lead to attribute errors or infinite recursion @@ -156,6 +157,7 @@ def num_cuda_devices() -> int: Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support, if the platform allows it. + """ if _TORCH_GREATER_EQUAL_2_0: return torch.cuda.device_count() @@ -171,6 +173,7 @@ def is_cuda_available() -> bool: Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support, if the platform allows it. + """ # We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning.fabric.__init__.py return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0 @@ -311,6 +314,7 @@ def _device_count_nvml() -> int: """Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. Negative value is returned if NVML discovery or initialization has failed. + """ visible_devices = _parse_visible_devices() if not visible_devices: diff --git a/src/lightning/fabric/accelerators/mps.py b/src/lightning/fabric/accelerators/mps.py index cb6ffddbd9..1126f01d1e 100644 --- a/src/lightning/fabric/accelerators/mps.py +++ b/src/lightning/fabric/accelerators/mps.py @@ -26,6 +26,7 @@ class MPSAccelerator(Accelerator): """Accelerator for Metal Apple Silicon GPU devices. .. warning:: Use of this accelerator beyond import and instantiation is experimental. + """ def setup_device(self, device: torch.device) -> None: diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index f8d79dc1b6..68b2b98f45 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -41,6 +41,7 @@ class _AcceleratorRegistry(dict): or AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True) + """ def register( @@ -59,6 +60,7 @@ class _AcceleratorRegistry(dict): description : accelerator description override : overrides the registered accelerator, if True init_params: parameters to initialize the accelerator + """ if not (name is None or isinstance(name, str)): raise TypeError(f"`name` must be a str, found {name}") @@ -87,6 +89,7 @@ class _AcceleratorRegistry(dict): Args: name (str): the name that identifies a accelerator, e.g. "gpu" + """ if name in self: data = self[name] diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index fbe29eee64..207eefe533 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -26,6 +26,7 @@ class XLAAccelerator(Accelerator): """Accelerator for XLA devices, normally TPUs. .. warning:: Use of this accelerator beyond import and instantiation is experimental. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index abcb8f195a..24943a528b 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -33,8 +33,8 @@ _SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") def _get_supported_strategies() -> List[str]: - """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from - the CLI or ones that require further configuration by the user.""" + """Returns strategy choices from the registry, with the ones removed that are incompatible to be launched from the + CLI or ones that require further configuration by the user.""" available_strategies = STRATEGY_REGISTRY.available_strategies() excluded = r".*(spawn|fork|notebook|xla|tpu|offload).*" return [strategy for strategy in available_strategies if not re.match(excluded, strategy)] @@ -122,6 +122,7 @@ if _CLICK_AVAILABLE: SCRIPT_ARGS are the remaining arguments that you can pass to the script itself and are expected to be parsed there. + """ script_args = list(kwargs.pop("script_args", [])) main(args=Namespace(**kwargs), script_args=script_args) @@ -131,6 +132,7 @@ def _set_env_variables(args: Namespace) -> None: """Set the environment variables for the new processes. The Fabric connector will parse the arguments set here. + """ os.environ["LT_CLI_USED"] = "1" if args.accelerator is not None: diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 401d6d4912..0852e1e179 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -98,6 +98,7 @@ class _Connector: priorities which to take when: A. Class > str B. Strategy > Accelerator/precision/plugins + """ def __init__( @@ -182,6 +183,7 @@ class _Connector: 4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others. Additionally, other flags such as `precision` can populate the list with the corresponding plugin instances. + """ if plugins is not None: plugins = [plugins] if not isinstance(plugins, list) else plugins @@ -390,8 +392,8 @@ class _Connector: return "ddp" def _check_strategy_and_fallback(self) -> None: - """Checks edge cases when the strategy selection was a string input, and we need to fall back to a - different choice depending on other parameters or the environment.""" + """Checks edge cases when the strategy selection was a string input, and we need to fall back to a different + choice depending on other parameters or the environment.""" # current fallback and check logic only apply to user pass in str config and object config # TODO this logic should apply to both str and object config strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 2a9e6088b4..0f229a5538 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -96,6 +96,7 @@ class Fabric: can be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user. loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for more information. + """ def __init__( @@ -147,6 +148,7 @@ class Fabric: """The current device this process runs on. Use this to create tensors directly on the device if needed. + """ return self._strategy.root_device @@ -189,6 +191,7 @@ class Fabric: """All the code inside this run method gets accelerated by Fabric. You can pass arbitrary arguments to this function when overriding it. + """ def setup( @@ -207,6 +210,7 @@ class Fabric: Returns: The tuple containing wrapped module and the optimizers, in the same order they were passed in. + """ self._validate_setup(module, optimizers) original_module = module @@ -264,6 +268,7 @@ class Fabric: Returns: The wrapped model. + """ self._validate_setup_module(module) original_module = module @@ -300,6 +305,7 @@ class Fabric: Returns: The wrapped optimizer(s). + """ self._validate_setup_optimizers(optimizers) optimizers = [self._strategy.setup_optimizer(optimizer) for optimizer in optimizers] @@ -326,6 +332,7 @@ class Fabric: Returns: The wrapped dataloaders, in the same order they were passed in. + """ self._validate_setup_dataloaders(dataloaders) dataloaders = [ @@ -353,6 +360,7 @@ class Fabric: Returns: The wrapped dataloader. + """ if use_distributed_sampler and self._requires_distributed_sampler(dataloader): sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) @@ -419,6 +427,7 @@ class Fabric: norm_type: The type of norm if `max_norm` was passed. Can be ``'inf'`` for infinity norm. Default is the 2-norm. error_if_nonfinite: An error is raised if the total norm of the gradients is NaN or infinite. + """ if clip_val is not None and max_norm is not None: raise ValueError( @@ -444,6 +453,7 @@ class Fabric: Use this only if the `forward` method of your model does not cover all operations you wish to run with the chosen precision setting. + """ with self._precision.forward_context(): yield @@ -461,8 +471,8 @@ class Fabric: ... def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: - """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already - on that device. + """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on + that device. Args: obj: An object to move to the device. Can be an instance of :class:`torch.nn.Module`, a tensor, or a @@ -470,6 +480,7 @@ class Fabric: Returns: A reference to the object that was moved to the new device. + """ if isinstance(obj, nn.Module): self._accelerator.setup_device(self.device) @@ -482,6 +493,7 @@ class Fabric: process in each machine. Arguments passed to this method are forwarded to the Python built-in :func:`print` function. + """ if self.local_rank == 0: print(*args, **kwargs) @@ -492,6 +504,7 @@ class Fabric: Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down. This method needs to be called on all processes. Failing to do so will cause your program to stall forever. + """ self._validate_launched() self._strategy.barrier(name=name) @@ -508,6 +521,7 @@ class Fabric: Return: The transferred data, the same value on every rank. + """ self._validate_launched() return self._strategy.broadcast(obj, src=src) @@ -528,6 +542,7 @@ class Fabric: Return: A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. + """ self._validate_launched() group = group if group is not None else torch.distributed.group.WORLD @@ -554,6 +569,7 @@ class Fabric: Return: A tensor of the same shape as the input with values reduced pointwise across processes. The same is applied to tensors in a collection if a collection is given as input. + """ self._validate_launched() group = group if group is not None else torch.distributed.group.WORLD @@ -573,6 +589,7 @@ class Fabric: with fabric.rank_zero_first(): dataset = MNIST("datasets/", download=True) + """ rank = self.local_rank if local else self.global_rank if rank > 0: @@ -604,6 +621,7 @@ class Fabric: module: The module for which to control the gradient synchronization. enabled: Whether the context manager is enabled or not. ``True`` means skip the sync, ``False`` means do not skip. + """ module = _unwrap_compiled(module) if not isinstance(module, _FabricModule): @@ -633,6 +651,7 @@ class Fabric: """Instantiate a model under this context manager to prepare it for model-parallel sharding. .. deprecated:: This context manager is deprecated in favor of :meth:`init_module`, use it instead. + """ rank_zero_deprecation("`Fabric.sharded_model()` is deprecated in favor of `Fabric.init_module()`.") if isinstance(self.strategy, _Sharded): @@ -643,10 +662,11 @@ class Fabric: @contextmanager def init_tensor(self) -> Generator: - """Tensors that you instantiate under this context manager will be created on the device right away and - have the right data type depending on the precision setting in Fabric. + """Tensors that you instantiate under this context manager will be created on the device right away and have + the right data type depending on the precision setting in Fabric. The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer. + """ if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu": rank_zero_warn( @@ -670,6 +690,7 @@ class Fabric: empty_init: Whether to initialize the model with empty weights (uninitialized memory). If ``None``, the strategy will decide. Some strategies may not support all options. Set this to ``True`` if you are loading a checkpoint into a large model. Requires `torch >= 1.13`. + """ if not _TORCH_GREATER_EQUAL_2_0 and self.device.type != "cpu": rank_zero_warn( @@ -700,6 +721,7 @@ class Fabric: filter: An optional dictionary containing filter callables that return a boolean indicating whether the given item should be saved (``True``) or filtered out (``False``). Each filter key should match a state key, where its filter will be applied to the ``state_dict`` generated. + """ if filter is not None: if not isinstance(filter, dict): @@ -734,6 +756,7 @@ class Fabric: Returns: The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned. + """ unwrapped_state = _unwrap_objects(state) remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict) @@ -760,6 +783,7 @@ class Fabric: obj: A :class:`~torch.nn.Module` or :class:`~torch.optim.Optimizer` instance. strict: Whether to enforce that the keys in the module's state-dict match the keys in the checkpoint. Does not apply to optimizers. + """ obj = _unwrap_objects(obj) self._strategy.load_checkpoint(path=path, state=obj, strict=strict) @@ -782,6 +806,7 @@ class Fabric: ``launch()`` from your code. ``launch()`` is a no-op when called multiple times and no function is passed in. + """ if _is_using_cli(): raise RuntimeError( @@ -825,6 +850,7 @@ class Fabric: fabric = Fabric(callbacks=[MyCallback()]) fabric.call("on_train_epoch_end", results={...}) + """ for callback in self._callbacks: method = getattr(callback, hook_name, None) @@ -853,6 +879,7 @@ class Fabric: graph automatically. step: Optional step number. Most Logger implementations auto-increment the step value by one with every log call. You can specify your own value here. + """ self.log_dict(metrics={name: value}, step=step) @@ -864,6 +891,7 @@ class Fabric: Any :class:`torch.Tensor` in the dictionary get detached from the graph automatically. step: Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here. + """ metrics = convert_tensors_to_scalars(metrics) for logger in self._loggers: @@ -874,6 +902,7 @@ class Fabric: """Helper function to seed everything without explicitly importing Lightning. See :func:`lightning.fabric.utilities.seed.seed_everything` for more details. + """ if workers is None: # Lightning sets `workers=False` by default to avoid breaking reproducibility, but since this is a new diff --git a/src/lightning/fabric/loggers/csv_logs.py b/src/lightning/fabric/loggers/csv_logs.py index fd0c69c3ab..d55860056e 100644 --- a/src/lightning/fabric/loggers/csv_logs.py +++ b/src/lightning/fabric/loggers/csv_logs.py @@ -49,6 +49,7 @@ class CSVLogger(Logger): logger = CSVLogger("path/to/logs/root", name="my_model") logger.log_metrics({"loss": 0.235, "acc": 0.75}) logger.finalize("success") + """ LOGGER_JOIN_CHAR = "-" @@ -77,6 +78,7 @@ class CSVLogger(Logger): Returns: The name of the experiment. + """ return self._name @@ -86,6 +88,7 @@ class CSVLogger(Logger): Returns: The version of the experiment if it is specified, else the next version. + """ if self._version is None: self._version = self._get_next_version() @@ -102,6 +105,7 @@ class CSVLogger(Logger): By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the constructor's version parameter instead of ``None`` or an int. + """ # create a pseudo standard path version = self.version if isinstance(self.version, str) else f"version_{self.version}" @@ -110,12 +114,12 @@ class CSVLogger(Logger): @property @rank_zero_experiment def experiment(self) -> "_ExperimentWriter": - """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the - following. + """Actual ExperimentWriter object. To use ExperimentWriter features anywhere in your code, do the following. Example:: self.logger.experiment.some_experiment_writer_function() + """ if self._experiment is not None: return self._experiment @@ -177,6 +181,7 @@ class _ExperimentWriter: Args: log_dir: Directory for the experiment logs + """ NAME_METRICS_FILE = "metrics.csv" diff --git a/src/lightning/fabric/loggers/logger.py b/src/lightning/fabric/loggers/logger.py index 8efa974bd4..5647ab9c1c 100644 --- a/src/lightning/fabric/loggers/logger.py +++ b/src/lightning/fabric/loggers/logger.py @@ -39,8 +39,8 @@ class Logger(ABC): @property def root_dir(self) -> Optional[str]: - """Return the root directory where all versions of an experiment get saved, or `None` if the logger does - not save data locally.""" + """Return the root directory where all versions of an experiment get saved, or `None` if the logger does not + save data locally.""" return None @property @@ -61,6 +61,7 @@ class Logger(ABC): Args: metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded + """ pass @@ -72,6 +73,7 @@ class Logger(ABC): params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters args: Optional positional arguments, depends on the specific logger being used kwargs: Optional keyword arguments, depends on the specific logger being used + """ def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: @@ -80,6 +82,7 @@ class Logger(ABC): Args: model: the model with an implementation of ``forward``. input_array: input passes to `model.forward` + """ pass @@ -91,6 +94,7 @@ class Logger(ABC): Args: status: Status that the experiment finished with (e.g. success, failed, aborted) + """ self.save() diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 3708ac9b73..8098881c7d 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -75,6 +75,7 @@ class TensorBoardLogger(Logger): logger.log_hyperparams({"epochs": 5, "optimizer": "Adam"}) logger.log_metrics({"acc": 0.75}) logger.finalize("success") + """ LOGGER_JOIN_CHAR = "-" @@ -113,6 +114,7 @@ class TensorBoardLogger(Logger): Returns: The name of the experiment. + """ return self._name @@ -122,6 +124,7 @@ class TensorBoardLogger(Logger): Returns: The experiment version if specified else the next version. + """ if self._version is None: self._version = self._get_next_version() @@ -133,6 +136,7 @@ class TensorBoardLogger(Logger): Returns: The local path to the save directory where the TensorBoard experiments are saved. + """ return self._root_dir @@ -142,6 +146,7 @@ class TensorBoardLogger(Logger): By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the constructor's version parameter instead of ``None`` or an int. + """ version = self.version if isinstance(self.version, str) else f"version_{self.version}" log_dir = os.path.join(self.root_dir, self.name, version) @@ -157,6 +162,7 @@ class TensorBoardLogger(Logger): Returns: The local path to the sub directory where the TensorBoard experiments are saved. + """ return self._sub_dir @@ -168,6 +174,7 @@ class TensorBoardLogger(Logger): Example:: logger.experiment.some_tensorboard_function() + """ if self._experiment is not None: return self._experiment @@ -210,12 +217,13 @@ class TensorBoardLogger(Logger): self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the - hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs - to display the new ones with hyperparameters. + hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to + display the new ones with hyperparameters. Args: params: a dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values + """ params = _convert_params(params) diff --git a/src/lightning/fabric/plugins/collectives/collective.py b/src/lightning/fabric/plugins/collectives/collective.py index 9a2c399883..5c655f189c 100644 --- a/src/lightning/fabric/plugins/collectives/collective.py +++ b/src/lightning/fabric/plugins/collectives/collective.py @@ -13,6 +13,7 @@ class Collective(ABC): Supports communications between multiple processes and multiple nodes. A collective owns a group. .. warning:: This is an :ref:`experimental ` feature which is still in development. + """ def __init__(self) -> None: @@ -120,6 +121,7 @@ class Collective(ABC): This assumes that :meth:`~lightning.fabric.plugins.collectives.Collective.init_group` has been called already by the user. + """ if self._group is not None: raise RuntimeError(f"`{type(self).__name__}` already owns a group.") diff --git a/src/lightning/fabric/plugins/collectives/single_device.py b/src/lightning/fabric/plugins/collectives/single_device.py index 88c24c4892..fee7a05f79 100644 --- a/src/lightning/fabric/plugins/collectives/single_device.py +++ b/src/lightning/fabric/plugins/collectives/single_device.py @@ -10,6 +10,7 @@ class SingleDeviceCollective(Collective): """Support for collective operations on a single device (no-op). .. warning:: This is an :ref:`experimental ` feature which is still in development. + """ @property diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 05c2a5b4b2..50b9a49975 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -21,6 +21,7 @@ class TorchCollective(Collective): """Collective operations using `torch.distributed `__. .. warning:: This is an :ref:`experimental ` feature which is still in development. + """ manages_default_group = False diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py index 967c4682f2..3b14fbccae 100644 --- a/src/lightning/fabric/plugins/environments/kubeflow.py +++ b/src/lightning/fabric/plugins/environments/kubeflow.py @@ -28,6 +28,7 @@ class KubeflowEnvironment(ClusterEnvironment): .. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ .. _Kubeflow: https://www.kubeflow.org + """ @property diff --git a/src/lightning/fabric/plugins/environments/lightning.py b/src/lightning/fabric/plugins/environments/lightning.py index efb4968cc6..8a717e3bf4 100644 --- a/src/lightning/fabric/plugins/environments/lightning.py +++ b/src/lightning/fabric/plugins/environments/lightning.py @@ -32,6 +32,7 @@ class LightningEnvironment(ClusterEnvironment): If the main address and port are not provided, the default environment will choose them automatically. It is recommended to use this default environment for single-node distributed training as it provides a convenient way to launch the training script. + """ def __init__(self) -> None: @@ -46,6 +47,7 @@ class LightningEnvironment(ClusterEnvironment): If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the process launcher/job scheduler and Lightning will not launch new processes. + """ return "LOCAL_RANK" in os.environ @@ -93,6 +95,7 @@ def find_free_network_port() -> int: It is useful in single-node training when we don't want to connect to a real main node but have to set the `MASTER_PORT` environment variable. + """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) diff --git a/src/lightning/fabric/plugins/environments/lsf.py b/src/lightning/fabric/plugins/environments/lsf.py index 8500a3e40f..2eb089f580 100644 --- a/src/lightning/fabric/plugins/environments/lsf.py +++ b/src/lightning/fabric/plugins/environments/lsf.py @@ -44,6 +44,7 @@ class LSFEnvironment(ClusterEnvironment): ``JSM_NAMESPACE_RANK`` The global rank for the task. This environment variable is set by ``jsrun`` + """ def __init__(self) -> None: @@ -128,6 +129,7 @@ class LSFEnvironment(ClusterEnvironment): The node rank is determined by the position of the current node in the list of hosts used in the job. This is calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list. + """ hosts = self._read_hosts() count: Dict[str, int] = {} @@ -143,6 +145,7 @@ class LSFEnvironment(ClusterEnvironment): LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes. Each job is assigned a launch node. This launch node will be the first node in the list contained in ``LSB_DJOB_RANKFILE``. + """ var = "LSB_DJOB_RANKFILE" rankfile = os.environ.get(var) @@ -161,6 +164,7 @@ class LSFEnvironment(ClusterEnvironment): """A helper for getting the main address. The main address is assigned to the first node in the list of nodes used for the job. + """ hosts = self._read_hosts() return hosts[0] @@ -170,6 +174,7 @@ class LSFEnvironment(ClusterEnvironment): """A helper function for accessing the main port. Uses the LSF job ID so all ranks can compute the main port. + """ # check for user-specified main port if "MASTER_PORT" in os.environ: diff --git a/src/lightning/fabric/plugins/environments/mpi.py b/src/lightning/fabric/plugins/environments/mpi.py index a518da5f66..e40fe8b027 100644 --- a/src/lightning/fabric/plugins/environments/mpi.py +++ b/src/lightning/fabric/plugins/environments/mpi.py @@ -31,6 +31,7 @@ class MPIEnvironment(ClusterEnvironment): """An environment for running on clusters with processes created through MPI. Requires the installation of the `mpi4py` package. See also: https://github.com/mpi4py/mpi4py + """ def __init__(self) -> None: diff --git a/src/lightning/fabric/plugins/environments/slurm.py b/src/lightning/fabric/plugins/environments/slurm.py index 97fceb0218..b951ebedd5 100644 --- a/src/lightning/fabric/plugins/environments/slurm.py +++ b/src/lightning/fabric/plugins/environments/slurm.py @@ -39,6 +39,7 @@ class SLURMEnvironment(ClusterEnvironment): rescheduled gets determined by the owner of this plugin. requeue_signal: The signal that SLURM will send to indicate that the job should be requeued. Defaults to SIGUSR1 on Unix. + """ def __init__(self, auto_requeue: bool = True, requeue_signal: Optional[signal.Signals] = None) -> None: @@ -149,6 +150,7 @@ class SLURMEnvironment(ClusterEnvironment): - a space-separated list of host names, e.g., 'host0 host1 host3' yields 'host0' as the root - a comma-separated list of host names, e.g., 'host0,host1,host3' yields 'host0' as the root - the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root + """ nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number @@ -161,6 +163,7 @@ class SLURMEnvironment(ClusterEnvironment): Parallel jobs (multi-GPU, multi-node) in SLURM are launched by prepending `srun` in front of the Python command. Not doing so will result in processes hanging, which is a frequent user error. Lightning will emit a warning if `srun` is found but not used. + """ if _IS_WINDOWS: return @@ -176,12 +179,12 @@ class SLURMEnvironment(ClusterEnvironment): @staticmethod def _validate_srun_variables() -> None: - """Checks for conflicting or incorrectly set variables set through `srun` and raises a useful error - message. + """Checks for conflicting or incorrectly set variables set through `srun` and raises a useful error message. Right now, we only check for the most common user errors. See `the srun docs `_ for a complete list of supported srun variables. + """ ntasks = int(os.environ.get("SLURM_NTASKS", "1")) if ntasks > 1 and "SLURM_NTASKS_PER_NODE" not in os.environ: diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index 0aa4671684..97657087be 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -25,6 +25,7 @@ class XLAEnvironment(ClusterEnvironment): A list of environment variables set by XLA can be found `here `_. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/lightning/fabric/plugins/io/checkpoint_io.py b/src/lightning/fabric/plugins/io/checkpoint_io.py index 44e9f596d4..93e3a67b7b 100644 --- a/src/lightning/fabric/plugins/io/checkpoint_io.py +++ b/src/lightning/fabric/plugins/io/checkpoint_io.py @@ -42,6 +42,7 @@ class CheckpointIO(ABC): checkpoint: dict containing model and trainer state path: write-target path storage_options: Optional parameters when saving the model/training states. + """ @abstractmethod @@ -54,6 +55,7 @@ class CheckpointIO(ABC): locations. Returns: The loaded checkpoint. + """ @abstractmethod @@ -62,6 +64,7 @@ class CheckpointIO(ABC): Args: path: Path to checkpoint + """ def teardown(self) -> None: diff --git a/src/lightning/fabric/plugins/io/torch_io.py b/src/lightning/fabric/plugins/io/torch_io.py index 646f29ffcb..31f90d2b69 100644 --- a/src/lightning/fabric/plugins/io/torch_io.py +++ b/src/lightning/fabric/plugins/io/torch_io.py @@ -25,10 +25,11 @@ log = logging.getLogger(__name__) class TorchCheckpointIO(CheckpointIO): - """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints - respectively, common for most use cases. + """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, + common for most use cases. .. warning:: This is an :ref:`experimental ` feature. + """ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: @@ -42,6 +43,7 @@ class TorchCheckpointIO(CheckpointIO): Raises: TypeError: If ``storage_options`` arg is passed in + """ if storage_options is not None: raise TypeError( @@ -82,6 +84,7 @@ class TorchCheckpointIO(CheckpointIO): Args: path: Path to checkpoint + """ fs = get_filesystem(path) if fs.exists(path): diff --git a/src/lightning/fabric/plugins/io/xla.py b/src/lightning/fabric/plugins/io/xla.py index 509bdcdc6e..4a5c3ef96b 100644 --- a/src/lightning/fabric/plugins/io/xla.py +++ b/src/lightning/fabric/plugins/io/xla.py @@ -28,6 +28,7 @@ class XLACheckpointIO(TorchCheckpointIO): """CheckpointIO that utilizes :func:`xm.save` to save checkpoints for TPU training strategies. .. warning:: This is an :ref:`experimental ` feature. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -46,6 +47,7 @@ class XLACheckpointIO(TorchCheckpointIO): Raises: TypeError: If ``storage_options`` arg is passed in + """ if storage_options is not None: raise TypeError( diff --git a/src/lightning/fabric/plugins/precision/deepspeed.py b/src/lightning/fabric/plugins/precision/deepspeed.py index d51baa5939..6a1ebb17be 100644 --- a/src/lightning/fabric/plugins/precision/deepspeed.py +++ b/src/lightning/fabric/plugins/precision/deepspeed.py @@ -43,6 +43,7 @@ class DeepSpeedPrecision(Precision): Raises: ValueError: If unsupported ``precision`` is provided. + """ def __init__(self, precision: _PRECISION_INPUT) -> None: diff --git a/src/lightning/fabric/plugins/precision/double.py b/src/lightning/fabric/plugins/precision/double.py index 05419a5d00..8a2623a141 100644 --- a/src/lightning/fabric/plugins/precision/double.py +++ b/src/lightning/fabric/plugins/precision/double.py @@ -36,6 +36,7 @@ class DoublePrecision(Precision): """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float64) @@ -47,6 +48,7 @@ class DoublePrecision(Precision): """A context manager to change the default tensor type. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float64) diff --git a/src/lightning/fabric/plugins/precision/fsdp.py b/src/lightning/fabric/plugins/precision/fsdp.py index 826415568d..5c211dd025 100644 --- a/src/lightning/fabric/plugins/precision/fsdp.py +++ b/src/lightning/fabric/plugins/precision/fsdp.py @@ -47,6 +47,7 @@ class FSDPPrecision(Precision): Raises: ValueError: If unsupported ``precision`` is provided. + """ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradScaler"] = None) -> None: @@ -107,6 +108,7 @@ class FSDPPrecision(Precision): """A context manager to change the default tensor type when initializing module parameters or tensors. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(self.mixed_precision_config.param_dtype) diff --git a/src/lightning/fabric/plugins/precision/half.py b/src/lightning/fabric/plugins/precision/half.py index aa9ac52ffd..e4b011c57f 100644 --- a/src/lightning/fabric/plugins/precision/half.py +++ b/src/lightning/fabric/plugins/precision/half.py @@ -28,6 +28,7 @@ class HalfPrecision(Precision): Args: precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``). + """ precision: Literal["bf16-true", "16-true"] = "16-true" @@ -44,6 +45,7 @@ class HalfPrecision(Precision): """A context manager to change the default tensor type when initializing module parameters or tensors. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(self._desired_input_dtype) @@ -52,10 +54,10 @@ class HalfPrecision(Precision): @contextmanager def forward_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type when tensors get created during the module's - forward. + """A context manager to change the default tensor type when tensors get created during the module's forward. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(self._desired_input_dtype) diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index 1add95da88..76017b32c7 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -33,6 +33,7 @@ class Precision: """Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. + """ precision: _PRECISION_INPUT_STR = "32-true" @@ -41,6 +42,7 @@ class Precision: """Convert the module parameters to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. + """ return module @@ -49,6 +51,7 @@ class Precision: """Instantiate module parameters or tensors in the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. + """ yield @@ -62,6 +65,7 @@ class Precision: This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is torch.float32). + """ return data @@ -70,6 +74,7 @@ class Precision: This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is torch.float32). + """ return data @@ -79,6 +84,7 @@ class Precision: Args: tensor: The tensor that will be used for backpropagation module: The module that was involved in producing the tensor and whose parameters need the gradients + """ def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: @@ -87,6 +93,7 @@ class Precision: Args: tensor: The tensor that will be used for backpropagation model: The module that was involved in producing the tensor and whose parameters need the gradients + """ tensor.backward(*args, **kwargs) @@ -96,6 +103,7 @@ class Precision: Args: tensor: The tensor that will be used for backpropagation module: The module that was involved in producing the tensor and whose parameters need the gradients + """ def optimizer_step( @@ -110,6 +118,7 @@ class Precision: """The main params of the model. Returns the plain model params here. Maybe different in other precision plugins. + """ for group in optimizer.param_groups: yield from group["params"] @@ -122,6 +131,7 @@ class Precision: Returns: A dictionary containing precision plugin state. + """ return {} @@ -131,6 +141,7 @@ class Precision: Args: state_dict: the precision plugin state returned by ``state_dict``. + """ pass @@ -138,4 +149,5 @@ class Precision: """This method is called to teardown the training process. It is the right place to release memory and free other resources. + """ diff --git a/src/lightning/fabric/plugins/precision/transformer_engine.py b/src/lightning/fabric/plugins/precision/transformer_engine.py index 3d8d0c4ccf..23035b80b5 100644 --- a/src/lightning/fabric/plugins/precision/transformer_engine.py +++ b/src/lightning/fabric/plugins/precision/transformer_engine.py @@ -53,6 +53,7 @@ class TransformerEnginePrecision(Precision): Support for FP8 in the linear layers with `precision='transformer-engine'` is currently limited to tensors with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs to conform to this restriction. + """ precision: Literal["transformer-engine"] = "transformer-engine" diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index c6c65ef4dd..a912c5d438 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -136,6 +136,7 @@ class DDPStrategy(ParallelStrategy): Return: reduced value, except when the input was not a tensor the output remains is unchanged + """ if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 15bb698177..0fe3aa6e1b 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -220,6 +220,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. + """ if not _DEEPSPEED_AVAILABLE: raise ImportError( @@ -313,6 +314,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): Return: The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single deepspeed optimizer. + """ if len(optimizers) != 1: raise ValueError( @@ -328,6 +330,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): """Set up a module for inference (no optimizers). For training, see :meth:`setup_module_and_optimizers`. + """ self._deepspeed_engine, _ = self._initialize_engine(module) return self._deepspeed_engine @@ -336,6 +339,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): """Optimizers can only be set up jointly with the model in this strategy. Please use :meth:`setup_module_and_optimizers` to set up both module and optimizer together. + """ raise NotImplementedError(self._err_msg_joint_setup_required()) @@ -386,6 +390,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): ValueError: When no :class:`deepspeed.DeepSpeedEngine` objects were found in the state, or when multiple :class:`deepspeed.DeepSpeedEngine` objects were found. + """ if storage_options is not None: raise TypeError( @@ -451,6 +456,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): RuntimeError: If DeepSpeed was unable to load the checkpoint due to missing files or because the checkpoint is not in the expected DeepSpeed format. + """ if isinstance(state, (Module, Optimizer)) or self.load_full_weights and self.zero_stage_3: # This code path to enables loading a checkpoint from a non-deepspeed checkpoint or from @@ -565,6 +571,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): """Initialize one model and one optimizer with an optional learning rate scheduler. This calls :func:`deepspeed.initialize` internally. + """ import deepspeed @@ -720,12 +727,13 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): return cfg def _restore_zero_state(self, module: Module, ckpt: Mapping[str, Any]) -> None: - """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be - sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then - automatically synced across processes. + """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded + across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced + across processes. Args: ckpt: The ckpt file. + """ import deepspeed diff --git a/src/lightning/fabric/strategies/dp.py b/src/lightning/fabric/strategies/dp.py index d19d1a14d7..99beeca150 100644 --- a/src/lightning/fabric/strategies/dp.py +++ b/src/lightning/fabric/strategies/dp.py @@ -28,8 +28,8 @@ from lightning.fabric.utilities.distributed import ReduceOp class DataParallelStrategy(ParallelStrategy): - """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and - each gets a split of the data.""" + """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each + gets a split of the data.""" def __init__( self, diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index f5d2c5a8ef..07236abb40 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -137,6 +137,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): a folder with as many files as the world size. \**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. + """ def __init__( @@ -296,6 +297,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the flattened parameters. + """ if _TORCH_GREATER_EQUAL_2_0: return optimizer @@ -414,6 +416,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): optimizer state and other metadata. If the state-dict-type is ``'sharded'``, the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file `meta.pt` with the rest of the user's state (only saved from rank 0). + """ if not _TORCH_GREATER_EQUAL_2_0: raise NotImplementedError( @@ -511,6 +514,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file. + """ if not _TORCH_GREATER_EQUAL_2_0: raise NotImplementedError( @@ -846,8 +850,7 @@ def _load_raw_module_state_from_path(path: Path, module: Module, strict: bool = def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, strict: bool = True) -> None: - """Loads the state dict into the module by gathering all weights first and then and writing back to each - shard.""" + """Loads the state dict into the module by gathering all weights first and then and writing back to each shard.""" with _get_full_state_dict_context(module, rank0_only=False): module.load_state_dict(state_dict, strict=strict) @@ -879,6 +882,7 @@ def _apply_optimizers_during_fsdp_backward( By moving optimizer step invocation into the backward call we can free gradients earlier and reduce peak memory. + """ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles diff --git a/src/lightning/fabric/strategies/launchers/launcher.py b/src/lightning/fabric/strategies/launchers/launcher.py index f261f81124..c22a14633e 100644 --- a/src/lightning/fabric/strategies/launchers/launcher.py +++ b/src/lightning/fabric/strategies/launchers/launcher.py @@ -23,6 +23,7 @@ class _Launcher(ABC): Subclass this class and override any of the relevant methods to provide a custom implementation depending on cluster environment, hardware, strategy, etc. + """ @property diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index 9d90f7953b..66c766ecb5 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -50,6 +50,7 @@ class _MultiProcessingLauncher(_Launcher): - 'fork': Preferable for IPython/Jupyter environments where 'spawn' is not available. Not available on the Windows platform for example. - 'forkserver': Alternative implementation to 'fork'. + """ def __init__( @@ -82,6 +83,7 @@ class _MultiProcessingLauncher(_Launcher): function: The entry point for all launched processes. *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. + """ if self._start_method in ("fork", "forkserver"): _check_bad_cuda_fork() @@ -143,6 +145,7 @@ class _GlobalStateSnapshot: # in worker process snapshot.restore() + """ use_deterministic_algorithms: bool @@ -152,8 +155,7 @@ class _GlobalStateSnapshot: @classmethod def capture(cls) -> "_GlobalStateSnapshot": - """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker - process.""" + """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process.""" return cls( use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(), use_deterministic_algorithms_warn_only=torch.is_deterministic_algorithms_warn_only_enabled(), @@ -175,6 +177,7 @@ def _check_bad_cuda_fork() -> None: The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for Lightning users. + """ if not torch.cuda.is_initialized(): return diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index 27d23b859f..f171b3435e 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -65,6 +65,7 @@ class _SubprocessScriptLauncher(_Launcher): cluster_environment: A cluster environment that provides access to world size, node rank, etc. num_processes: The number of processes to launch in the current node. num_nodes: The total number of nodes that participate in this process group. + """ def __init__( @@ -91,6 +92,7 @@ class _SubprocessScriptLauncher(_Launcher): It is up to the implementation of this function to synchronize the processes, e.g., with barriers. *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. + """ if not self.cluster_environment.creates_processes_externally: self._call_children_scripts() diff --git a/src/lightning/fabric/strategies/launchers/xla.py b/src/lightning/fabric/strategies/launchers/xla.py index a538886018..d4a17256db 100644 --- a/src/lightning/fabric/strategies/launchers/xla.py +++ b/src/lightning/fabric/strategies/launchers/xla.py @@ -27,8 +27,8 @@ if TYPE_CHECKING: class _XLALauncher(_Launcher): - r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at - the end. + r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the + end. The main process in which this launcher is invoked creates N so-called worker processes (using the `torch_xla` :func:`xmp.spawn`) that run the given function. @@ -40,6 +40,7 @@ class _XLALauncher(_Launcher): Args: strategy: A reference to the strategy that is used together with this launcher + """ def __init__(self, strategy: Union["XLAStrategy", "XLAFSDPStrategy"]) -> None: @@ -62,6 +63,7 @@ class _XLALauncher(_Launcher): function: The entry point for all launched processes. *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. + """ using_pjrt = _using_pjrt() return_queue: Union[queue.Queue, mp.SimpleQueue] diff --git a/src/lightning/fabric/strategies/registry.py b/src/lightning/fabric/strategies/registry.py index 92c0417062..7956f7a95e 100644 --- a/src/lightning/fabric/strategies/registry.py +++ b/src/lightning/fabric/strategies/registry.py @@ -40,6 +40,7 @@ class _StrategyRegistry(dict): or StrategyRegistry.register("lightning", LightningStrategy, description="Super fast", a=1, b=True) + """ def register( @@ -58,6 +59,7 @@ class _StrategyRegistry(dict): description : strategy description override : overrides the registered strategy, if True init_params: parameters to initialize the strategy + """ if not (name is None or isinstance(name, str)): raise TypeError(f"`name` must be a str, found {name}") @@ -86,6 +88,7 @@ class _StrategyRegistry(dict): Args: name (str): the name that identifies a strategy, e.g. "deepspeed_stage_3" + """ if name in self: data = self[name] diff --git a/src/lightning/fabric/strategies/single_device.py b/src/lightning/fabric/strategies/single_device.py index 59ccf0810c..3edda3faf7 100644 --- a/src/lightning/fabric/strategies/single_device.py +++ b/src/lightning/fabric/strategies/single_device.py @@ -56,8 +56,8 @@ class SingleDeviceStrategy(Strategy): module.to(self.root_device) def all_reduce(self, tensor: Any | Tensor, *args: Any, **kwargs: Any) -> Any | Tensor: - """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only - operates with a single device, the reduction is simply the identity. + """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only operates + with a single device, the reduction is simply the identity. Args: tensor: the tensor to sync and reduce @@ -66,6 +66,7 @@ class SingleDeviceStrategy(Strategy): Return: the unmodified input as reduction is not needed for single process operation + """ return tensor diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 3e1913a6c1..3aaa8db0da 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -102,6 +102,7 @@ class Strategy(ABC): This must be called by the framework at the beginning of every process, before any distributed communication takes place. + """ assert self.accelerator is not None self.accelerator.setup_device(self.root_device) @@ -111,6 +112,7 @@ class Strategy(ABC): Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ return dataloader @@ -131,6 +133,7 @@ class Strategy(ABC): Args: empty_init: Whether to initialize the model with empty weights (uninitialized memory). If ``None``, the strategy will decide. Some strategies may not support all options. + """ empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() with empty_init_context, self.tensor_init_context(): @@ -143,6 +146,7 @@ class Strategy(ABC): The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`setup_module` and :meth:`setup_optimizer` on the inputs. + """ module = self.setup_module(module) optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers] @@ -169,6 +173,7 @@ class Strategy(ABC): Args: batch: The batch of samples to move to the correct device device: The target device + """ device = device or self.root_device return move_data_to_device(batch, device) @@ -189,6 +194,7 @@ class Strategy(ABC): Args: optimizer: the optimizer performing the step **kwargs: Any extra arguments to ``optimizer.step`` + """ return self.precision.optimizer_step(optimizer, **kwargs) @@ -200,6 +206,7 @@ class Strategy(ABC): tensor: the tensor to all_gather group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op + """ @abstractmethod @@ -216,6 +223,7 @@ class Strategy(ABC): group: the process group to reduce reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. + """ @abstractmethod @@ -224,6 +232,7 @@ class Strategy(ABC): Args: name: an optional name to pass into barrier. + """ @abstractmethod @@ -233,6 +242,7 @@ class Strategy(ABC): Args: obj: the object to broadcast src: source rank + """ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: @@ -256,6 +266,7 @@ class Strategy(ABC): filter: An optional dictionary containing filter callables that return a boolean indicating whether the given item should be saved (``True``) or filtered out (``False``). Each filter key should match a state key, where its filter will be applied to the ``state_dict`` generated. + """ state = self._convert_stateful_objects_in_state(state, filter=filter or {}) if self.is_global_zero: @@ -275,6 +286,7 @@ class Strategy(ABC): """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. + """ if hasattr(optimizer, "consolidate_state_dict"): # there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their @@ -307,6 +319,7 @@ class Strategy(ABC): Returns: The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned. + """ torch.cuda.empty_cache() checkpoint = self.checkpoint_io.load_checkpoint(path) @@ -338,6 +351,7 @@ class Strategy(ABC): """This method is called to teardown the training process. It is the right place to release memory and free other resources. + """ self.precision.teardown() assert self.accelerator is not None @@ -397,6 +411,7 @@ class _BackwardSyncControl(ABC): The most common use-case is gradient accumulation. If a :class:`Strategy` implements this interface, the user can implement their gradient accumulation loop very efficiently by disabling redundant gradient synchronization. + """ @contextmanager @@ -405,20 +420,21 @@ class _BackwardSyncControl(ABC): """Blocks the synchronization of gradients during the backward pass. This is a context manager. It is only effective if it wraps a call to `.backward()`. + """ class _Sharded(ABC): - """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model - parameters.""" + """Mixin-interface for any :class:`Strategy` that wants to expose functionality for sharding model parameters.""" @abstractmethod @contextmanager def module_sharded_context(self) -> Generator: - """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding - of parameters on creation. + """A context manager that goes over the instantiation of an :class:`torch.nn.Module` and handles sharding of + parameters on creation. By sharding layers directly on instantiation, one can reduce peak memory usage and initialization time. + """ yield diff --git a/src/lightning/fabric/strategies/xla.py b/src/lightning/fabric/strategies/xla.py index f5050889a5..955b8bfda0 100644 --- a/src/lightning/fabric/strategies/xla.py +++ b/src/lightning/fabric/strategies/xla.py @@ -151,6 +151,7 @@ class XLAStrategy(ParallelStrategy): sync_grads: flag that allows users to synchronize gradients for the all-gather operation. Return: A tensor of shape (world_size, ...) + """ if not self._launched: return tensor @@ -246,6 +247,7 @@ class XLAStrategy(ParallelStrategy): storage_options: Additional options for the ``CheckpointIO`` plugin filter: An optional dictionary of the same format as ``state`` mapping keys to callables that return a boolean indicating whether the given parameter should be saved (``True``) or filtered out (``False``). + """ import torch_xla.core.xla_model as xm diff --git a/src/lightning/fabric/strategies/xla_fsdp.py b/src/lightning/fabric/strategies/xla_fsdp.py index 81e269a794..e59c9843e0 100644 --- a/src/lightning/fabric/strategies/xla_fsdp.py +++ b/src/lightning/fabric/strategies/xla_fsdp.py @@ -188,6 +188,7 @@ class XLAFSDPStrategy(ParallelStrategy): This setup method doesn't modify the optimizer or wrap the optimizer. The only thing it currently does is verify that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the flattened parameters. + """ if _TORCH_GREATER_EQUAL_2_0: return optimizer @@ -210,12 +211,13 @@ class XLAFSDPStrategy(ParallelStrategy): ) def optimizer_step(self, optimizer: Optimizable, **kwargs: Any) -> Any: - """Overrides default tpu optimizer_step since FSDP should not call - `torch_xla.core.xla_model.optimizer_step`. Performs the actual optimizer step. + """Overrides default tpu optimizer_step since FSDP should not call `torch_xla.core.xla_model.optimizer_step`. + Performs the actual optimizer step. Args: optimizer: the optimizer performing the step **kwargs: Any extra arguments to ``optimizer.step`` + """ loss = optimizer.step(**kwargs) import torch_xla.core.xla_model as xm @@ -251,6 +253,7 @@ class XLAFSDPStrategy(ParallelStrategy): sync_grads: flag that allows users to synchronize gradients for the all-gather operation. Return: A tensor of shape (world_size, ...) + """ if not self._launched: return tensor @@ -342,6 +345,7 @@ class XLAFSDPStrategy(ParallelStrategy): If the user specifies sharded checkpointing, the directory will contain one file per process, with model- and optimizer shards stored per file. If the user specifies full checkpointing, the directory will contain a consolidated checkpoint combining all of the sharded checkpoints. + """ if not _TORCH_GREATER_EQUAL_2_0: raise NotImplementedError( @@ -421,6 +425,7 @@ class XLAFSDPStrategy(ParallelStrategy): The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a directory of multiple files rather than a single file. + """ if not _TORCH_GREATER_EQUAL_2_0: raise NotImplementedError( diff --git a/src/lightning/fabric/utilities/apply_func.py b/src/lightning/fabric/utilities/apply_func.py index 1feedef96e..33231ccd19 100644 --- a/src/lightning/fabric/utilities/apply_func.py +++ b/src/lightning/fabric/utilities/apply_func.py @@ -56,6 +56,7 @@ class _TransferableDataType(ABC): ... return self >>> isinstance(CustomObject(), _TransferableDataType) True + """ @classmethod @@ -113,6 +114,7 @@ def convert_tensors_to_scalars(data: Any) -> Any: Raises: ValueError: If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar. + """ def to_item(value: Tensor) -> Union[int, float, bool]: diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 17d5d33c7e..4979e5db7c 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -34,6 +34,7 @@ def _load( Args: path_or_url: Path or URL of the checkpoint. map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations. + """ if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar @@ -65,6 +66,7 @@ def _atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None accepts. filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. + """ bytesbuffer = io.BytesIO() torch.save(checkpoint, bytesbuffer) @@ -107,6 +109,7 @@ def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False strict: A flag specific to Object Storage platforms. If set to ``False``, any non-existing path is considered as a valid directory-like path. In such cases, the directory (and any non-existing parent directories) will be created on the fly. Defaults to False. + """ # Object storage fsspec's are inconsistent with other file systems because they do not have real directories, # see for instance https://gcsfs.readthedocs.io/en/latest/api.html?highlight=makedirs#gcsfs.core.GCSFileSystem.mkdir diff --git a/src/lightning/fabric/utilities/data.py b/src/lightning/fabric/utilities/data.py index 7b30e0944a..e6e53034b9 100644 --- a/src/lightning/fabric/utilities/data.py +++ b/src/lightning/fabric/utilities/data.py @@ -175,8 +175,8 @@ def _dataloader_init_kwargs_resolve_sampler( sampler: Union[Sampler, Iterable], disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: - """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its - re-instantiation.""" + """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- + instantiation.""" batch_sampler = getattr(dataloader, "batch_sampler") if batch_sampler is not None: @@ -362,6 +362,7 @@ def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. + """ classes = get_all_subclasses(base_cls) | {base_cls} for cls in classes: @@ -399,6 +400,7 @@ def _replace_value_in_saved_args( """Tries to replace an argument value in a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs + """ if replace_key in arg_names: diff --git a/src/lightning/fabric/utilities/device_dtype_mixin.py b/src/lightning/fabric/utilities/device_dtype_mixin.py index 40a171134e..cb5590c098 100644 --- a/src/lightning/fabric/utilities/device_dtype_mixin.py +++ b/src/lightning/fabric/utilities/device_dtype_mixin.py @@ -55,8 +55,8 @@ class _DeviceDtypeModuleMixin(Module): def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers - different objects. So it should be called before constructing optimizer if the module will live on GPU - while being optimized. + different objects. So it should be called before constructing optimizer if the module will live on GPU while + being optimized. Arguments: device: If specified, all parameters will be copied to that device. If `None`, the current CUDA device @@ -64,6 +64,7 @@ class _DeviceDtypeModuleMixin(Module): Returns: Module: self + """ if device is None: device = torch.device("cuda", torch.cuda.current_device()) diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index 65e363cb06..2aa8872e87 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -113,8 +113,8 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: bool = False) -> List[int]: - """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of - the GPUs is not available. + """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the + GPUs is not available. Args: gpus: List of ints corresponding to GPU indices @@ -125,6 +125,7 @@ def _sanitize_gpu_ids(gpus: List[int], include_cuda: bool = False, include_mps: Raises: MisconfigurationException: If machine has fewer available GPUs than requested. + """ if sum((include_cuda, include_mps)) == 0: raise ValueError("At least one gpu type should be specified!") @@ -172,6 +173,7 @@ def _check_unique(device_ids: List[int]) -> None: Raises: MisconfigurationException: If ``device_ids`` of GPUs aren't unique + """ if len(device_ids) != len(set(device_ids)): raise MisconfigurationException("Device ID's (GPU) must be unique.") @@ -186,6 +188,7 @@ def _check_data_type(device_ids: object) -> None: Raises: TypeError: If ``device_ids`` of GPU/TPUs aren't ``int``, ``str`` or sequence of ``int``` + """ msg = "Device IDs (GPU/TPU) must be an int, a string, a sequence of ints, but you passed" if device_ids is None: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 5361566451..c7f52161c4 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -39,6 +39,7 @@ def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Ten Return: gathered_result: List with size equal to the process group where gathered_result[i] corresponds to result tensor from process i + """ if group is None: group = torch.distributed.group.WORLD @@ -98,6 +99,7 @@ def _sync_ddp_if_available( Return: reduced value + """ if torch.distributed.is_initialized(): return _sync_ddp(result, group=group, reduce_op=reduce_op) @@ -115,6 +117,7 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U Return: reduced value + """ divide_by_world_size = False @@ -197,6 +200,7 @@ def _all_gather_ddp_if_available( Return: A tensor of shape (world_size, batch, ...) + """ if not torch.distributed.is_initialized(): return tensor @@ -213,8 +217,8 @@ def _init_dist_connection( world_size: Optional[int] = None, **kwargs: Any, ) -> None: - """Utility function to initialize distributed connection by setting env variables and initializing the - distributed process group. + """Utility function to initialize distributed connection by setting env variables and initializing the distributed + process group. Args: cluster_environment: ``ClusterEnvironment`` instance @@ -226,6 +230,7 @@ def _init_dist_connection( Raises: RuntimeError: If ``torch.distributed`` is not available + """ if not torch.distributed.is_available(): raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group") diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index 52cf496cf0..2031c74c1d 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -30,6 +30,7 @@ class _EmptyInit(TorchFunctionMode): with _EmptyInit(): model = BigModel() model.load_state_dict(torch.load("checkpoint.pt")) + """ def __init__(self, enabled: bool = True) -> None: diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index db726ede05..c3874262ca 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -27,6 +27,7 @@ def _convert_params(params: Optional[Union[Dict[str, Any], Namespace]]) -> Dict[ Returns: params as a dictionary + """ # in case converting from namespace if isinstance(params, Namespace): @@ -46,6 +47,7 @@ def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: Returns: dictionary with all callables sanitized + """ def _sanitize_callable(val: Any) -> Any: @@ -81,6 +83,7 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent {'a/b': 123} >>> _flatten_dict({5: {'a': 123}}) {'5/a': 123} + """ result: Dict[str, Any] = {} for k, v in params.items(): @@ -114,6 +117,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: 'list': '[1, 2, 3]', 'namespace': 'Namespace(foo=3)', 'string': 'abc'} + """ for k in params: # convert relevant np scalars to python types first (instead of str) @@ -136,6 +140,7 @@ def _add_prefix( Returns: Dictionary with prefix and separator inserted before each key + """ if not prefix: return metrics diff --git a/src/lightning/fabric/utilities/registry.py b/src/lightning/fabric/utilities/registry.py index 609101fbb9..4c1a1f7bb0 100644 --- a/src/lightning/fabric/utilities/registry.py +++ b/src/lightning/fabric/utilities/registry.py @@ -42,6 +42,7 @@ def _load_external_callbacks(group: str) -> List[Any]: Return: A list of all callbacks collected from external factories. + """ if _PYTHON_GREATER_EQUAL_3_8_0: from importlib.metadata import entry_points diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index c3c6852a76..425db5ec35 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -17,8 +17,8 @@ min_seed_value = np.iinfo(np.uint32).min def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: - """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, - sets the following environment variables: + """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets + the following environment variables: - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. @@ -31,6 +31,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: :func:`~lightning.fabric.utilities.seed.pl_worker_init_function`. + """ if seed is None: env_seed = os.environ.get("PL_GLOBAL_SEED") @@ -70,6 +71,7 @@ def reset_seed() -> None: """Reset the seed to the value that :func:`lightning.fabric.utilities.seed.seed_everything` previously set. If :func:`lightning.fabric.utilities.seed.seed_everything` is unused, this function will do nothing. + """ seed = os.environ.get("PL_GLOBAL_SEED", None) if seed is None: diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index 3a118eb56d..0d840d5ec1 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -35,6 +35,7 @@ class SpikeDetection: exclude_batches_path: Where to save the file that contains the batches to exclude. Will default to current directory. finite_only: If set to ``False``, consider non-finite values like NaN, inf and -inf a spike as well. + """ def __init__( diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index de940810a5..906e9019fb 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -57,6 +57,7 @@ def _runif_reasons( This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. deepspeed: Require that microsoft/DeepSpeed is installed. dynamo: Require that `torch.dynamo` is supported. + """ reasons = [] kwargs = {} # used in conftest.py::pytest_collection_modifyitems diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 861acff348..28732f9264 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -39,14 +39,15 @@ _LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step class _FabricOptimizer: def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None: - """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the - optimizer step calls to the strategy. + """FabricOptimizer is a thin wrapper around the :class:`~torch.optim.Optimizer` that delegates the optimizer + step calls to the strategy. The underlying wrapped optimizer object can be accessed via the property :attr:`optimizer`. Args: optimizer: The optimizer to wrap strategy: Reference to the strategy for handling the optimizer step + """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_FabricOptimizer @@ -96,6 +97,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): original_module: The original, unmodified module as passed into the :meth:`lightning.fabric.fabric.Fabric.setup` method. This is needed when attribute lookup on this wrapper should pass through to the original module. + """ super().__init__() self._forward_module = forward_module @@ -108,8 +110,7 @@ class _FabricModule(_DeviceDtypeModuleMixin): return self._original_module or self._forward_module def forward(self, *args: Any, **kwargs: Any) -> Any: - """Casts all inputs to the right precision and handles autocast for operations in the module forward - method.""" + """Casts all inputs to the right precision and handles autocast for operations in the module forward method.""" args, kwargs = self._precision.convert_input((args, kwargs)) with self._precision.forward_context(): @@ -218,13 +219,14 @@ class _FabricModule(_DeviceDtypeModuleMixin): class _FabricDataLoader: def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: - """The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to - the device automatically if the device is specified. + """The FabricDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the + device automatically if the device is specified. Args: dataloader: The dataloader to wrap device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). + """ self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader @@ -277,6 +279,7 @@ def _unwrap_compiled(obj: Any) -> Any: """Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped. Use this function before instance checks against e.g. :class:`_FabricModule`. + """ if not _TORCH_GREATER_EQUAL_2_0: return obj @@ -296,6 +299,7 @@ def is_wrapped(obj: object) -> bool: Args: obj: The object to test. + """ obj = _unwrap_compiled(obj) return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)) diff --git a/src/lightning/pytorch/_graveyard/tpu.py b/src/lightning/pytorch/_graveyard/tpu.py index 602dc58534..dde1729735 100644 --- a/src/lightning/pytorch/_graveyard/tpu.py +++ b/src/lightning/pytorch/_graveyard/tpu.py @@ -35,6 +35,7 @@ class SingleTPUStrategy(SingleDeviceXLAStrategy): """Legacy class. Use :class:`~lightning.pytorch.strategies.single_xla.SingleDeviceXLAStrategy` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -51,6 +52,7 @@ class TPUAccelerator(XLAAccelerator): """Legacy class. Use :class:`~lightning.pytorch.accelerators.xla.XLAAccelerator` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -64,6 +66,7 @@ class TPUPrecisionPlugin(XLAPrecisionPlugin): """Legacy class. Use :class:`~lightning.pytorch.plugins.precision.xla.XLAPrecisionPlugin` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -78,6 +81,7 @@ class TPUBf16PrecisionPlugin(XLABf16PrecisionPlugin): """Legacy class. Use :class:`~lightning.pytorch.plugins.precision.xlabf16.XLABf16PrecisionPlugin` instead. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 3f78b1f667..0490c2d864 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -23,6 +23,7 @@ class Accelerator(_Accelerator, ABC): """The Accelerator base class for Lightning PyTorch. .. warning:: Writing your own accelerator is an :ref:`experimental ` feature. + """ def setup(self, trainer: "pl.Trainer") -> None: @@ -30,6 +31,7 @@ class Accelerator(_Accelerator, ABC): Args: trainer: the trainer instance + """ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: @@ -40,5 +42,6 @@ class Accelerator(_Accelerator, ABC): Returns: Dictionary of device stats + """ raise NotImplementedError diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index 9161bc92f8..b1d621ad07 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -69,6 +69,7 @@ class CUDAAccelerator(Accelerator): Raises: FileNotFoundError: If nvidia-smi installation not found + """ return torch.cuda.memory_stats(device) @@ -115,6 +116,7 @@ def get_nvidia_gpu_stats(device: _DEVICE) -> Dict[str, float]: # pragma: no-cov Raises: FileNotFoundError: If nvidia-smi installation not found + """ nvidia_smi_path = shutil.which("nvidia-smi") if nvidia_smi_path is None: diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index 03ba218604..f25ed82f16 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -28,6 +28,7 @@ class MPSAccelerator(Accelerator): """Accelerator for Metal Apple Silicon GPU devices. .. warning:: Use of this accelerator beyond import and instantiation is experimental. + """ def setup_device(self, device: torch.device) -> None: diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index fe9c1261c9..e1ef449e79 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -23,6 +23,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator): """Accelerator for XLA devices, normally TPUs. .. warning:: Use of this accelerator beyond import and instantiation is experimental. + """ def get_device_stats(self, device: _DEVICE) -> Dict[str, Any]: @@ -33,6 +34,7 @@ class XLAAccelerator(Accelerator, FabricXLAAccelerator): Returns: A dictionary mapping the metrics (free memory and peak memory) to their values. + """ import torch_xla.core.xla_model as xm diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 197433dc64..447b7dda94 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -26,6 +26,7 @@ class Callback: r"""Abstract base class used to build new callbacks. Subclass this class and override any of the relevant hooks + """ @property @@ -35,6 +36,7 @@ class Callback: Used to store and retrieve a callback's state from the checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. + """ return self.__class__.__qualname__ @@ -44,11 +46,12 @@ class Callback: return type(self) def _generate_state_key(self, **kwargs: Any) -> str: - """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful - for defining a :attr:`state_key`. + """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful for + defining a :attr:`state_key`. Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. + """ return f"{self.__class__.__qualname__}{repr(kwargs)}" @@ -83,6 +86,7 @@ class Callback: Note: The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the loss returned from ``training_step``. + """ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -114,6 +118,7 @@ class Callback: pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() + """ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -249,6 +254,7 @@ class Callback: trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance. pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance. checkpoint: the checkpoint dictionary that will be saved. + """ def on_load_checkpoint( @@ -260,6 +266,7 @@ class Callback: trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance. pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance. checkpoint: the full checkpoint dictionary that got loaded by the Trainer. + """ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: diff --git a/src/lightning/pytorch/callbacks/checkpoint.py b/src/lightning/pytorch/callbacks/checkpoint.py index 301761049b..3f241278b7 100644 --- a/src/lightning/pytorch/callbacks/checkpoint.py +++ b/src/lightning/pytorch/callbacks/checkpoint.py @@ -6,4 +6,5 @@ class Checkpoint(Callback): Expert users may want to subclass it in case of writing custom :class:`~lightning.pytorch.callbacksCheckpoint` callback, so that the trainer recognizes the custom class as a checkpointing callback. + """ diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index d6996e4081..23c5a2a533 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -14,6 +14,7 @@ r"""Early Stopping ^^^^^^^^^^^^^^ Monitor a metric and stop training when it stops improving. + """ import logging from typing import Any, Callable, Dict, Optional, Tuple @@ -80,6 +81,7 @@ class EarlyStopping(Callback): *monitor, mode* Read more: :ref:`Persisting Callback State ` + """ mode_dict = {"min": torch.lt, "max": torch.gt} diff --git a/src/lightning/pytorch/callbacks/finetuning.py b/src/lightning/pytorch/callbacks/finetuning.py index 9d19cc8f86..26386551bd 100644 --- a/src/lightning/pytorch/callbacks/finetuning.py +++ b/src/lightning/pytorch/callbacks/finetuning.py @@ -74,6 +74,7 @@ class BaseFinetuning(Callback): ... optimizer=optimizer, ... train_bn=True, ... ) + """ def __init__(self) -> None: @@ -106,14 +107,15 @@ class BaseFinetuning(Callback): @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: - """This function is used to flatten a module or an iterable of modules into a list of its leaf modules - (modules with no children) and parent modules that have parameters directly themselves. + """This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules + with no children) and parent modules that have parameters directly themselves. Args: modules: A given module or an iterable of modules Returns: List of modules + """ if isinstance(modules, ModuleDict): modules = modules.values() @@ -142,6 +144,7 @@ class BaseFinetuning(Callback): requires_grad: Whether to create a generator for trainable or non-trainable parameters. Returns: Generator + """ modules = BaseFinetuning.flatten_modules(modules) for mod in modules: @@ -158,6 +161,7 @@ class BaseFinetuning(Callback): Args: modules: A given module or an iterable of modules + """ modules = BaseFinetuning.flatten_modules(modules) for module in modules: @@ -173,6 +177,7 @@ class BaseFinetuning(Callback): Args: module: A given module + """ if isinstance(module, _BatchNorm): module.track_running_stats = False @@ -190,6 +195,7 @@ class BaseFinetuning(Callback): Returns: None + """ modules = BaseFinetuning.flatten_modules(modules) for mod in modules: @@ -208,6 +214,7 @@ class BaseFinetuning(Callback): Returns: List of parameters not contained in this optimizer param groups + """ out_params = [] removed_params = [] @@ -245,6 +252,7 @@ class BaseFinetuning(Callback): initial_denom_lr: If no lr is provided, the learning from the first param group will be used and divided by `initial_denom_lr`. train_bn: Whether to train the BatchNormalization layers. + """ BaseFinetuning.make_trainable(modules) params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr) @@ -338,6 +346,7 @@ class BackboneFinetuning(BaseFinetuning): >>> multiplicative = lambda epoch: 1.5 >>> backbone_finetuning = BackboneFinetuning(200, multiplicative) >>> trainer = Trainer(callbacks=[backbone_finetuning]) + """ def __init__( diff --git a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py index 9c0b1a741f..1a18454b55 100644 --- a/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py +++ b/src/lightning/pytorch/callbacks/gradient_accumulation_scheduler.py @@ -60,6 +60,7 @@ class GradientAccumulationScheduler(Callback): # because epoch (key) should be zero-indexed. >>> accumulator = GradientAccumulationScheduler(scheduling={4: 2}) >>> trainer = Trainer(callbacks=[accumulator]) + """ def __init__(self, scheduling: Dict[int, int]): @@ -99,8 +100,7 @@ class GradientAccumulationScheduler(Callback): return accumulate_grad_batches def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Performns a configuration validation before training starts and raises errors for incompatible - settings.""" + """Performns a configuration validation before training starts and raises errors for incompatible settings.""" if not pl_module.automatic_optimization: raise RuntimeError( diff --git a/src/lightning/pytorch/callbacks/lambda_function.py b/src/lightning/pytorch/callbacks/lambda_function.py index e062656313..45d7764a1b 100644 --- a/src/lightning/pytorch/callbacks/lambda_function.py +++ b/src/lightning/pytorch/callbacks/lambda_function.py @@ -14,6 +14,7 @@ r"""Lambda Callback ^^^^^^^^^^^^^^^ Create a simple callback on the fly using lambda functions. + """ from typing import Callable, Optional @@ -32,6 +33,7 @@ class LambdaCallback(Callback): >>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import LambdaCallback >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))]) + """ def __init__( diff --git a/src/lightning/pytorch/callbacks/lr_monitor.py b/src/lightning/pytorch/callbacks/lr_monitor.py index d938db61d6..d823cf5d52 100644 --- a/src/lightning/pytorch/callbacks/lr_monitor.py +++ b/src/lightning/pytorch/callbacks/lr_monitor.py @@ -84,6 +84,7 @@ class LearningRateMonitor(Callback): ) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) return [optimizer], [lr_scheduler] + """ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None: @@ -95,12 +96,13 @@ class LearningRateMonitor(Callback): self.lrs: Dict[str, List[float]] = {} def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - """Called before training, determines unique names for all lr schedulers in the case of multiple of the - same type or in the case of multiple parameter groups. + """Called before training, determines unique names for all lr schedulers in the case of multiple of the same + type or in the case of multiple parameter groups. Raises: MisconfigurationException: If ``Trainer`` has no ``logger``. + """ if not trainer.loggers: raise MisconfigurationException( diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 8a75e9e137..08518a863c 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -569,6 +569,7 @@ class ModelCheckpoint(Checkpoint): >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) 'step=0.ckpt' + """ filename = filename or self.filename filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name) @@ -588,6 +589,7 @@ class ModelCheckpoint(Checkpoint): 3. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers The path gets extended with subdirectory "checkpoints". + """ if self.dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint diff --git a/src/lightning/pytorch/callbacks/model_summary.py b/src/lightning/pytorch/callbacks/model_summary.py index 4fc788ba09..870d5a73a2 100644 --- a/src/lightning/pytorch/callbacks/model_summary.py +++ b/src/lightning/pytorch/callbacks/model_summary.py @@ -47,6 +47,7 @@ class ModelSummary(Callback): >>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import ModelSummary >>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)]) + """ def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: diff --git a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py index 760e774a25..0b5a953cd7 100644 --- a/src/lightning/pytorch/callbacks/on_exception_checkpoint.py +++ b/src/lightning/pytorch/callbacks/on_exception_checkpoint.py @@ -41,6 +41,7 @@ class OnExceptionCheckpoint(Checkpoint): >>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import OnExceptionCheckpoint >>> trainer = Trainer(callbacks=[OnExceptionCheckpoint(".")]) + """ FILE_EXTENSION = ".ckpt" diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index 0f19c77102..74ee0b85a7 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -99,6 +99,7 @@ class BasePredictionWriter(Callback): trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer]) model = BoringModel() trainer.predict(model, return_predictions=False) + """ def __init__(self, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch") -> None: diff --git a/src/lightning/pytorch/callbacks/progress/progress_bar.py b/src/lightning/pytorch/callbacks/progress/progress_bar.py index 7a2b57be17..b20dda1cd1 100644 --- a/src/lightning/pytorch/callbacks/progress/progress_bar.py +++ b/src/lightning/pytorch/callbacks/progress/progress_bar.py @@ -19,9 +19,9 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn class ProgressBar(Callback): - r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that - keeps track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. You should - implement your highly custom progress bars with this as the base class. + r"""The base class for progress bars in Lightning. It is a :class:`~lightning.pytorch.callbacks.Callback` that keeps + track of the batch progress in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. You should implement your + highly custom progress bars with this as the base class. Example:: @@ -42,6 +42,7 @@ class ProgressBar(Callback): bar = LitProgressBar() trainer = Trainer(callbacks=[bar]) + """ def __init__(self) -> None: @@ -80,6 +81,7 @@ class ProgressBar(Callback): Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. + """ return self.trainer.num_training_batches @@ -89,6 +91,7 @@ class ProgressBar(Callback): Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. + """ batches = self.trainer.num_sanity_val_batches if self.trainer.sanity_checking else self.trainer.num_val_batches if isinstance(batches, list): @@ -102,6 +105,7 @@ class ProgressBar(Callback): Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. + """ batches = self.trainer.num_test_batches if isinstance(batches, list): @@ -115,6 +119,7 @@ class ProgressBar(Callback): Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. + """ assert self._current_eval_dataloader_idx is not None return self.trainer.num_predict_batches[self._current_eval_dataloader_idx] @@ -125,6 +130,7 @@ class ProgressBar(Callback): Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. + """ if not self.trainer.fit_loop.epoch_loop._should_check_val_epoch(): return 0 @@ -152,6 +158,7 @@ class ProgressBar(Callback): The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder `. to temporarily enable and disable the training progress bar. + """ raise NotImplementedError @@ -167,8 +174,8 @@ class ProgressBar(Callback): def get_metrics( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: - r"""Combines progress bar metrics collected from the trainer with standard metrics from - get_standard_metrics. Implement this to override the items displayed in the progress bar. + r"""Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. + Implement this to override the items displayed in the progress bar. Here is an example of how to override the defaults: @@ -182,6 +189,7 @@ class ProgressBar(Callback): Return: Dictionary with the items to be displayed in the progress bar. + """ standard_metrics = get_standard_metrics(trainer) pbar_metrics = trainer.progress_bar_metrics @@ -206,6 +214,7 @@ def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]: Return: Dictionary with the standard metrics to be displayed in the progress bar. + """ items_dict: Dict[str, Union[int, str]] = {} if trainer.loggers: diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index b6934678c9..48aee9673e 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -191,6 +191,7 @@ class RichProgressBarTheme: metrics: Style for the metrics https://rich.readthedocs.io/en/stable/style.html + """ description: Union[str, Style] = "white" @@ -234,6 +235,7 @@ class RichProgressBar(ProgressBar): PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output. Reference: https://rich.readthedocs.io/en/latest/introduction.html#requirements + """ def __init__( diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index d938b92571..fc36f81e82 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -407,6 +407,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """The tqdm doesn't support inf/nan values. We have to convert it to None. + """ if x is None or math.isinf(x) or math.isnan(x): return None diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 5cffe61717..83e430fa1a 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -74,8 +74,8 @@ class ModelPruning(Callback): verbose: int = 0, prune_on_train_epoch_end: bool = True, ) -> None: - """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning - networks parameters during training. + """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning networks + parameters during training. To learn more about pruning with PyTorch, please take a look at `this tutorial `_. @@ -152,6 +152,7 @@ class ModelPruning(Callback): if ``pruning_norm`` is not provided when ``"ln_structured"``, if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or if ``amount`` is none of ``int``, ``float`` and ``Callable``. + """ self._use_global_unstructured = use_global_unstructured @@ -235,6 +236,7 @@ class ModelPruning(Callback): IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. + """ pruning_meth = ( _PYTORCH_PRUNING_METHOD[pruning_fn] @@ -259,6 +261,7 @@ class ModelPruning(Callback): """Removes pruning buffers from any pruned modules. Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122 + """ for _, module in module.named_modules(): for k in list(module._forward_pre_hooks): @@ -286,6 +289,7 @@ class ModelPruning(Callback): This function implements the step 4. The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` + """ # noqa: E501 assert self._original_layers is not None for d in self._original_layers.values(): diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index f68c98259b..0a5d7d286c 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -23,8 +23,8 @@ if _RICH_AVAILABLE: # type: ignore[has-type] class RichModelSummary(ModelSummary): - r"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule` with `rich - text formatting `_. + r"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.module.LightningModule` with `rich text + formatting `_. Install it with pip: @@ -56,6 +56,7 @@ class RichModelSummary(ModelSummary): Raises: ModuleNotFoundError: If required `rich` package is not installed on the device. + """ def __init__(self, max_depth: int = 1, **summarize_kwargs: Any) -> None: diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index dc6ae074a3..10a4d2fd19 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -89,6 +89,7 @@ class StochasticWeightAveraging(Callback): device: if provided, the averaged model will be stored on the ``device``. When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) + """ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." diff --git a/src/lightning/pytorch/callbacks/timer.py b/src/lightning/pytorch/callbacks/timer.py index bb6245dbb0..36c0ee9719 100644 --- a/src/lightning/pytorch/callbacks/timer.py +++ b/src/lightning/pytorch/callbacks/timer.py @@ -33,8 +33,8 @@ class Interval(LightningEnum): class Timer(Callback): - """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the - Trainer if the given time limit for the training loop is reached. + """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer + if the given time limit for the training loop is reached. Args: duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`, @@ -69,6 +69,7 @@ class Timer(Callback): timer.time_elapsed("train") timer.start_time("validate") timer.end_time("test") + """ def __init__( diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 7c8c6f5afc..95f105402b 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -88,6 +88,7 @@ class LightningArgumentParser(ArgumentParser): description: Description of the tool shown when running ``--help``. env_prefix: Prefix for environment variables. Set ``default_env=True`` to enable env parsing. default_env: Whether to parse environment variables. + """ if not _JSONARGPARSE_SIGNATURES_AVAILABLE: raise ModuleNotFoundError(f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}") @@ -120,6 +121,7 @@ class LightningArgumentParser(ArgumentParser): Returns: A list with the names of the class arguments added. + """ if callable(lightning_class) and not isinstance(lightning_class, type): lightning_class = class_from_function(lightning_class) @@ -155,6 +157,7 @@ class LightningArgumentParser(ArgumentParser): optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ if isinstance(optimizer_class, tuple): assert all(issubclass(o, Optimizer) for o in optimizer_class) @@ -180,6 +183,7 @@ class LightningArgumentParser(ArgumentParser): tuple to allow subclasses. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. + """ if isinstance(lr_scheduler_class, tuple): assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) @@ -206,6 +210,7 @@ class SaveConfigCallback(Callback): Raises: RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run + """ def __init__( @@ -284,6 +289,7 @@ class SaveConfigCallback(Callback): worry about ranks or race conditions. Since it only runs on rank zero, any collective call will make the process hang waiting for a broadcast. If you need to make collective calls, implement the setup method instead. + """ @@ -306,8 +312,8 @@ class LightningCLI: run: bool = True, auto_configure_optimizers: bool = True, ) -> None: - """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which - are called / instantiated using a parsed configuration file and / or command line args. + """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are + called / instantiated using a parsed configuration file and / or command line args. Parsing of configuration from environment variables can be enabled by setting ``parser_kwargs={"default_env": True}``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed @@ -345,6 +351,7 @@ class LightningCLI: ``dict`` or ``jsonargparse.Namespace``. run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. + """ self.save_config_callback = save_config_callback self.save_config_kwargs = save_config_kwargs or {} @@ -450,6 +457,7 @@ class LightningCLI: Args: parser: The parser object to which arguments can be added + """ @staticmethod @@ -533,6 +541,7 @@ class LightningCLI: Args: kwargs: Any custom trainer arguments. + """ extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys] trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} @@ -580,6 +589,7 @@ class LightningCLI: lightning_module: A reference to the model. optimizer: The optimizer. lr_scheduler: The learning rate scheduler (if used). + """ if lr_scheduler is None: return optimizer @@ -591,8 +601,8 @@ class LightningCLI: return [optimizer], [lr_scheduler] def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: - """Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method - if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" + """Overrides the model's :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` method if + a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" if not self.auto_configure_optimizers: return @@ -725,6 +735,7 @@ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) - Returns: The instantiated class object. + """ kwargs = init.get("init_args", {}) if not isinstance(args, tuple): diff --git a/src/lightning/pytorch/core/datamodule.py b/src/lightning/pytorch/core/datamodule.py index b556eefc30..bda6dac711 100644 --- a/src/lightning/pytorch/core/datamodule.py +++ b/src/lightning/pytorch/core/datamodule.py @@ -28,8 +28,8 @@ from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADER class LightningDataModule(DataHooks, HyperparametersMixin): - """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main - advantage is consistent data splits, data preparation and transforms across models. + """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is + consistent data splits, data preparation and transforms across models. Example:: @@ -62,6 +62,7 @@ class LightningDataModule(DataHooks, HyperparametersMixin): # clean up state after the trainer stops, delete files... # called on every process in DDP ... + """ name: Optional[str] = None @@ -98,6 +99,7 @@ class LightningDataModule(DataHooks, HyperparametersMixin): data will be loaded in the main process. Number of CPUs available. This parameter gets forwarded to the ``__init__`` if the datamodule has such a name defined in its signature. **datamodule_kwargs: Additional parameters that get passed down to the datamodule's ``__init__``. + """ def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader: @@ -142,6 +144,7 @@ class LightningDataModule(DataHooks, HyperparametersMixin): Returns: A dictionary containing datamodule state. + """ return {} @@ -150,6 +153,7 @@ class LightningDataModule(DataHooks, HyperparametersMixin): Args: state_dict: the datamodule state returned by ``state_dict``. + """ pass diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 5c491c1aef..e923faa2a4 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -31,12 +31,14 @@ class ModelHooks: """Called at the very beginning of fit. If on DDP it is called on every process + """ def on_fit_end(self) -> None: """Called at the very end of fit. If on DDP it is called on every process + """ def on_train_start(self) -> None: @@ -71,6 +73,7 @@ class ModelHooks: Args: batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch + """ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: @@ -80,6 +83,7 @@ class ModelHooks: outputs: The outputs of training_step(x) batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch + """ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: @@ -89,6 +93,7 @@ class ModelHooks: batch: The batched data as it is returned by the validation DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader + """ def on_validation_batch_end( @@ -101,6 +106,7 @@ class ModelHooks: batch: The batched data as it is returned by the validation DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader + """ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: @@ -110,6 +116,7 @@ class ModelHooks: batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader + """ def on_test_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: @@ -120,6 +127,7 @@ class ModelHooks: batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader + """ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: @@ -129,6 +137,7 @@ class ModelHooks: batch: The batched data as it is returned by the test DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader + """ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: @@ -139,6 +148,7 @@ class ModelHooks: batch: The batched data as it is returned by the prediction DataLoader. batch_idx: the index of the batch dataloader_idx: the index of the dataloader + """ def on_validation_model_eval(self) -> None: @@ -188,6 +198,7 @@ class ModelHooks: self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear() + """ def on_validation_epoch_start(self) -> None: @@ -274,6 +285,7 @@ class ModelHooks: """Deprecated. Use :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` instead. + """ def configure_model(self) -> None: @@ -286,6 +298,7 @@ class ModelHooks: This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent. + """ @@ -308,8 +321,8 @@ class DataHooks: def prepare_data(self) -> None: """Use this to download and prepare data. Downloading and saving data with multiple processes (distributed - settings) will result in corrupted data. Lightning ensures this method is called only within a single - process, so you can safely add your downloading logic within. + settings) will result in corrupted data. Lightning ensures this method is called only within a single process, + so you can safely add your downloading logic within. .. warning:: DO NOT set state to the model (use ``setup`` instead) since this is NOT called on every device @@ -359,12 +372,13 @@ class DataHooks: model.val_dataloader() model.test_dataloader() model.predict_dataloader() + """ def setup(self, stage: str) -> None: - """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when - you need to build models dynamically or adjust something about them. This hook is called on every process - when using DDP. + """Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you + need to build models dynamically or adjust something about them. This hook is called on every process when + using DDP. Args: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` @@ -385,6 +399,7 @@ class DataHooks: def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes) + """ def teardown(self, stage: str) -> None: @@ -392,6 +407,7 @@ class DataHooks: Args: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` + """ def train_dataloader(self) -> TRAIN_DATALOADERS: @@ -419,6 +435,7 @@ class DataHooks: Note: Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself. + """ raise MisconfigurationException("`train_dataloader` must be implemented to be used with the Lightning Trainer") @@ -448,6 +465,7 @@ class DataHooks: Note: If you don't need a test dataset and a :meth:`test_step`, you don't need to implement this method. + """ raise MisconfigurationException("`test_dataloader` must be implemented to be used with the Lightning Trainer") @@ -474,6 +492,7 @@ class DataHooks: Note: If you don't need a validation dataset and a :meth:`validation_step`, you don't need to implement this method. + """ raise MisconfigurationException("`val_dataloader` must be implemented to be used with the Lightning Trainer") @@ -494,14 +513,15 @@ class DataHooks: Return: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying prediction samples. + """ raise MisconfigurationException( "`predict_dataloader` must be implemented to be used with the Lightning Trainer" ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom - data structure. + """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data + structure. The data types listed below (and any arbitrary nesting of them) are supported out of the box: @@ -548,6 +568,7 @@ class DataHooks: See Also: - :meth:`move_data_to_device` - :meth:`apply_to_collection` + """ return move_data_to_device(batch, device) @@ -575,6 +596,7 @@ class DataHooks: See Also: - :meth:`on_after_batch_transfer` - :meth:`transfer_batch_to_device` + """ return batch @@ -606,6 +628,7 @@ class DataHooks: See Also: - :meth:`on_before_batch_transfer` - :meth:`transfer_batch_to_device` + """ return batch @@ -614,8 +637,8 @@ class CheckpointHooks: """Hooks to be used with Checkpointing.""" def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this - is your chance to restore this. + r"""Called by Lightning to restore your model. If you saved something with :meth:`on_save_checkpoint` this is + your chance to restore this. Args: checkpoint: Loaded checkpoint @@ -629,11 +652,12 @@ class CheckpointHooks: Note: Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training. + """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want - to save. + r"""Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to + save. Args: checkpoint: The full checkpoint dictionary before it gets dumped to a file. @@ -649,4 +673,5 @@ class CheckpointHooks: Lightning saves all aspects of training (epoch, global step, etc...) including amp scaling. There is no need for you to store anything about training. + """ diff --git a/src/lightning/pytorch/core/mixins/hparams_mixin.py b/src/lightning/pytorch/core/mixins/hparams_mixin.py index d30caeda6b..ca6ad172e0 100644 --- a/src/lightning/pytorch/core/mixins/hparams_mixin.py +++ b/src/lightning/pytorch/core/mixins/hparams_mixin.py @@ -132,11 +132,12 @@ class HyperparametersMixin: @property def hparams(self) -> Union[AttributeDict, MutableMapping]: - """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. - For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. + """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For + the frozen set of initial hyperparameters, use :attr:`hparams_initial`. Returns: Mutable hyperparameters dictionary + """ if not hasattr(self, "_hparams"): self._hparams = AttributeDict() @@ -149,6 +150,7 @@ class HyperparametersMixin: Returns: AttributeDict: immutable initial hyperparameters + """ if not hasattr(self, "_hparams_initial"): return AttributeDict() diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b5ab6604c3..a4237a7352 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -152,6 +152,7 @@ class LightningModule( Returns: A single optimizer, or a list of optimizers in case multiple ones are present. + """ if self._fabric: opts: MODULE_OPTIMIZERS = self._fabric_optimizers @@ -171,12 +172,12 @@ class LightningModule( return opts def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]: - """Returns the learning rate scheduler(s) that are being used during training. Useful for manual - optimization. + """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. Returns: A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no schedulers were returned in :meth:`configure_optimizers`. + """ if not self.trainer.lr_scheduler_configs: return None @@ -224,8 +225,8 @@ class LightningModule( @property def example_input_array(self) -> Optional[Union[Tensor, Tuple, Dict]]: - """The example input array is a specification of what the module can consume in the :meth:`forward` method. - The return type is interpreted as follows: + """The example input array is a specification of what the module can consume in the :meth:`forward` method. The + return type is interpreted as follows: - Single tensor: It is assumed the model takes a single argument, i.e., ``model.forward(model.example_input_array)`` @@ -233,6 +234,7 @@ class LightningModule( ``model.forward(*model.example_input_array)`` - Dict: The input array represents named keyword arguments, i.e., ``model.forward(**model.example_input_array)`` + """ return self._example_input_array @@ -250,6 +252,7 @@ class LightningModule( """Total training batches seen across all epochs. If no Trainer is attached, this propery is 0. + """ return self.trainer.global_step if self._trainer else 0 @@ -333,6 +336,7 @@ class LightningModule( def forward(self, x): self.print(x, 'in forward') + """ if self.trainer.is_global_zero: progress_bar = self.trainer.progress_bar_callback @@ -389,6 +393,7 @@ class LightningModule( :class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute. rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. + """ if self._fabric is not None: self._log_dict_through_fabric(dictionary={name: value}, logger=logger) @@ -551,6 +556,7 @@ class LightningModule( but some data structures might need to explicitly provide it. rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which would produce a deadlock as not all processes would perform this log call. + """ if self._fabric is not None: return self._log_dict_through_fabric(dictionary=dictionary, logger=logger) @@ -630,6 +636,7 @@ class LightningModule( Return: A tensor of shape (world_size, batch, ...), or if the input was a collection the output will also be a collection with tensors of this shape. + """ group = group if group is not None else torch.distributed.group.WORLD all_gather = self.trainer.strategy.all_gather @@ -645,6 +652,7 @@ class LightningModule( Return: Your model's output + """ return super().forward(*args, **kwargs) @@ -698,12 +706,13 @@ class LightningModule( Note: When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically normalized by ``accumulate_grad_batches`` internally. + """ rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples - or calculate anything of interest like accuracy. + r"""Operates on a single batch of data from the validation set. In this step you'd might generate examples or + calculate anything of interest like accuracy. Args: batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`. @@ -767,6 +776,7 @@ class LightningModule( When the :meth:`validation_step` is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled. + """ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @@ -835,11 +845,12 @@ class LightningModule( When the :meth:`test_step` is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled. + """ def predict_step(self, *args: Any, **kwargs: Any) -> Any: - """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it - calls :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic. + """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls + :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic. The :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` is used to scale inference on multi-devices. @@ -871,6 +882,7 @@ class LightningModule( model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm) + """ # For backwards compatibility batch = kwargs.get("batch", args[0]) @@ -897,9 +909,9 @@ class LightningModule( return [] def configure_optimizers(self) -> Any: - r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need - one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only - works in the manual optimization mode. + r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. + But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in + the manual optimization mode. Return: Any of these 6 options. @@ -994,6 +1006,7 @@ class LightningModule( - If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself. - If you need to control how often the optimizer steps, override the :meth:`optimizer_step` hook. + """ rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") @@ -1017,6 +1030,7 @@ class LightningModule( loss: The tensor on which to compute gradients. Must have a graph attached. *args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward` **kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward` + """ if self._fabric: self._fabric.backward(loss, *args, **kwargs) @@ -1025,8 +1039,8 @@ class LightningModule( self.trainer.strategy.backward(loss, None, *args, **kwargs) def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None: - """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your - own implementation if you need to. + """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own + implementation if you need to. Args: loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here @@ -1036,6 +1050,7 @@ class LightningModule( def backward(self, loss): loss.backward() + """ if self._fabric: self._fabric.backward(loss, *args, **kwargs) @@ -1043,13 +1058,14 @@ class LightningModule( loss.backward(*args, **kwargs) def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None: - """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step - to prevent dangling gradients in multiple-optimizer setup. + """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to + prevent dangling gradients in multiple-optimizer setup. It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset. Args: optimizer: The optimizer to toggle. + """ # Iterate over all optimizer parameters to preserve their `requires_grad` information # in case these are pre-defined during `configure_optimizers` @@ -1075,6 +1091,7 @@ class LightningModule( Args: optimizer: The optimizer to untoggle. + """ for opt in self.trainer.optimizers: if not (opt is optimizer or (isinstance(optimizer, LightningOptimizer) and opt is optimizer.optimizer)): @@ -1106,6 +1123,7 @@ class LightningModule( gradient_clip_val: The value at which to clip gradients. gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. + """ if self.fabric is not None: @@ -1177,6 +1195,7 @@ class LightningModule( gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm ) + """ self.clip_gradients( optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm @@ -1219,8 +1238,8 @@ class LightningModule( optimizer: Union[Optimizer, LightningOptimizer], optimizer_closure: Optional[Callable[[], Any]] = None, ) -> None: - r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` - calls the optimizer. + r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls + the optimizer. By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example. This method (and ``zero_grad()``) won't be called during the accumulation phase when @@ -1249,6 +1268,7 @@ class LightningModule( lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0) for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.learning_rate + """ optimizer.step(closure=optimizer_closure) @@ -1281,6 +1301,7 @@ class LightningModule( model = MyLightningModule(...) model.freeze() + """ for param in self.parameters(): param.requires_grad = False @@ -1294,6 +1315,7 @@ class LightningModule( model = MyLightningModule(...) model.unfreeze() + """ for param in self.parameters(): param.requires_grad = True @@ -1329,6 +1351,7 @@ class LightningModule( model = SimpleModel() input_sample = torch.randn(1, 64) model.to_onnx("export.onnx", input_sample, export_params=True) + """ if _TORCH_GREATER_EQUAL_2_0 and not _ONNX_AVAILABLE: raise ModuleNotFoundError( @@ -1533,6 +1556,7 @@ class LightningModule( """Adds ShardedTensor state dict hooks if ShardedTensors are supported. These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. + """ if _TORCH_GREATER_EQUAL_2_1: # ShardedTensor is deprecated in favor of DistributedTensor diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index afd2ca01a9..e90ff0be84 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -34,8 +34,8 @@ def do_nothing_closure() -> None: class LightningOptimizer: - """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic - across accelerators, AMP, accumulate_grad_batches.""" + """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across + accelerators, AMP, accumulate_grad_batches.""" def __init__(self, optimizer: Optimizer): # copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has @@ -73,6 +73,7 @@ class LightningOptimizer: When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. + """ # local import here to avoid circular import from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior @@ -144,6 +145,7 @@ class LightningOptimizer: with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) + """ self._on_before_step() @@ -237,8 +239,7 @@ def _configure_optimizers( def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]) -> List[LRSchedulerConfig]: - """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic - optimization.""" + """Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization.""" lr_scheduler_configs = [] for scheduler in schedulers: if isinstance(scheduler, dict): diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 9b32058a81..1fb5e0d46b 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -215,6 +215,7 @@ def update_hparams(hparams: dict, updates: dict) -> None: Args: hparams: the original params and also target object updates: new params to be used as update + """ for k, v in updates.items(): # if missing, add the key @@ -240,6 +241,7 @@ def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]: >>> vars(hparams) == hparams_new True >>> os.remove(path_csv) + """ fs = get_filesystem(tags_csv) if not fs.exists(tags_csv): @@ -282,6 +284,7 @@ def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Di >>> vars(hparams) == hparams_new True >>> os.remove(path_yaml) + """ fs = get_filesystem(config_yaml) if not fs.exists(config_yaml): diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 918e6c9737..3dd7bd8b1a 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -105,6 +105,7 @@ class BoringModel(LightningModule): class TestModel(BoringModel): def training_step(self, ...): ... # do your own thing + """ def __init__(self) -> None: diff --git a/src/lightning/pytorch/demos/mnist_datamodule.py b/src/lightning/pytorch/demos/mnist_datamodule.py index 63c36e108a..5f4528a939 100644 --- a/src/lightning/pytorch/demos/mnist_datamodule.py +++ b/src/lightning/pytorch/demos/mnist_datamodule.py @@ -149,6 +149,7 @@ class MNISTDataModule(LightningDataModule): >>> MNISTDataModule() # doctest: +ELLIPSIS <...mnist_datamodule.MNISTDataModule object at ...> + """ name = "mnist" diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index c8e2d6bb88..13e220759f 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -2,6 +2,7 @@ Code is adapted from the PyTorch examples at https://github.com/pytorch/examples/blob/main/word_language_model + """ import math import os diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index e2a095e2e6..27afe86730 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -204,6 +204,7 @@ class CometLogger(Logger): If required Comet package is not installed on the device. MisconfigurationException: If neither ``api_key`` nor ``save_dir`` are passed as arguments. + """ LOGGER_JOIN_CHAR = "-" @@ -357,6 +358,7 @@ class CometLogger(Logger): Returns: The path to the save directory. + """ return self._save_dir @@ -366,6 +368,7 @@ class CometLogger(Logger): Returns: The project name if it is specified, else "comet-default". + """ # Don't create an experiment if we don't have one if self._experiment is not None and self._experiment.project_name is not None: @@ -389,6 +392,7 @@ class CometLogger(Logger): 4. future experiment key. If none are present generates a new guid. + """ # Don't create an experiment if we don't have one if self._experiment is not None: diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index beeeb905c3..18cd83960d 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -45,6 +45,7 @@ class ExperimentWriter(_FabricExperimentWriter): Args: log_dir: Directory for the experiment logs + """ NAME_HPARAMS_FILE = "hparams.yaml" @@ -82,6 +83,7 @@ class CSVLogger(Logger, FabricCSVLogger): directory for existing versions, then automatically assigns the next available version. prefix: A string to put at the beginning of metric keys. flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). + """ LOGGER_JOIN_CHAR = "-" @@ -109,6 +111,7 @@ class CSVLogger(Logger, FabricCSVLogger): If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will be saved in "save_dir/version" + """ return os.path.join(self.save_dir, self.name) @@ -118,6 +121,7 @@ class CSVLogger(Logger, FabricCSVLogger): By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the constructor's version parameter instead of ``None`` or an int. + """ # create a pseudo standard path version = self.version if isinstance(self.version, str) else f"version_{self.version}" @@ -129,6 +133,7 @@ class CSVLogger(Logger, FabricCSVLogger): Returns: The path to current directory where logs are saved. + """ return self._save_dir diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 52a51ab8eb..59ff16ac99 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -36,6 +36,7 @@ class Logger(FabricLogger, ABC): Args: checkpoint_callback: the model checkpoint callback instance + """ pass @@ -50,6 +51,7 @@ class DummyLogger(Logger): """Dummy logger for internal use. It is useful if we want to disable user's logger for a feature, but still ensure that user code can run + """ def __init__(self) -> None: @@ -96,8 +98,7 @@ def merge_dicts( # pragma: no cover agg_key_funcs: Optional[Mapping] = None, default_func: Callable[[Sequence[float]], float] = np.mean, ) -> Dict: - """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given - function. + """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. Args: dicts: @@ -128,6 +129,7 @@ def merge_dicts( # pragma: no cover 'c': 1, 'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}}, 'v': 2.3} + """ agg_key_funcs = agg_key_funcs or {} keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 56aca09013..9834875da2 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -130,6 +130,7 @@ class MLFlowLogger(Logger): Raises: ModuleNotFoundError: If required MLFlow package is not installed on the device. + """ LOGGER_JOIN_CHAR = "-" @@ -218,6 +219,7 @@ class MLFlowLogger(Logger): Returns: The run id. + """ _ = self.experiment return self._run_id @@ -228,6 +230,7 @@ class MLFlowLogger(Logger): Returns: The experiment id. + """ _ = self.experiment return self._experiment_id @@ -295,6 +298,7 @@ class MLFlowLogger(Logger): Return: Local path to the root experiment directory if the tracking uri is local. Otherwise returns `None`. + """ if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX) @@ -306,6 +310,7 @@ class MLFlowLogger(Logger): Returns: The experiment id. + """ return self.experiment_id @@ -315,6 +320,7 @@ class MLFlowLogger(Logger): Returns: The run id. + """ return self.run_id diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 3e5f4f6566..b6e8b70172 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -223,6 +223,7 @@ class NeptuneLogger(Logger): If the required Neptune package is not installed. ValueError: If an argument passed to the logger's constructor is incorrect. + """ LOGGER_JOIN_CHAR = "/" @@ -413,6 +414,7 @@ class NeptuneLogger(Logger): ) neptune_logger.log_hyperparams(PARAMS) + """ params = _convert_params(params) params = _sanitize_callable_params(params) @@ -431,6 +433,7 @@ class NeptuneLogger(Logger): Args: metrics: Dictionary with metric names as keys and measured quantities as values. step: Step number at which the metrics should be recorded, currently ignored. + """ if rank_zero_only.rank != 0: raise ValueError("run tried to log from global_rank != 0") @@ -476,6 +479,7 @@ class NeptuneLogger(Logger): Args: checkpoint_callback: the model checkpoint callback instance + """ if not self._log_model_checkpoints: return @@ -560,5 +564,6 @@ class NeptuneLogger(Logger): """Return the experiment version. It's Neptune Run's short_id + """ return self._run_short_id diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index c99ca2ad02..567f80b4a8 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -96,6 +96,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): >>> tbl.log_metrics({"acc": 0.9}) >>> tbl.finalize("success") >>> shutil.rmtree(tmp) + """ NAME_HPARAMS_FILE = "hparams.yaml" @@ -133,6 +134,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will be saved in "save_dir/version" + """ return os.path.join(super().root_dir, self.name) @@ -142,6 +144,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the constructor's version parameter instead of ``None`` or an int. + """ # create a pseudo standard path ala test-tube version = self.version if isinstance(self.version, str) else f"version_{self.version}" @@ -158,6 +161,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): Returns: The local path to the save directory where the TensorBoard experiments are saved. + """ return self._root_dir @@ -166,12 +170,13 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the - hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs - to display the new ones with hyperparameters. + hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to + display the new ones with hyperparameters. Args: params: a dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values + """ params = _convert_params(params) @@ -233,6 +238,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): Args: checkpoint_callback: the model checkpoint callback instance + """ pass diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index ddc9e24749..588826aa41 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -35,6 +35,7 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) Args: checkpoint_callback: Checkpoint callback reference. logged_model_time: dictionary containing the logged model times. + """ # get checkpoints to be saved with associated score checkpoints = {} diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index df676230ed..9f9738a753 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -279,6 +279,7 @@ class WandbLogger(Logger): If required WandB package is not installed on the device. MisconfigurationException: If both ``log_model`` and ``offline`` is set to ``True``. + """ LOGGER_JOIN_CHAR = "-" @@ -436,6 +437,7 @@ class WandbLogger(Logger): """Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either with `columns` and `data` or with `dataframe`. + """ metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)} @@ -453,6 +455,7 @@ class WandbLogger(Logger): """Log text as a Table. Can be defined either with `columns` and `data` or with `dataframe`. + """ self.log_table(key, columns, data, dataframe, step) @@ -462,6 +465,7 @@ class WandbLogger(Logger): """Log images (tensors, numpy arrays, PIL Images or file paths). Optional kwargs are lists passed to each image (ex: caption, masks, boxes). + """ if not isinstance(images, list): raise TypeError(f'Expected a list as "images", found {type(images)}') @@ -479,6 +483,7 @@ class WandbLogger(Logger): Returns: The path to the save directory. + """ return self._save_dir @@ -489,6 +494,7 @@ class WandbLogger(Logger): Returns: The name of the project the current experiment belongs to. This name is not the same as `wandb.Run`'s name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead. + """ return self._project @@ -498,6 +504,7 @@ class WandbLogger(Logger): Returns: The id of the experiment if the experiment exists else the id given to the constructor. + """ # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id @@ -527,6 +534,7 @@ class WandbLogger(Logger): Returns: The path to the downloaded artifact. + """ if wandb.run is not None and use_artifact: artifact = wandb.run.use_artifact(artifact) @@ -546,6 +554,7 @@ class WandbLogger(Logger): Returns: wandb Artifact object for the artifact. + """ return self.experiment.use_artifact(artifact, type=artifact_type) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 14f39e5c0f..38a6b8803f 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -89,6 +89,7 @@ class _EvaluationLoop(_Loop): """In "sequential" mode, the max number of batches to run per dataloader. Otherwise, the max batches to run. + """ max_batches = self._max_batches if not self.trainer.sanity_checking: @@ -377,6 +378,7 @@ class _EvaluationLoop(_Loop): batch: The current batch to run through the step. batch_idx: The index of the current batch dataloader_idx: the index of the dataloader producing the current batch + """ trainer = self.trainer @@ -431,6 +433,7 @@ class _EvaluationLoop(_Loop): Returns: the dictionary containing all the keyboard arguments for the step + """ step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)]) if dataloader_idx is not None: diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 8df73c891c..9c526c9b6b 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -78,6 +78,7 @@ class _PrefetchDataFetcher(_DataFetcher): Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track whether a batch is the last one (available with :attr:`self.done`) when the length is not available. + """ def __init__(self, prefetch_batches: int = 1) -> None: diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 16c11ba45c..3f4047d8f7 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -69,6 +69,7 @@ class _FitLoop(_Loop): Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs, can be set -1 to turn this limit off + """ def __init__( diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index 2a3bf1dfc4..56d520800c 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -42,6 +42,7 @@ class _Loop: Returns: The current loop state. + """ return {} @@ -55,6 +56,7 @@ class _Loop: destination: An existing dictionary to update with this loop's state. By default a new dictionary is returned. prefix: A prefix for each key in the state dictionary + """ if destination is None: destination = {} diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index f86de295d8..26c5d40427 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -40,6 +40,7 @@ class ClosureResult(OutputResult): closure_loss: The loss with a graph attached. loss: A detached copy of the closure loss. extra: Any keys other than the loss returned. + """ closure_loss: Optional[Tensor] @@ -158,6 +159,7 @@ class _AutomaticOptimization(_Loop): Args: kwargs: the kwargs passed down to the hooks optimizer: the optimizer + """ closure = self._make_closure(kwargs, optimizer) @@ -203,6 +205,7 @@ class _AutomaticOptimization(_Loop): """Build a `zero_grad` function that zeroes the gradients before back-propagation. Returns ``None`` in the case backward needs to be skipped. + """ if self._skip_backward: return None @@ -218,10 +221,11 @@ class _AutomaticOptimization(_Loop): return zero_grad_fn def _make_backward_fn(self, optimizer: Optimizer) -> Optional[Callable[[Tensor], None]]: - """Build a `backward` function that handles back-propagation through the output produced by the - `training_step` function. + """Build a `backward` function that handles back-propagation through the output produced by the `training_step` + function. Returns ``None`` in the case backward needs to be skipped. + """ if self._skip_backward: return None @@ -242,6 +246,7 @@ class _AutomaticOptimization(_Loop): batch_idx: the index of the current batch train_step_and_backward_closure: the closure function performing the train step and computing the gradients. By default, called by the optimizer (if possible) + """ trainer = self.trainer @@ -285,6 +290,7 @@ class _AutomaticOptimization(_Loop): Args: batch_idx: the index of the current batch optimizer: the current optimizer + """ trainer = self.trainer call._call_lightning_module_hook(trainer, "optimizer_zero_grad", trainer.current_epoch, batch_idx, optimizer) @@ -298,6 +304,7 @@ class _AutomaticOptimization(_Loop): Returns: A ``ClosureResult`` containing the training step output. + """ trainer = self.trainer diff --git a/src/lightning/pytorch/loops/optimization/closure.py b/src/lightning/pytorch/loops/optimization/closure.py index ec85a96e54..4b550166b7 100644 --- a/src/lightning/pytorch/loops/optimization/closure.py +++ b/src/lightning/pytorch/loops/optimization/closure.py @@ -35,6 +35,7 @@ class AbstractClosure(ABC, Generic[T]): This class provides a simple abstraction making the instance of this class callable like a function while capturing the closure result and caching it. + """ def __init__(self) -> None: @@ -46,6 +47,7 @@ class AbstractClosure(ABC, Generic[T]): Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long as necessary. + """ if self._result is None: raise MisconfigurationException( diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index 79dd4360dc..01998ae5ff 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -72,6 +72,7 @@ class _ManualOptimization(_Loop): This loop is a trivial case because it performs only a single iteration (calling directly into the module's :meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s). + """ output_result_cls = ManualResult @@ -102,6 +103,7 @@ class _ManualOptimization(_Loop): Args: kwargs: The kwargs passed down to the hooks. + """ trainer = self.trainer diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 227d3246bb..6df419f412 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -208,6 +208,7 @@ class _PredictionLoop(_Loop): batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch + """ trainer = self.trainer batch = trainer.lightning_module._on_before_batch_transfer(batch, dataloader_idx=dataloader_idx) diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 788f97bbc6..8ff12ba378 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -45,6 +45,7 @@ class _ReadyCompletedTracker(_BaseProgress): completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last. + """ ready: int = 0 @@ -60,6 +61,7 @@ class _ReadyCompletedTracker(_BaseProgress): If there is a failure before all attributes are increased, restore the attributes to the last fully completed value. + """ self.ready = self.completed @@ -74,6 +76,7 @@ class _StartedTracker(_ReadyCompletedTracker): completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last. + """ started: int = 0 @@ -98,6 +101,7 @@ class _ProcessedTracker(_StartedTracker): completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last. + """ processed: int = 0 @@ -118,6 +122,7 @@ class _Progress(_BaseProgress): Args: total: Intended to track the total progress of an event. current: Intended to track the current progress of an event. + """ total: _ReadyCompletedTracker = field(default_factory=_ProcessedTracker) @@ -177,6 +182,7 @@ class _BatchProgress(_Progress): total: Tracks the total batch progress. current: Tracks the current batch progress. is_last_batch: Whether the batch is the last one. This is useful for iterable datasets. + """ is_last_batch: bool = False @@ -203,6 +209,7 @@ class _SchedulerProgress(_Progress): Args: total: Tracks the total scheduler progress. current: Tracks the current scheduler progress. + """ total: _ReadyCompletedTracker = field(default_factory=_ReadyCompletedTracker) @@ -216,6 +223,7 @@ class _OptimizerProgress(_BaseProgress): Args: step: Tracks ``optimizer.step`` calls. zero_grad: Tracks ``optimizer.zero_grad`` calls. + """ step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker)) @@ -244,6 +252,7 @@ class _OptimizationProgress(_BaseProgress): Args: optimizer: Tracks optimizer progress. + """ optimizer: _OptimizerProgress = field(default_factory=_OptimizerProgress) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index b9d205070e..46835f4677 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -181,6 +181,7 @@ class _TrainingEpochLoop(loops._Loop): Raises: StopIteration: When the epoch is canceled by the user returning -1 + """ if self.restarting and self._should_check_val_fx(): # skip training and run validation in `on_advance_end` @@ -288,8 +289,7 @@ class _TrainingEpochLoop(loops._Loop): return epoch_finished_on_ready or self.batch_progress.is_last_batch def _should_accumulate(self) -> bool: - """Checks if the optimizer step should be performed or gradients should be accumulated for the current - step.""" + """Checks if the optimizer step should be performed or gradients should be accumulated for the current step.""" accumulation_done = self._accumulated_batches_reached() # Lightning steps on the final batch is_final_batch = self._num_ready_batches_reached() @@ -312,6 +312,7 @@ class _TrainingEpochLoop(loops._Loop): This is used so non-plateau schedulers can be updated before running validation. Checkpoints are commonly saved during validation, however, on-plateau schedulers might monitor a validation metric so they have to be updated separately. + """ trainer = self.trainer @@ -413,6 +414,7 @@ class _TrainingEpochLoop(loops._Loop): Returns: The kwargs passed down to the hooks. + """ kwargs["batch"] = batch training_step_fx = getattr(self.trainer.lightning_module, "training_step") diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index b449355582..618717a69b 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -40,6 +40,7 @@ def check_finite_loss(loss: Optional[Tensor]) -> None: Args: loss: the loss value to check to be finite + """ if loss is not None and not torch.isfinite(loss).all(): raise ValueError(f"The loss returned in `training_step` is {loss}.") @@ -52,8 +53,8 @@ def _parse_loop_limits( max_epochs: Optional[int], trainer: "pl.Trainer", ) -> Tuple[int, int]: - """This utility computes the default values for the minimum and maximum number of steps and epochs given the - values the user has selected. + """This utility computes the default values for the minimum and maximum number of steps and epochs given the values + the user has selected. Args: min_steps: Minimum number of steps. @@ -64,6 +65,7 @@ def _parse_loop_limits( Returns: The parsed limits, with default values being set for the ones that the user did not specify. + """ if max_epochs is None: if max_steps == -1 and not any(isinstance(cb, Timer) for cb in trainer.callbacks): @@ -89,8 +91,8 @@ def _parse_loop_limits( @contextmanager def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Generator[None, None, None]: - """Blocks synchronization in :class:`~lightning.pytorch.strategies.parallel.ParallelStrategy`. This is useful - for example when accumulating gradients to reduce communication when it is not needed. + """Blocks synchronization in :class:`~lightning.pytorch.strategies.parallel.ParallelStrategy`. This is useful for + example when accumulating gradients to reduce communication when it is not needed. Args: strategy: the strategy instance to use. @@ -98,6 +100,7 @@ def _block_parallel_sync_behavior(strategy: Strategy, block: bool = True) -> Gen Returns: context manager with sync behaviour off + """ if isinstance(strategy, ParallelStrategy) and block: with strategy.block_backward_sync(): @@ -115,6 +118,7 @@ def _is_max_limit_reached(current: int, maximum: int = -1) -> bool: Returns: bool: whether the limit has been reached + """ return maximum != -1 and current >= maximum diff --git a/src/lightning/pytorch/overrides/distributed.py b/src/lightning/pytorch/overrides/distributed.py index 1480163dc5..9b86b5db30 100644 --- a/src/lightning/pytorch/overrides/distributed.py +++ b/src/lightning/pytorch/overrides/distributed.py @@ -142,6 +142,7 @@ def _register_ddp_comm_hook( ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, ) + """ if ddp_comm_hook is None: return @@ -191,15 +192,16 @@ def _sync_module_states(module: torch.nn.Module) -> None: class UnrepeatedDistributedSampler(DistributedSampler): - """A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches - per process to be off-by-one from each other. This makes this sampler usable for predictions (it's - deterministic and doesn't require shuffling). It is potentially unsafe to use this sampler for training, - because during training the DistributedDataParallel syncs buffers on each forward pass, so it could freeze if - one of the processes runs one fewer batch. During prediction, buffers are only synced on the first batch, so - this is safe to use as long as each process runs at least one batch. We verify this in an assert. + """A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches per + process to be off-by-one from each other. This makes this sampler usable for predictions (it's deterministic and + doesn't require shuffling). It is potentially unsafe to use this sampler for training, because during training the + DistributedDataParallel syncs buffers on each forward pass, so it could freeze if one of the processes runs one + fewer batch. During prediction, buffers are only synced on the first batch, so this is safe to use as long as each + process runs at least one batch. We verify this in an assert. Taken from https://github.com/jpuigcerver/PyLaia/blob/v1.0.0/laia/data/unpadded_distributed_sampler.py and https://github.com/pytorch/pytorch/issues/25162#issuecomment-634146002 + """ def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/lightning/pytorch/plugins/layer_sync.py b/src/lightning/pytorch/plugins/layer_sync.py index e777eb04ea..faa1ab23f1 100644 --- a/src/lightning/pytorch/plugins/layer_sync.py +++ b/src/lightning/pytorch/plugins/layer_sync.py @@ -33,10 +33,10 @@ class LayerSync(ABC): class TorchSyncBatchNorm(LayerSync): - """A plugin that wraps all batch normalization layers of a model with synchronization logic for - multiprocessing. + """A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing. This plugin has no effect in single-device operation. + """ def apply(self, model: Module) -> Module: @@ -50,6 +50,7 @@ class TorchSyncBatchNorm(LayerSync): Return: LightningModule with batchnorm layers synchronized within the process groups. + """ return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -61,6 +62,7 @@ class TorchSyncBatchNorm(LayerSync): Return: LightningModule with regular batchnorm layers that will no longer sync across processes. + """ # Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547 # Original author: Kapil Yedidi (@kapily) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 510b4c7d42..99e291b733 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -50,6 +50,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): Raises: ValueError: If unsupported ``precision`` is provided. + """ def __init__(self, precision: _PRECISION_INPUT) -> None: @@ -104,6 +105,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): optimizer: ignored for DeepSpeed \*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call \**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call + """ if is_overridden("backward", model): warning_cache.warn( diff --git a/src/lightning/pytorch/plugins/precision/double.py b/src/lightning/pytorch/plugins/precision/double.py index df2387ba80..a6cc31d9b6 100644 --- a/src/lightning/pytorch/plugins/precision/double.py +++ b/src/lightning/pytorch/plugins/precision/double.py @@ -39,6 +39,7 @@ class DoublePrecisionPlugin(PrecisionPlugin): """A context manager to change the default tensor type when initializing module parameters or tensors. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float64) @@ -50,6 +51,7 @@ class DoublePrecisionPlugin(PrecisionPlugin): """A context manager to change the default tensor type. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float64) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 5054969899..befa2d9bd0 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -29,6 +29,7 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin): """AMP for Fully Sharded Data Parallel (FSDP) Training. .. warning:: This is an :ref:`experimental ` feature. + """ def __init__( @@ -79,6 +80,7 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin): """A context manager to change the default tensor type when initializing module parameters or tensors. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(self.mixed_precision_config.param_dtype) @@ -90,5 +92,6 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin): """For FSDP, this context manager is a no-op since conversion is already handled internally. See: https://pytorch.org/docs/stable/fsdp.html for more details on mixed precision. + """ yield diff --git a/src/lightning/pytorch/plugins/precision/half.py b/src/lightning/pytorch/plugins/precision/half.py index dcafa3b33f..9e2ed0a6a5 100644 --- a/src/lightning/pytorch/plugins/precision/half.py +++ b/src/lightning/pytorch/plugins/precision/half.py @@ -28,6 +28,7 @@ class HalfPrecisionPlugin(PrecisionPlugin): Args: precision: Whether to use ``torch.float16`` (``'16-true'``) or ``torch.bfloat16`` (``'bf16-true'``). + """ precision: Literal["bf16-true", "16-true"] = "16-true" @@ -44,6 +45,7 @@ class HalfPrecisionPlugin(PrecisionPlugin): """A context manager to change the default tensor type when initializing module parameters or tensors. See: :meth:`torch.set_default_dtype` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(self._desired_input_dtype) @@ -52,10 +54,10 @@ class HalfPrecisionPlugin(PrecisionPlugin): @contextmanager def forward_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type when tensors get created during the module's - forward. + """A context manager to change the default tensor type when tensors get created during the module's forward. See: :meth:`torch.set_default_tensor_type` + """ default_dtype = torch.get_default_dtype() torch.set_default_dtype(self._desired_input_dtype) diff --git a/src/lightning/pytorch/plugins/precision/precision_plugin.py b/src/lightning/pytorch/plugins/precision/precision_plugin.py index 89fa734013..0c083f9427 100644 --- a/src/lightning/pytorch/plugins/precision/precision_plugin.py +++ b/src/lightning/pytorch/plugins/precision/precision_plugin.py @@ -32,6 +32,7 @@ class PrecisionPlugin(FabricPrecision, CheckpointHooks): """Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. + """ def connect( @@ -63,6 +64,7 @@ class PrecisionPlugin(FabricPrecision, CheckpointHooks): \*args: Positional arguments intended for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`. \**kwargs: Keyword arguments for the same purpose as ``*args``. + """ model.backward(tensor, *args, **kwargs) diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index cc1600af34..f758439646 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -25,10 +25,11 @@ log = logging.getLogger(__name__) class AdvancedProfiler(Profiler): - """This profiler uses Python's cProfiler to record more detailed information about time spent in each function - call recorded during a given action. + """This profiler uses Python's cProfiler to record more detailed information about time spent in each function call + recorded during a given action. The output is quite verbose and you should only use this if you want very detailed reports. + """ def __init__( diff --git a/src/lightning/pytorch/profilers/base.py b/src/lightning/pytorch/profilers/base.py index 5bf3b0e6e8..be2c50a8d0 100644 --- a/src/lightning/pytorch/profilers/base.py +++ b/src/lightning/pytorch/profilers/base.py @@ -20,6 +20,7 @@ class PassThroughProfiler(Profiler): """This class should be used when you don't want the (small) overhead of profiling. The Trainer uses this class by default. + """ def start(self, action_name: str) -> None: diff --git a/src/lightning/pytorch/profilers/profiler.py b/src/lightning/pytorch/profilers/profiler.py index 5bc23251a8..d7f168b3d2 100644 --- a/src/lightning/pytorch/profilers/profiler.py +++ b/src/lightning/pytorch/profilers/profiler.py @@ -62,6 +62,7 @@ class Profiler(ABC): The profiler will start once you've entered the context and will automatically stop once you exit the code block. + """ try: self.start(action_name) @@ -134,6 +135,7 @@ class Profiler(ABC): """Execute arbitrary post-profiling tear-down steps. Closes the currently open file and stream. + """ self._write_stream = None if self._output_file is not None: diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index fe3ab1c189..0b486f1aa5 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -44,8 +44,7 @@ _PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch class RegisterRecordFunction: - """While profiling autograd operations, this class will add labels for module names around the forward - function. + """While profiling autograd operations, this class will add labels for module names around the forward function. The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: @@ -60,6 +59,7 @@ class RegisterRecordFunction: from lightning.pytorch import Trainer, seed_everything with RegisterRecordFunction(model): out = model(batch) + """ def __init__(self, model: nn.Module) -> None: @@ -288,6 +288,7 @@ class PyTorchProfiler(Profiler): If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. If arg ``schedule`` is not a ``Callable``. If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. + """ super().__init__(dirpath=dirpath, filename=filename) diff --git a/src/lightning/pytorch/profilers/simple.py b/src/lightning/pytorch/profilers/simple.py index 3af44d4178..528290545e 100644 --- a/src/lightning/pytorch/profilers/simple.py +++ b/src/lightning/pytorch/profilers/simple.py @@ -32,8 +32,8 @@ _TABLE_DATA = List[_TABLE_ROW] class SimpleProfiler(Profiler): - """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each - action and the total time spent over the entire training run.""" + """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action + and the total time spent over the entire training run.""" def __init__( self, diff --git a/src/lightning/pytorch/profilers/xla.py b/src/lightning/pytorch/profilers/xla.py index b6ebe70fd2..2d1db1d3e5 100644 --- a/src/lightning/pytorch/profilers/xla.py +++ b/src/lightning/pytorch/profilers/xla.py @@ -31,12 +31,13 @@ class XLAProfiler(Profiler): } def __init__(self, port: int = 9012) -> None: - """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud - TPU performance tools. + """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU + performance tools. Args: port: the port to start the profiler server on. An exception is raised if the provided port is invalid or busy. + """ if not _XLA_AVAILABLE: raise ModuleNotFoundError(str(_XLA_AVAILABLE)) diff --git a/src/lightning/pytorch/serve/servable_module.py b/src/lightning/pytorch/serve/servable_module.py index 33efa9956a..f715f4b3ca 100644 --- a/src/lightning/pytorch/serve/servable_module.py +++ b/src/lightning/pytorch/serve/servable_module.py @@ -52,6 +52,7 @@ class ServableModule(ABC, torch.nn.Module): ) trainer.fit(ServableBoringModel()) assert serve_cb.resp.json() == {"output": [0, 1]} + """ @abstractmethod @@ -67,6 +68,7 @@ class ServableModule(ABC, torch.nn.Module): The second dictionary contains the name of the ``serve_step`` output variables name as its keys and the associated serialization function (e.g function to convert a tensors into payload). + """ @abstractmethod @@ -84,6 +86,7 @@ class ServableModule(ABC, torch.nn.Module): Return: - ``dict`` - A dictionary with their associated tensors. + """ @abstractmethod diff --git a/src/lightning/pytorch/serve/servable_module_validator.py b/src/lightning/pytorch/serve/servable_module_validator.py index e6db99091e..9e669a0455 100644 --- a/src/lightning/pytorch/serve/servable_module_validator.py +++ b/src/lightning/pytorch/serve/servable_module_validator.py @@ -36,6 +36,7 @@ class ServableModuleValidator(Callback): port: The port associated with the server. timeout: Timeout period in seconds, that the process should wait for the server to start. exit_on_failure: Whether to exit the process on failure. + """ def __init__( diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd1f72ab34..55708768a6 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -258,6 +258,7 @@ class DDPStrategy(ParallelStrategy): closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks **kwargs: Any extra arguments to ``optimizer.step`` + """ optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs) @@ -323,6 +324,7 @@ class DDPStrategy(ParallelStrategy): Return: reduced value, except when the input was not a tensor the output remains is unchanged + """ if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index 82e9b108fe..8b9da45cd1 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -246,6 +246,7 @@ class DeepSpeedStrategy(DDPStrategy): load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -389,6 +390,7 @@ class DeepSpeedStrategy(DDPStrategy): Return: The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single deepspeed optimizer. + """ if len(optimizers) != 1: raise ValueError( @@ -414,6 +416,7 @@ class DeepSpeedStrategy(DDPStrategy): """Initialize one model and one optimizer with an optional learning rate scheduler. This calls :func:`deepspeed.initialize` internally. + """ import deepspeed @@ -577,6 +580,7 @@ class DeepSpeedStrategy(DDPStrategy): Args: trainer: the Trainer, these optimizers should be connected to + """ if trainer.state.fn != TrainerFn.FITTING: return @@ -739,6 +743,7 @@ class DeepSpeedStrategy(DDPStrategy): Raises: TypeError: If ``storage_options`` arg is passed in + """ # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath filepath = self.broadcast(filepath) @@ -808,12 +813,13 @@ class DeepSpeedStrategy(DDPStrategy): self._restore_zero_state(checkpoint) def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: - """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be - sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then - automatically synced across processes. + """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be sharded + across processes before loading the state dictionary when using ZeRO stage 3. This is then automatically synced + across processes. Args: ckpt: The ckpt file. + """ import deepspeed diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 9555d105ee..c160daba6f 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -117,6 +117,7 @@ class FSDPStrategy(ParallelStrategy): Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value. \**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. + """ strategy_name = "fsdp" @@ -172,6 +173,7 @@ class FSDPStrategy(ParallelStrategy): To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty dict. + """ from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType @@ -386,6 +388,7 @@ class FSDPStrategy(ParallelStrategy): Return: reduced value, except when the input was not a tensor the output remains is unchanged + """ if isinstance(tensor, Tensor): return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index ebd30a19ee..6b087db5e4 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -60,6 +60,7 @@ class _MultiProcessingLauncher(_Launcher): - 'fork': Preferable for IPython/Jupyter environments where 'spawn' is not available. Not available on the Windows platform for example. - 'forkserver': Alternative implementation to 'fork'. + """ def __init__( @@ -93,6 +94,7 @@ class _MultiProcessingLauncher(_Launcher): trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer` for which a selected set of attributes get restored in the main process after processes join. **kwargs: Optional keyword arguments to be passed to the given function. + """ if self._start_method in ("fork", "forkserver"): _check_bad_cuda_fork() @@ -198,8 +200,8 @@ class _MultiProcessingLauncher(_Launcher): return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]: - """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. - To avoid issues with memory sharing, we cast the data to numpy. + """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To + avoid issues with memory sharing, we cast the data to numpy. Args: trainer: reference to the Trainer. @@ -207,6 +209,7 @@ class _MultiProcessingLauncher(_Launcher): Returns: A dictionary with items to send back to the main process where :meth:`update_main_process_results` will process this output. + """ callback_metrics: dict = apply_to_collection( trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy() @@ -263,6 +266,7 @@ class _GlobalStateSnapshot: # in worker process snapshot.restore() + """ use_deterministic_algorithms: bool @@ -272,8 +276,7 @@ class _GlobalStateSnapshot: @classmethod def capture(cls) -> "_GlobalStateSnapshot": - """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker - process.""" + """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process.""" return cls( use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(), use_deterministic_algorithms_warn_only=torch.is_deterministic_algorithms_warn_only_enabled(), diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index 0a0170f2eb..5afdcbec2f 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -67,6 +67,7 @@ class _SubprocessScriptLauncher(_Launcher): cluster_environment: A cluster environment that provides access to world size, node rank, etc. num_processes: The number of processes to launch in the current node. num_nodes: The total number of nodes that participate in this process group. + """ def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: @@ -89,6 +90,7 @@ class _SubprocessScriptLauncher(_Launcher): *args: Optional positional arguments to be passed to the given function. trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer`. **kwargs: Optional keyword arguments to be passed to the given function. + """ if not self.cluster_environment.creates_processes_externally: self._call_children_scripts() diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 961bc9bbb6..032ec26150 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -31,8 +31,8 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_debug class _XLALauncher(_MultiProcessingLauncher): - r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at - the end. + r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the + end. The main process in which this launcher is invoked creates N so-called worker processes (using the `torch_xla` :func:`xmp.spawn`) that run the given function. @@ -44,6 +44,7 @@ class _XLALauncher(_MultiProcessingLauncher): Args: strategy: A reference to the strategy that is used together with this launcher + """ def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None: @@ -67,6 +68,7 @@ class _XLALauncher(_MultiProcessingLauncher): trainer: Optional reference to the :class:`~lightning.pytorch.trainer.trainer.Trainer` for which a selected set of attributes get restored in the main process after processes join. **kwargs: Optional keyword arguments to be passed to the given function. + """ using_pjrt = _using_pjrt() # pjrt requires that the queue is serializable diff --git a/src/lightning/pytorch/strategies/parallel.py b/src/lightning/pytorch/strategies/parallel.py index 30439cdbd4..33dcd4be0b 100644 --- a/src/lightning/pytorch/strategies/parallel.py +++ b/src/lightning/pytorch/strategies/parallel.py @@ -112,6 +112,7 @@ class ParallelStrategy(Strategy, ABC): This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off + """ if isinstance(self.model, pl.utilities.types.DistributedDataParallel): with self.model.no_sync(): diff --git a/src/lightning/pytorch/strategies/single_device.py b/src/lightning/pytorch/strategies/single_device.py index 8083cccec8..a9809abe7c 100644 --- a/src/lightning/pytorch/strategies/single_device.py +++ b/src/lightning/pytorch/strategies/single_device.py @@ -57,6 +57,7 @@ class SingleDeviceStrategy(Strategy): Return: the unmodified input as reduction is not needed for single process operation + """ return tensor diff --git a/src/lightning/pytorch/strategies/strategy.py b/src/lightning/pytorch/strategies/strategy.py index 1b575e0d6b..6fb5636c2f 100644 --- a/src/lightning/pytorch/strategies/strategy.py +++ b/src/lightning/pytorch/strategies/strategy.py @@ -120,6 +120,7 @@ class Strategy(ABC): This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete. + """ assert self.accelerator is not None self.accelerator.setup_device(self.root_device) @@ -129,6 +130,7 @@ class Strategy(ABC): Args: trainer: the Trainer, these optimizers should be connected to + """ if trainer.state.fn != TrainerFn.FITTING: return @@ -140,6 +142,7 @@ class Strategy(ABC): Args: trainer: the trainer instance + """ assert self.accelerator is not None self.accelerator.setup(trainer) @@ -161,6 +164,7 @@ class Strategy(ABC): """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom strategies. + """ if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer @@ -189,6 +193,7 @@ class Strategy(ABC): \*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`. \**kwargs: Keyword arguments for the same purpose as ``*args``. + """ self.pre_backward(closure_loss) assert self.lightning_module is not None @@ -215,6 +220,7 @@ class Strategy(ABC): closure: closure calculating the loss value model: reference to the model, optionally defining optimizer step related hooks \**kwargs: Keyword arguments to ``optimizer.step`` + """ model = model or self.lightning_module # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed @@ -226,6 +232,7 @@ class Strategy(ABC): The returned objects are expected to be in the same order they were passed in. The default implementation will call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. + """ # TODO: standardize this across all plugins in Lightning and Fabric. Related refactor: #7324 model = self._setup_model(model) @@ -252,6 +259,7 @@ class Strategy(ABC): batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. + """ model = self.lightning_module device = device or self.root_device @@ -287,6 +295,7 @@ class Strategy(ABC): group: the process group to reduce reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. + """ @abstractmethod @@ -295,6 +304,7 @@ class Strategy(ABC): Args: name: an optional name to pass into barrier. + """ @abstractmethod @@ -304,6 +314,7 @@ class Strategy(ABC): Args: obj: the object to broadcast src: source rank + """ @abstractmethod @@ -314,6 +325,7 @@ class Strategy(ABC): tensor: the tensor to all_gather group: the process group to gather results from sync_grads: flag that allows users to synchronize gradients for all_gather op + """ def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool: @@ -358,6 +370,7 @@ class Strategy(ABC): """The actual training step. See :meth:`~lightning.pytorch.core.module.LightningModule.training_step` for more details + """ args, kwargs = self.precision_plugin.convert_input((args, kwargs)) assert self.lightning_module is not None @@ -371,6 +384,7 @@ class Strategy(ABC): """This hook is deprecated. Override :meth:`training_step` instead. + """ pass @@ -378,6 +392,7 @@ class Strategy(ABC): """The actual validation step. See :meth:`~lightning.pytorch.core.module.LightningModule.validation_step` for more details + """ args, kwargs = self.precision_plugin.convert_input((args, kwargs)) assert self.lightning_module is not None @@ -391,6 +406,7 @@ class Strategy(ABC): """The actual test step. See :meth:`~lightning.pytorch.core.module.LightningModule.test_step` for more details + """ args, kwargs = self.precision_plugin.convert_input((args, kwargs)) assert self.lightning_module is not None @@ -404,6 +420,7 @@ class Strategy(ABC): """The actual predict step. See :meth:`~lightning.pytorch.core.module.LightningModule.predict_step` for more details + """ args, kwargs = self.precision_plugin.convert_input((args, kwargs)) assert self.lightning_module is not None @@ -418,16 +435,18 @@ class Strategy(ABC): Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ return dataloader @property def restore_checkpoint_after_setup(self) -> bool: - """Override to delay restoring from checkpoint till after the setup phase has completed. This is useful - when the strategy requires all the setup hooks to run before loading checkpoint. + """Override to delay restoring from checkpoint till after the setup phase has completed. This is useful when + the strategy requires all the setup hooks to run before loading checkpoint. Returns: If ``True``, restore checkpoint after strategy setup. + """ return False @@ -436,6 +455,7 @@ class Strategy(ABC): """Override to disable Lightning restoring optimizers/schedulers. This is useful for strategies which manage restoring optimizers/schedulers. + """ return True @@ -458,6 +478,7 @@ class Strategy(ABC): checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin + """ if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) @@ -467,6 +488,7 @@ class Strategy(ABC): Args: filepath: Path to checkpoint + """ if self.is_global_zero: self.checkpoint_io.remove_checkpoint(filepath) @@ -478,6 +500,7 @@ class Strategy(ABC): Args: empty_init: Whether to initialize the model with empty weights (uninitialized memory). If ``None``, the strategy will decide. Some strategies may not support all options. + """ device_context = self.root_device if _TORCH_GREATER_EQUAL_2_0 else nullcontext() empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() @@ -486,11 +509,11 @@ class Strategy(ABC): @contextmanager def model_sharded_context(self) -> Generator[None, None, None]: - """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to - shard the model instantly, which is useful for extremely large models which can save memory and - initialization time. + """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard + the model instantly, which is useful for extremely large models which can save memory and initialization time. Returns: Model parallel context. + """ yield @@ -498,6 +521,7 @@ class Strategy(ABC): """This method is called to teardown the training process. It is the right place to release memory and free other resources. + """ _optimizers_to_device(self.optimizers, torch.device("cpu")) @@ -568,6 +592,7 @@ class _ForwardRedirection: """Implements the `forward-redirection`. A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. + """ def __call__( @@ -584,6 +609,7 @@ class _ForwardRedirection: `forward` method instead. **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched `forward` method instead. + """ assert method_name != "forward" original_forward = original_module.forward diff --git a/src/lightning/pytorch/strategies/xla.py b/src/lightning/pytorch/strategies/xla.py index 99ce29b7c6..39458a8cd7 100644 --- a/src/lightning/pytorch/strategies/xla.py +++ b/src/lightning/pytorch/strategies/xla.py @@ -266,6 +266,7 @@ class XLAStrategy(DDPStrategy): Args: filepath: Path to checkpoint + """ if self.local_rank == 0: self.checkpoint_io.remove_checkpoint(filepath) @@ -279,6 +280,7 @@ class XLAStrategy(DDPStrategy): sync_grads: flag that allows users to synchronize gradients for the all-gather operation. Return: A tensor of shape (world_size, ...) + """ if not self._launched: return tensor diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 1647794f23..2eab1bac09 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -29,13 +29,14 @@ log = logging.getLogger(__name__) def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: - r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test, - predict) as all errors should funnel through them. + r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) + as all errors should funnel through them. Args: trainer_fn: one of (fit, validate, test, predict) *args: positional arguments to be passed to the `trainer_fn` **kwargs: keyword arguments to be passed to `trainer_fn` + """ try: if trainer.strategy.launcher is not None: @@ -243,6 +244,7 @@ def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: Dict[s Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using `_call_callback_hooks` because we have special logic for getting callback_states. + """ pl_module = trainer.lightning_module if pl_module: diff --git a/src/lightning/pytorch/trainer/configuration_validator.py b/src/lightning/pytorch/trainer/configuration_validator.py index bebe781f2e..4a9c9c45c4 100644 --- a/src/lightning/pytorch/trainer/configuration_validator.py +++ b/src/lightning/pytorch/trainer/configuration_validator.py @@ -27,6 +27,7 @@ def _verify_loop_configurations(trainer: "pl.Trainer") -> None: Args: trainer: Lightning Trainer. Its `lightning_module` (the model) to check the configuration. + """ model = trainer.lightning_module diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 0e52a4b2fb..2bd55e0326 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -114,6 +114,7 @@ class _AcceleratorConnector: priorities which to take when: A. Class > str B. Strategy > Accelerator/precision/plugins + """ self.use_distributed_sampler = use_distributed_sampler _set_torch_flags(deterministic=deterministic, benchmark=benchmark) @@ -187,6 +188,7 @@ class _AcceleratorConnector: 4. plugins: The list of plugins may contain a Precision plugin, CheckpointIO, ClusterEnvironment and others. Additionally, other flags such as `precision` or `sync_batchnorm` can populate the list with the corresponding plugin instances. + """ if plugins is not None: plugins = [plugins] if not isinstance(plugins, list) else plugins @@ -456,8 +458,8 @@ class _AcceleratorConnector: return "ddp" def _check_strategy_and_fallback(self) -> None: - """Checks edge cases when the strategy selection was a string input, and we need to fall back to a - different choice depending on other parameters or the environment.""" + """Checks edge cases when the strategy selection was a string input, and we need to fall back to a different + choice depending on other parameters or the environment.""" # current fallback and check logic only apply to user pass in str config and object config # TODO this logic should apply to both str and object config strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index d649755172..bc4ae50164 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -161,6 +161,7 @@ class _CallbackConnector: callbacks already present in the trainer callbacks list, it will replace them. In addition, all :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks will be pushed to the end of the list, ensuring they run last. + """ trainer = self.trainer @@ -187,9 +188,9 @@ class _CallbackConnector: @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: - """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` - callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is - preserved, as well as the order of all other callbacks. + """Moves all the tuner specific callbacks at the beginning of the list and all the `ModelCheckpoint` callbacks + to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as + the order of all other callbacks. Args: callbacks: A list of callbacks. @@ -197,6 +198,7 @@ class _CallbackConnector: Return: A new list in which the first elements are tuner specific callbacks and last elements are ModelCheckpoints if there were any present in the input. + """ tuner_callbacks: List[Callback] = [] other_callbacks: List[Callback] = [] diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index ed54a9b4eb..7b66048c46 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -71,6 +71,7 @@ class _CheckpointConnector: 2. from fault-tolerant auto-saved checkpoint if found 3. from `checkpoint_path` file if provided 4. don't restore + """ self._ckpt_path = checkpoint_path if not checkpoint_path: @@ -209,8 +210,7 @@ class _CheckpointConnector: return ckpt_path def resume_end(self) -> None: - """Signal the connector that all states have resumed and memory for the checkpoint object can be - released.""" + """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" assert self.trainer.state.fn is not None if self._ckpt_path: message = "Restored all states" if self.trainer.state.fn == TrainerFn.FITTING else "Loaded model weights" @@ -235,6 +235,7 @@ class _CheckpointConnector: Args: checkpoint_path: Path to a PyTorch Lightning checkpoint file. + """ self.resume_start(checkpoint_path) @@ -266,6 +267,7 @@ class _CheckpointConnector: Hooks are called first to give the LightningModule a chance to modify the contents, then finally the model gets updated with the loaded weights. + """ if not self._loaded_checkpoint: return @@ -281,6 +283,7 @@ class _CheckpointConnector: """Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. + """ if not self._loaded_checkpoint: return @@ -320,6 +323,7 @@ class _CheckpointConnector: """Restores the loop progress from the pre-loaded checkpoint. Calls hooks on the loops to give it a chance to restore its state from the checkpoint. + """ if not self._loaded_checkpoint: return @@ -420,6 +424,7 @@ class _CheckpointConnector: something_cool_i_want_to_save: anything you define through model.on_save_checkpoint LightningDataModule.__class__.__qualname__: pl DataModule's state } + """ trainer = self.trainer model = trainer.lightning_module @@ -507,6 +512,7 @@ class _CheckpointConnector: name_key: file name prefix Returns: None if no-corresponding-file else maximum suffix number + """ # check directory existence fs, uri = url_to_fs(str(dir_path)) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 0b8726cbd1..424cfe7cf6 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -185,6 +185,7 @@ class _DataConnector: - Injecting a `DistributedDataSamplerWrapper` into the `DataLoader` if on a distributed environment - Wrapping the dataloader based on strategy-specific logic + """ # don't do anything if it's not a dataloader if not isinstance(dataloader, DataLoader): @@ -289,6 +290,7 @@ class _DataLoaderSource: instance: A LightningModule, LightningDataModule, or (a collection of) iterable(s). name: A name for this dataloader source. If the instance is a module, the name corresponds to the hook that returns the desired dataloader(s). + """ instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] @@ -298,6 +300,7 @@ class _DataLoaderSource: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. + """ if isinstance(self.instance, pl.LightningModule): return call._call_lightning_module_hook(self.instance.trainer, self.name, pl_module=self.instance) @@ -311,6 +314,7 @@ class _DataLoaderSource: """Returns whether the source dataloader can be retrieved or not. If the source is a module it checks that the method with given :attr:`name` is overridden. + """ return not self.is_module() or is_overridden(self.name, self.instance) @@ -318,6 +322,7 @@ class _DataLoaderSource: """Returns whether the DataLoader source is a LightningModule or a LightningDataModule. It does not check whether ``*_dataloader`` methods are actually overridden. + """ return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule)) @@ -327,6 +332,7 @@ def _request_dataloader(data_source: _DataLoaderSource) -> Union[TRAIN_DATALOADE Returns: The requested dataloader + """ with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py index 773257f119..61dd62cf9a 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py @@ -86,6 +86,7 @@ class _LoggerConnector: metrics: Metric values step: Step for which metrics should be logged. Default value is `self.global_step` during training or the total validation / test log step count during validation and testing. + """ if not self.trainer.loggers or not metrics: return diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 06ce4d021d..c74f34dbe1 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -403,6 +403,7 @@ class _ResultCollection(dict): """Create one _ResultMetric object per value. Value can be provided as a nested collection + """ metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device) self[key] = metric @@ -493,6 +494,7 @@ class _ResultCollection(dict): if False, only ``torch.Tensors`` are reset, if ``None``, both are. fx: Function to reset + """ for item in self.values(): requested_type = metrics is None or metrics ^ item.is_tensor diff --git a/src/lightning/pytorch/trainer/connectors/signal_connector.py b/src/lightning/pytorch/trainer/connectors/signal_connector.py index 7e6b7cd0c5..a5a0a5b693 100644 --- a/src/lightning/pytorch/trainer/connectors/signal_connector.py +++ b/src/lightning/pytorch/trainer/connectors/signal_connector.py @@ -139,6 +139,7 @@ class _SignalConnector: Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for older Python versions. + """ if _PYTHON_GREATER_EQUAL_3_8_0: return signal.valid_signals() diff --git a/src/lightning/pytorch/trainer/states.py b/src/lightning/pytorch/trainer/states.py index 73b7cb71dc..d386538f41 100644 --- a/src/lightning/pytorch/trainer/states.py +++ b/src/lightning/pytorch/trainer/states.py @@ -53,6 +53,7 @@ class RunningStage(LightningEnum): - ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING`` - ``TrainerFn.TESTING`` - ``RunningStage.TESTING`` - ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING`` + """ TRAINING = "train" diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index ba668009a5..1dfc680c85 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -301,6 +301,7 @@ class Trainer: MisconfigurationException: If ``gradient_clip_algorithm`` is invalid. If ``track_grad_norm`` is not a positive number or inf. + """ super().__init__() log.debug(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}") @@ -533,6 +534,7 @@ class Trainer: :class:`torch._dynamo.OptimizedModule` for torch versions greater than or equal to 2.0.0 . For more information about multiple dataloaders, see this :ref:`section `. + """ model = _maybe_unwrap_optimized(model) self.strategy._lightning_module = model @@ -626,6 +628,7 @@ class Trainer: RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. + """ if model is None: # do we still have a reference from a previous call? @@ -697,8 +700,8 @@ class Trainer: verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ) -> _EVALUATE_OUTPUT: - r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on - your test set until you want to. + r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your + test set until you want to. Args: model: The model to test. @@ -734,6 +737,7 @@ class Trainer: RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. + """ if model is None: # do we still have a reference from a previous call? @@ -843,6 +847,7 @@ class Trainer: If a compiled ``model`` is passed and the strategy is not supported. See :ref:`Lightning inference section` for more. + """ if model is None: # do we still have a reference from a previous call? @@ -1003,8 +1008,8 @@ class Trainer: return results def _teardown(self) -> None: - """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and - Callback; those are handled by :meth:`_call_teardown_hook`.""" + """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; + those are handled by :meth:`_call_teardown_hook`.""" self.strategy.teardown() loop = self._active_loop # loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn` @@ -1075,8 +1080,8 @@ class Trainer: @contextmanager def init_module(self, empty_init: Optional[bool] = None) -> Generator: - """Tensors that you instantiate under this context manager will be created on the device right away and - have the right data type depending on the precision setting in the Trainer. + """Tensors that you instantiate under this context manager will be created on the device right away and have + the right data type depending on the precision setting in the Trainer. The parameters and tensors get created on the device and with the right data type right away without wasting memory being allocated unnecessarily. The automatic device placement under this context manager is only @@ -1086,6 +1091,7 @@ class Trainer: empty_init: Whether to initialize the model with empty weights (uninitialized memory). If ``None``, the strategy will decide. Some strategies may not support all options. Set this to ``True`` if you are loading a checkpoint into a large model. Requires `torch >= 1.13`. + """ if not _TORCH_GREATER_EQUAL_2_0 and self.strategy.root_device.type != "cpu": rank_zero_warn( @@ -1113,6 +1119,7 @@ class Trainer: process in each machine. Arguments passed to this method are forwarded to the Python built-in :func:`print` function. + """ if self.local_rank == 0: print(*args, **kwargs) @@ -1210,6 +1217,7 @@ class Trainer: To access the pure LightningModule, use :meth:`~lightning.pytorch.trainer.trainer.Trainer.lightning_module` instead. + """ return self.strategy.model @@ -1228,6 +1236,7 @@ class Trainer: def training_step(self, batch, batch_idx): img = ... save_img(img, self.trainer.log_dir) + """ if len(self.loggers) > 0: if not isinstance(self.loggers[0], TensorBoardLogger): @@ -1249,6 +1258,7 @@ class Trainer: def training_step(self, batch, batch_idx): if self.trainer.is_global_zero: print("in node 0, accelerator 0") + """ return self.strategy.is_global_zero @@ -1272,6 +1282,7 @@ class Trainer: """The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. + """ if get_filesystem(self._default_root_dir).protocol == "file": return os.path.normpath(self._default_root_dir) @@ -1286,8 +1297,8 @@ class Trainer: @property def early_stopping_callbacks(self) -> List[EarlyStopping]: - """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in - the Trainer.callbacks list.""" + """A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the + Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @property @@ -1299,8 +1310,8 @@ class Trainer: @property def checkpoint_callbacks(self) -> List[Checkpoint]: - """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found - in the Trainer.callbacks list.""" + """A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in + the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property @@ -1334,6 +1345,7 @@ class Trainer: # you will be in charge of resetting this trainer.ckpt_path = None trainer.test(model) + """ self._checkpoint_connector._ckpt_path = ckpt_path self._checkpoint_connector._user_managed = bool(ckpt_path) @@ -1351,6 +1363,7 @@ class Trainer: Raises: AttributeError: If the model is not attached to the Trainer before calling this method. + """ if self.model is None: raise AttributeError( @@ -1422,6 +1435,7 @@ class Trainer: """Whether sanity checking is running. Useful to disable some hooks, logging or callbacks during the sanity checking. + """ return self.state.stage == RunningStage.SANITY_CHECKING @@ -1449,6 +1463,7 @@ class Trainer: """The number of optimizer steps taken (does not reset each epoch). This includes multiple optimizers (if enabled). + """ return self.fit_loop.epoch_loop.global_step @@ -1588,6 +1603,7 @@ class Trainer: for logger in trainer.loggers: logger.log_metrics({"foo": 1.0}) + """ return self._loggers @@ -1607,6 +1623,7 @@ class Trainer: callback_metrics = trainer.callback_metrics assert callback_metrics["a_val"] == 2.0 + """ return self._logger_connector.callback_metrics @@ -1616,6 +1633,7 @@ class Trainer: This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the :paramref:`~lightning.pytorch.core.module.LightningModule.log.logger` argument set. + """ return self._logger_connector.logged_metrics @@ -1625,6 +1643,7 @@ class Trainer: This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the :paramref:`~lightning.pytorch.core.module.LightningModule.log.prog_bar` argument set. + """ return self._logger_connector.progress_bar_metrics diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 2b3ec7ef7f..e8ab5afbaa 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -57,6 +57,7 @@ def _scale_batch_size( - ``model`` - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) + """ if trainer.fast_dev_run: rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.") @@ -212,10 +213,10 @@ def _run_binary_scaling( max_trials: int, params: Dict[str, Any], ) -> int: - """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is - encountered. + """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search + """ low = 1 high = None @@ -289,6 +290,7 @@ def _adjust_batch_size( Returns: The new batch size for the next trial and a bool that signals whether the new value is different than the previous batch size. + """ model = trainer.lightning_module batch_size = lightning_getattr(model, batch_arg_name) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index f83c390d27..b7c3aba5f4 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -88,6 +88,7 @@ class _LRFinder: # Get suggestion lr = lr_finder.suggestion() + """ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None: @@ -172,8 +173,8 @@ class _LRFinder: return fig def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float]: - """This will propose a suggestion for an initial learning rate based on the point with the steepest - negative gradient. + """This will propose a suggestion for an initial learning rate based on the point with the steepest negative + gradient. Args: skip_begin: how many samples to skip in the beginning; helps to avoid too naive estimates @@ -182,6 +183,7 @@ class _LRFinder: Returns: The suggested initial learning rate to use, or `None` if a suggestion is not possible due to too few loss samples. + """ losses = torch.tensor(self.results["loss"][skip_begin:-skip_end]) losses = losses[torch.isfinite(losses)] @@ -215,8 +217,8 @@ def _lr_find( update_attr: bool = False, attr_name: str = "", ) -> Optional[_LRFinder]: - """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in - picking a good starting learning rate. + """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking + a good starting learning rate. Args: trainer: A Trainer instance. @@ -235,6 +237,7 @@ def _lr_find( update_attr: Whether to update the learning rate attribute or not. attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. + """ if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.") @@ -342,8 +345,8 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) -> class _LRCallback(Callback): - """Special callback used by the learning rate finder. This callback logs the learning rate before each batch - and logs the corresponding loss after each batch. + """Special callback used by the learning rate finder. This callback logs the learning rate before each batch and + logs the corresponding loss after each batch. Args: num_training: number of iterations done by the learning rate finder @@ -355,6 +358,7 @@ class _LRCallback(Callback): beta: smoothing value, the loss being logged is a running average of loss values logged until now. ``beta`` controls the forget rate i.e. if ``beta=0`` all past information is ignored. + """ def __init__( @@ -443,6 +447,7 @@ class _LinearLR(_TORCH_LRSCHEDULER): num_iter: the number of iterations over which the test occurs. last_epoch: the index of last epoch. Default: -1. + """ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): @@ -478,6 +483,7 @@ class _ExponentialLR(_TORCH_LRSCHEDULER): num_iter: the number of iterations over which the test occurs. last_epoch: the index of last epoch. Default: -1. + """ def __init__(self, optimizer: torch.optim.Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1): diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 53b7b45210..2fd0399549 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -39,8 +39,8 @@ class Tuner: max_trials: int = 25, batch_arg_name: str = "batch_size", ) -> Optional[int]: - """Iteratively try to find the largest batch size for a given model that does not give an out of memory - (OOM) error. + """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) + error. Args: model: Model to tune. @@ -71,6 +71,7 @@ class Tuner: - ``model`` - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) + """ _check_tuner_configuration(train_dataloaders, val_dataloaders, dataloaders, method) _check_scale_batch_size_configuration(self._trainer) @@ -149,6 +150,7 @@ class Tuner: MisconfigurationException: If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, or if you are using more than one optimizer. + """ if method != "fit": raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.") diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index 888a3b3755..72d550ac10 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -38,6 +38,7 @@ def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argum >>> _parse_env_variables(Trainer) Namespace(devices=42) >>> del os.environ["PL_TRAINER_DEVICES"] + """ env_args = {} for arg_name in inspect.signature(cls).parameters: diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index 0e012dbae1..d7194b9196 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -238,6 +238,7 @@ class CombinedLoader(Iterable): tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1 tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1 tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1 + """ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py index ba9cd9c93f..ea2dc146bf 100644 --- a/src/lightning/pytorch/utilities/compile.py +++ b/src/lightning/pytorch/utilities/compile.py @@ -79,6 +79,7 @@ def to_uncompiled(model: Union["pl.LightningModule", "torch._dynamo.OptimizedMod returned by ``from_compiled``. Note: this method will in-place modify the ``LightningModule`` that is passed in. + """ if not _TORCH_GREATER_EQUAL_2_0: raise ModuleNotFoundError("`to_uncompiled` requires torch>=2.0") diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index ce69cdc4d6..a42e3053a2 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -234,14 +234,15 @@ def _dataloader_init_kwargs_resolve_sampler( mode: Optional[RunningStage] = None, disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: - """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its - re-instantiation. + """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- + instantiation. If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated automatically, since `poptorch.DataLoader` will try to increase the batch_size + """ is_predicting = mode == RunningStage.PREDICTING batch_sampler = getattr(dataloader, "batch_sampler") diff --git a/src/lightning/pytorch/utilities/exceptions.py b/src/lightning/pytorch/utilities/exceptions.py index cb06cc7572..b1ca189f69 100644 --- a/src/lightning/pytorch/utilities/exceptions.py +++ b/src/lightning/pytorch/utilities/exceptions.py @@ -25,6 +25,7 @@ class SIGTERMException(SystemExit): For example, you could use the :class:`lightning.pytorch.callbacks.fault_tolerance.OnExceptionCheckpoint` callback that saves a checkpoint for you when this exception is raised. + """ diff --git a/src/lightning/pytorch/utilities/grads.py b/src/lightning/pytorch/utilities/grads.py index c6f0d062df..21f737f8b6 100644 --- a/src/lightning/pytorch/utilities/grads.py +++ b/src/lightning/pytorch/utilities/grads.py @@ -35,6 +35,7 @@ def grad_norm(module: Module, norm_type: Union[float, int, str], group_separator norms: The dictionary of p-norms of each parameter's gradient and a special entry for the total p-norm of the gradients viewed as a single vector. + """ norm_type = float(norm_type) if norm_type <= 0: diff --git a/src/lightning/pytorch/utilities/memory.py b/src/lightning/pytorch/utilities/memory.py index 0922b63e0b..698de64d48 100644 --- a/src/lightning/pytorch/utilities/memory.py +++ b/src/lightning/pytorch/utilities/memory.py @@ -34,6 +34,7 @@ def recursive_detach(in_dict: Any, to_cpu: bool = False) -> Any: Return: out_dict: Dictionary with detached tensors + """ def detach_and_move(t: Tensor, to_cpu: bool) -> Tensor: diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 40803650e5..f935a1061c 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -27,6 +27,7 @@ For the Lightning developer: How to add a new migration? cp model.ckpt model.ckpt.backup python -m lightning.pytorch.utilities.upgrade_checkpoint model.ckpt + """ import re from typing import Any, Callable, Dict, List @@ -60,6 +61,7 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP Version: 0.10.0 Commit: a5d1176 + """ keys_mapping = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), @@ -87,6 +89,7 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ Version: 1.6.0 Commit: c67b075 PR: #13645, #11805 + """ global_step = checkpoint["global_step"] checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()}) @@ -107,6 +110,7 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> Version: 1.6.0 Commit: aea96e4 PR: #11805 + """ epoch = checkpoint["epoch"] checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()}) @@ -121,6 +125,7 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: Version: 1.6.5 Commit: c67b075 PR: #13645 + """ global_step = checkpoint["global_step"] checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"].setdefault("_batches_that_stepped", global_step) @@ -218,6 +223,7 @@ def _drop_apex_amp_state(checkpoint: _CHECKPOINT) -> _CHECKPOINT: Version: 2.0.0 Commit: e544676ff434ed96c6dd3b4e73a708bcb27ebcf1 PR: #16149 + """ key = "amp_scaling_state" if key in checkpoint: @@ -234,6 +240,7 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE Version: 2.0.0 Commit: 7807454 PR: #16337, #16172 + """ if "loops" not in checkpoint: return checkpoint @@ -265,13 +272,13 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Adjusts the loop structure since it changed when the support for multiple optimizers in automatic - optimization mode was removed. There is no longer a loop over optimizer, and hence no position to store for - resuming the loop. + """Adjusts the loop structure since it changed when the support for multiple optimizers in automatic optimization + mode was removed. There is no longer a loop over optimizer, and hence no position to store for resuming the loop. Version: 2.0.0 Commit: 6a56586 PR: #16539, #16598 + """ if "loops" not in checkpoint: return checkpoint diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index 49ae913263..3929ea4a47 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -48,6 +48,7 @@ def migrate_checkpoint( Note: The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large checkpoints and objects that do not support being deep-copied. + """ ckpt_version = _get_version(checkpoint) if Version(ckpt_version) > Version(pl.__version__): @@ -91,6 +92,7 @@ class pl_legacy_patch: with pl_legacy_patch(): torch.load("path/to/legacy/checkpoint.ckpt") + """ def __enter__(self) -> "pl_legacy_patch": @@ -135,6 +137,7 @@ def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_P """Applies Lightning version migrations to a checkpoint dictionary and prints infos for the user. This function is used by the Lightning Trainer when resuming from a checkpoint. + """ old_version = _get_version(checkpoint) checkpoint, migrations = migrate_checkpoint(checkpoint) @@ -182,6 +185,7 @@ class _RedirectingUnpickler(pickle._Unpickler): In legacy versions of Lightning, callback classes got pickled into the checkpoint. These classes are defined in the `pytorch_lightning` but need to be loaded from `lightning.pytorch`. + """ def find_class(self, module: str, name: str) -> Any: diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 4476ac5b25..8c50181f3b 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -38,8 +38,8 @@ NOT_APPLICABLE = "n/a" class LayerSummary: - """Summary class for a single layer in a :class:`~lightning.pytorch.core.module.LightningModule`. It collects - the following information: + """Summary class for a single layer in a :class:`~lightning.pytorch.core.module.LightningModule`. It collects the + following information: - Type of the layer (e.g. Linear, BatchNorm1d, ...) - Input shape @@ -65,6 +65,7 @@ class LayerSummary: Args: module: A module to summarize + """ def __init__(self, module: nn.Module) -> None: @@ -78,13 +79,13 @@ class LayerSummary: self.detach_hook() def _register_hook(self) -> Optional[RemovableHandle]: - """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If - the hook is called, it will remove itself from the from the module, meaning that recursive models will only - record their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not - supported. + """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the + hook is called, it will remove itself from the from the module, meaning that recursive models will only record + their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. Return: A handle for the installed hook, or ``None`` if registering the hook is not possible. + """ def hook(_: nn.Module, inp: Any, out: Any) -> None: @@ -116,6 +117,7 @@ class LayerSummary: """Removes the forward hook if it was not already removed in the forward pass. Will be called after the summary is created. + """ if self._hook_handle is not None: self._hook_handle.remove() @@ -194,6 +196,7 @@ class ModelSummary: 0 Non-trainable params 132 K Total params 0.530 Total estimated model params size (MB) + """ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: @@ -303,6 +306,7 @@ class ModelSummary: """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size + """ arrays = [ (" ", list(map(str, range(len(self._layer_summary))))), @@ -361,8 +365,8 @@ def _format_summary_table( model_size: float, *cols: Tuple[str, List[str]], ) -> str: - """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one - big string defining the summary table that are nicely formatted.""" + """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big + string defining the summary table that are nicely formatted.""" n_rows = len(cols[0][1]) n_cols = 1 + len(cols) @@ -425,6 +429,7 @@ def get_human_readable_count(number: int) -> str: Return: A string formatted according to the pattern described above. + """ assert number >= 0 labels = PARAMETER_NUM_UNITS @@ -463,5 +468,6 @@ def summarize(lightning_module: "pl.LightningModule", max_depth: int = 1) -> Mod Return: The model summary object + """ return ModelSummary(lightning_module, max_depth=max_depth) diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py index fc84fbeb54..b9a4993941 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py @@ -85,6 +85,7 @@ class DeepSpeedSummary(ModelSummary): """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size + """ arrays = [ (" ", list(map(str, range(len(self._layer_summary))))), diff --git a/src/lightning/pytorch/utilities/parameter_tying.py b/src/lightning/pytorch/utilities/parameter_tying.py index 9b12b456db..5f5ea505dc 100644 --- a/src/lightning/pytorch/utilities/parameter_tying.py +++ b/src/lightning/pytorch/utilities/parameter_tying.py @@ -15,6 +15,7 @@ Reference: https://github.com/pytorch/fairseq/blob/1f7ef9ed1e1061f8c7f88f8b94c7186834398690/fairseq/trainer.py#L110-L118 + """ from typing import Dict, List, Optional diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 9958bc9f24..7ed77af189 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -120,6 +120,7 @@ def collect_init_args( A list of dictionaries where each dictionary contains the arguments passed to the constructor at that level. The last entry corresponds to the constructor call of the most specific class in the hierarchy. + """ _, _, _, local_vars = inspect.getargvalues(frame) # frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy @@ -216,6 +217,7 @@ class AttributeDict(Dict): "key2": abc "my-key": 3.14 "new_key": 42 + """ def __getattr__(self, key: str) -> Optional[Any]: @@ -241,6 +243,7 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. + """ holders: List[Any] = [] @@ -269,6 +272,7 @@ def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, returns the last one that has it. + """ holders = _lightning_get_all_attr_holders(model, attribute) if len(holders) == 0: @@ -281,18 +285,20 @@ def lightning_hasattr(model: "pl.LightningModule", attribute: str) -> bool: """Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. + """ return _lightning_get_first_attr_holder(model, attribute) is not None def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[Any]: - """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and - the datamodule. + """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the + datamodule. Raises: AttributeError: If ``model`` doesn't have ``attribute`` in any of model namespace, the hparams namespace/dict, and the datamodule. + """ holder = _lightning_get_first_attr_holder(model, attribute) if holder is None: @@ -307,13 +313,14 @@ def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[A def lightning_setattr(model: "pl.LightningModule", attribute: str, value: Any) -> None: - """Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. - Will also set the attribute on datamodule, if it exists. + """Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. Will + also set the attribute on datamodule, if it exists. Raises: AttributeError: If ``model`` doesn't have ``attribute`` in any of model namespace, the hparams namespace/dict, and the datamodule. + """ holders = _lightning_get_all_attr_holders(model, attribute) if len(holders) == 0: diff --git a/src/lightning/pytorch/utilities/seed.py b/src/lightning/pytorch/utilities/seed.py index 50c88dbad6..10badab69c 100644 --- a/src/lightning/pytorch/utilities/seed.py +++ b/src/lightning/pytorch/utilities/seed.py @@ -38,6 +38,7 @@ def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: [tensor([0.7576]), tensor([0.2793]), tensor([0.4031])] >>> torch.rand(1) tensor([0.7576]) + """ states = _collect_rng_states(include_cuda) yield diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 732bc26cf5..3c67260a88 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -65,6 +65,7 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + """ reasons, kwargs = FabricRunIf( diff --git a/src/lightning/store/store.py b/src/lightning/store/store.py index ede05d5400..f2389f8b8a 100644 --- a/src/lightning/store/store.py +++ b/src/lightning/store/store.py @@ -37,6 +37,7 @@ def upload_model( The version of the model to be uploaded. If not provided, default will be latest (not overridden). progress_bar: A progress bar to show the uploading status. Disable this if not needed, by setting to `False`. + """ client = _Client() user = client.auth_service_get_user() @@ -71,6 +72,7 @@ def download_model( The version of the model to be uploaded. If not provided, default will be latest (not overridden). progress_bar: Show progress on download. + """ client = _Client() download_url = client.models_store_download_model(name=name, version=version).download_url @@ -82,6 +84,7 @@ def list_models() -> List[V1Model]: Returns: A list of model objects. + """ client = _Client() # TODO: Allow passing this diff --git a/tests/integrations_app/flagship/test_flashy.py b/tests/integrations_app/flagship/test_flashy.py index c40cb155c2..a69ed1fc5d 100644 --- a/tests/integrations_app/flagship/test_flashy.py +++ b/tests/integrations_app/flagship/test_flashy.py @@ -20,6 +20,7 @@ def validate_app_functionalities(app_page: "Page") -> None: https://github.com/Lightning-AI/LAI-Flashy-App/blob/main/tests/test_app_gallery.py#L205 app_page: The UI page of the app to be validated. + """ while True: with contextlib.suppress(playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError): diff --git a/tests/parity_pytorch/test_sync_batchnorm_parity.py b/tests/parity_pytorch/test_sync_batchnorm_parity.py index 11aca56510..7a0c0658e3 100644 --- a/tests/parity_pytorch/test_sync_batchnorm_parity.py +++ b/tests/parity_pytorch/test_sync_batchnorm_parity.py @@ -52,8 +52,8 @@ class SyncBNModule(LightningModule): @RunIf(min_cuda_gpus=2, standalone=True) def test_sync_batchnorm_parity(tmpdir): - """Test parity between 1) Training a synced batch-norm layer on 2 GPUs with batch size B per device 2) Training - a batch-norm layer on CPU with twice the batch size.""" + """Test parity between 1) Training a synced batch-norm layer on 2 GPUs with batch size B per device 2) Training a + batch-norm layer on CPU with twice the batch size.""" seed_everything(3) # 2 GPUS, batch size = 4 per GPU => total batch size = 8 model = SyncBNModule(batch_size=4) diff --git a/tests/tests_app/cli/test_cmd_launch.py b/tests/tests_app/cli/test_cmd_launch.py index b1fdf89ac9..9610c463c9 100644 --- a/tests/tests_app/cli/test_cmd_launch.py +++ b/tests/tests_app/cli/test_cmd_launch.py @@ -27,6 +27,7 @@ def test_run_frontend(monkeypatch): dispatcher. This CLI call is made by Lightning AI and is not meant to be invoked by the user directly. + """ runner = CliRunner() diff --git a/tests/tests_app/cli/test_run_app.py b/tests/tests_app/cli/test_run_app.py index c3ca5c1ac0..a457b7b03e 100644 --- a/tests/tests_app/cli/test_run_app.py +++ b/tests/tests_app/cli/test_run_app.py @@ -68,6 +68,7 @@ def test_lightning_run_app_cloud(mock_dispatch: mock.MagicMock, open_ui, caplog, It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the call to `dispatch` for the right arguments. + """ monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger()) @@ -116,6 +117,7 @@ def test_lightning_run_app_cloud_with_run_app_commands(mock_dispatch: mock.Magic It tests it by checking if the click.launch is called with the right url if --open-ui was true and also checks the call to `dispatch` for the right arguments. + """ monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger()) @@ -182,6 +184,7 @@ def test_lightning_run_app_enable_basic_auth_passed(mock_dispatch: mock.MagicMoc """This test just validates the command has ran properly when --enable-basic-auth argument is passed. It checks the call to `dispatch` for the right arguments. + """ monkeypatch.setattr("lightning.app.runners.cloud.logger", logging.getLogger()) diff --git a/tests/tests_app/components/python/test_python.py b/tests/tests_app/components/python/test_python.py index 1475ec57dc..14cfb9b9c4 100644 --- a/tests/tests_app/components/python/test_python.py +++ b/tests/tests_app/components/python/test_python.py @@ -77,8 +77,8 @@ def test_tracer_python_script_with_kwargs(): def test_tracer_component_with_code(): - """This test ensures the Tracer Component gets the latest code from the code object that is provided and - arguments are cleaned.""" + """This test ensures the Tracer Component gets the latest code from the code object that is provided and arguments + are cleaned.""" drive = Drive("lit://code") drive.component_name = "something" @@ -125,8 +125,8 @@ def test_tracer_component_with_code(): def test_tracer_component_with_code_in_dir(tmp_path): - """This test ensures the Tracer Component gets the latest code from the code object that is provided and - arguments are cleaned.""" + """This test ensures the Tracer Component gets the latest code from the code object that is provided and arguments + are cleaned.""" drive = Drive("lit://code") drive.component_name = "something" diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index 1a03b6e356..01e7c1cb17 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -95,6 +95,7 @@ def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``. + """ import logging @@ -119,14 +120,15 @@ def caplog(caplog): @pytest.fixture() def patch_constants(request): - """This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for - the duration of a test. + """This fixture can be used with indirect parametrization to patch values in `lightning.app.core.constants` for the + duration of a test. Example:: @pytest.mark.parametrize("patch_constants", [{"LIGHTNING_CLOUDSPACE_HOST": "any"}], indirect=True) def test_my_stuff(patch_constants): ... + """ # Set constants old_constants = {} diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index b35d064772..367e822b9d 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -181,8 +181,8 @@ class AppStageTestingApp(LightningApp): # FIXME: This test doesn't assert anything @pytest.mark.skip(reason="TODO: Resolve flaky test.") def test_app_stage_from_frontend(): - """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would - start and stop the app.""" + """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would start + and stop the app.""" app = AppStageTestingApp(FlowA(), log_level="debug") app.stage = AppStage.BLOCKING MultiProcessRuntime(app, start_server=True).dispatch() @@ -193,6 +193,7 @@ def test_update_publish_state_and_maybe_refresh_ui(): - receives the state from the `publish_state_queue` and populates the app_state_store - receives a notification to refresh the UI and makes a GET Request (streamlit). + """ app = AppStageTestingApp(FlowA(), log_level="debug") publish_state_queue = _MockQueue("publish_state_queue") @@ -215,6 +216,7 @@ async def test_start_server(x_lightning_type, monkeypatch): - the state on GET /api/v1/state - push a delta when making a POST request to /api/v1/state + """ class InfiniteQueue(_MockQueue): diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index b42103fcb7..aa0c97f65f 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -1020,8 +1020,8 @@ def test_non_updated_flow(caplog): def test_debug_mode_logging(): - """This test validates the DEBUG messages are collected when activated by the LightningApp(debug=True) and - cleanup once finished.""" + """This test validates the DEBUG messages are collected when activated by the LightningApp(debug=True) and cleanup + once finished.""" from lightning.app.core.app import _console diff --git a/tests/tests_app/core/test_lightning_flow.py b/tests/tests_app/core/test_lightning_flow.py index 36388de3de..dd89bb739f 100644 --- a/tests/tests_app/core/test_lightning_flow.py +++ b/tests/tests_app/core/test_lightning_flow.py @@ -77,8 +77,7 @@ def test_unsupported_attribute_types(cls, attribute): ], ) def test_unsupported_attribute_declaration_outside_init_or_run(name, value): - """Test that LightningFlow attributes (with a few exceptions) are not allowed to be declared outside - __init__.""" + """Test that LightningFlow attributes (with a few exceptions) are not allowed to be declared outside __init__.""" flow = EmptyFlow() with pytest.raises(AttributeError, match=f"Cannot set attributes that were not defined in __init__: {name}"): setattr(flow, name, value) @@ -102,8 +101,8 @@ def test_unsupported_attribute_declaration_outside_init_or_run(name, value): ) @pytest.mark.parametrize("defined", [False, True]) def test_unsupported_attribute_declaration_inside_run(defined, name, value): - """Test that LightningFlow attributes can set LightningFlow or LightningWork inside its run method, but - everything else needs to be defined in the __init__ method.""" + """Test that LightningFlow attributes can set LightningFlow or LightningWork inside its run method, but everything + else needs to be defined in the __init__ method.""" class Flow(LightningFlow): def __init__(self): @@ -163,8 +162,8 @@ def test_name_gets_removed_from_state_when_defined_as_flow_works(value): ], ) def test_supported_attribute_declaration_outside_init(name, value): - """Test the custom LightningFlow setattr implementation for the few reserved attributes that are allowed to be - set from outside __init__.""" + """Test the custom LightningFlow setattr implementation for the few reserved attributes that are allowed to be set + from outside __init__.""" flow = EmptyFlow() setattr(flow, name, value) assert getattr(flow, name) == value diff --git a/tests/tests_app/core/test_lightning_work.py b/tests/tests_app/core/test_lightning_work.py index 3b7437f95b..d7870d3e09 100644 --- a/tests/tests_app/core/test_lightning_work.py +++ b/tests/tests_app/core/test_lightning_work.py @@ -70,8 +70,7 @@ def test_lightning_work_no_children_allowed(): def test_forgot_to_call_init(): - """This test validates the error message for user registering state without calling __init__ is - comprehensible.""" + """This test validates the error message for user registering state without calling __init__ is comprehensible.""" class W(LightningWork): def __init__(self): @@ -110,8 +109,8 @@ def test_unsupported_attribute_declaration_outside_init(name, value): ], ) def test_supported_attribute_declaration_outside_init(name, value): - """Test the custom LightningWork setattr implementation for the few reserved attributes that are allowed to be - set from outside __init__.""" + """Test the custom LightningWork setattr implementation for the few reserved attributes that are allowed to be set + from outside __init__.""" flow = EmptyWork() setattr(flow, name, value) assert getattr(flow, name) == value diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py index 8dd6d7d3a0..40abe438a3 100644 --- a/tests/tests_app/core/test_queues.py +++ b/tests/tests_app/core/test_queues.py @@ -21,6 +21,7 @@ def test_queue_api(queue_type, monkeypatch): """Test the Queue API. This test run all the Queue implementation but we monkeypatch the Redis Queues to avoid external interaction + """ import redis diff --git a/tests/tests_app/frontend/panel/test_app_state_watcher.py b/tests/tests_app/frontend/panel/test_app_state_watcher.py index a9c23b1619..21faeba495 100644 --- a/tests/tests_app/frontend/panel/test_app_state_watcher.py +++ b/tests/tests_app/frontend/panel/test_app_state_watcher.py @@ -4,6 +4,7 @@ - to access and change the App state. This is particularly useful for the PanelFrontend, but can be used by other Frontends too. + """ # pylint: disable=protected-access import os @@ -38,6 +39,7 @@ def test_init(flow_state_state: dict): - the .state is set - the .state is scoped to the flow state + """ # When app = AppStateWatcher() @@ -54,6 +56,7 @@ def test_update_flow_state(flow_state_state: dict): """We can update the state. - the .state is scoped to the flow state + """ app = AppStateWatcher() org_state = app.state @@ -67,6 +70,7 @@ def test_is_singleton(): Its key that __new__ and __init__ of AppStateWatcher is only called once. See https://github.com/holoviz/param/issues/643 + """ # When app1 = AppStateWatcher() diff --git a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py index 3244c07af0..8e8bc4d415 100644 --- a/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py +++ b/tests/tests_app/frontend/panel/test_panel_serve_render_fn.py @@ -1,6 +1,7 @@ """The panel_serve_render_fn_or_file file gets run by Python to launch a Panel Server with Lightning. These tests are for serving a render_fn function. + """ import inspect import os @@ -41,6 +42,7 @@ def test_get_view_fn_args(): """We have a helper get_view_fn function that create a function for our view. If the render_fn provides an argument an AppStateWatcher is provided as argument + """ result = _get_render_fn() assert isinstance(result(), AppStateWatcher) @@ -61,6 +63,7 @@ def test_get_view_fn_no_args(): """We have a helper get_view_fn function that create a function for our view. If the render_fn provides an argument an AppStateWatcher is provided as argument + """ result = _get_render_fn() assert result() == "no_args" diff --git a/tests/tests_app/plugin/test_plugin.py b/tests/tests_app/plugin/test_plugin.py index eaa75a60dc..fef33ace92 100644 --- a/tests/tests_app/plugin/test_plugin.py +++ b/tests/tests_app/plugin/test_plugin.py @@ -38,8 +38,8 @@ class _MockResponse: def mock_requests_get(valid_url, return_value): - """Used to replace `requests.get` with a function that returns the given value for the given valid URL and - raises otherwise.""" + """Used to replace `requests.get` with a function that returns the given value for the given valid URL and raises + otherwise.""" def inner(url): if url == valid_url: diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index af569b5ccc..8df770d611 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -254,6 +254,7 @@ class TestAppCreationClient: """Deleted apps show up in list apps but not in list instances. This tests that we don't try to reacreate a previously deleted app. + """ entrypoint = Path(tmpdir) / "entrypoint.py" entrypoint.touch() @@ -1881,10 +1882,11 @@ def test_print_specs(tmpdir, caplog, monkeypatch, print_format, expected): def test_incompatible_cloud_compute_and_build_config(monkeypatch): - """Test that an exception is raised when a build config has a custom image defined, but the cloud compute is - the default. + """Test that an exception is raised when a build config has a custom image defined, but the cloud compute is the + default. This combination is not supported by the platform. + """ mock_client = mock.MagicMock() cloud_backend = mock.MagicMock(client=mock_client) diff --git a/tests/tests_app/storage/test_copier.py b/tests/tests_app/storage/test_copier.py index fd4e274b91..f16ce57c7f 100644 --- a/tests/tests_app/storage/test_copier.py +++ b/tests/tests_app/storage/test_copier.py @@ -45,8 +45,8 @@ def test_copier_copies_all_files(fs_mock, stat_mock, dir_mock, tmpdir): @mock.patch("lightning.app.storage.path.pathlib.Path.is_dir") @mock.patch("lightning.app.storage.path.pathlib.Path.stat") def test_copier_handles_exception(stat_mock, dir_mock, monkeypatch): - """Test that the Copier captures exceptions from the file copy and forwards them through the queue without - raising it.""" + """Test that the Copier captures exceptions from the file copy and forwards them through the queue without raising + it.""" stat_mock().st_size = 0 dir_mock.return_value = False copy_request_queue = _MockQueue() diff --git a/tests/tests_app/storage/test_path.py b/tests/tests_app/storage/test_path.py index 423a4e7117..56bf5dbc1e 100644 --- a/tests/tests_app/storage/test_path.py +++ b/tests/tests_app/storage/test_path.py @@ -483,8 +483,8 @@ class RunPathWork(LightningWork): def test_path_as_argument_to_run_method(): - """Test that Path objects can be passed as arguments to the run() method of a Work in various ways such that - the origin, consumer and queues get automatically attached.""" + """Test that Path objects can be passed as arguments to the run() method of a Work in various ways such that the + origin, consumer and queues get automatically attached.""" root = RunPathFlow() app = LightningApp(root) MultiProcessRuntime(app, start_server=False).dispatch() @@ -621,8 +621,8 @@ def test_path_response_not_matching_reqeuest(tmpdir): def test_path_exists(tmpdir): - """Test that the Path.exists() behaves as expected: First it should check if the file exists locally, and if - not, send a message to the orchestrator to eventually check the existenc on the origin Work.""" + """Test that the Path.exists() behaves as expected: First it should check if the file exists locally, and if not, + send a message to the orchestrator to eventually check the existenc on the origin Work.""" # Local Path (no Work queues attached) assert not Path("file").exists() assert Path(tmpdir).exists() diff --git a/tests/tests_app/utilities/test_introspection.py b/tests/tests_app/utilities/test_introspection.py index b3371e2348..ce5d54f4e0 100644 --- a/tests/tests_app/utilities/test_introspection.py +++ b/tests/tests_app/utilities/test_introspection.py @@ -38,8 +38,8 @@ def test_introspection_lightning(): @_RunIf(pl=True) def test_introspection_lightning_overrides(): - """This test validates the scanner can find all the subclasses from primitives classes from PyTorch Lightning - in the provided files.""" + """This test validates the scanner can find all the subclasses from primitives classes from PyTorch Lightning in + the provided files.""" scanner = Scanner(str(os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/lightning_cli.py"))) scan = scanner.scan() assert set(scan) == {"LightningDataModule", "LightningModule"} diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index d62cf2d6a2..1134574f24 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -288,8 +288,7 @@ def test_proxy_timeout(): @mock.patch("lightning.app.utilities.proxies._Copier") def test_path_argument_to_transfer(*_): - """Test that any Lightning Path objects passed to the run method get transferred automatically (if they - exist).""" + """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist).""" class TransferPathWork(LightningWork): def run(self, *args, **kwargs): @@ -372,8 +371,7 @@ def test_path_argument_to_transfer(*_): ) @mock.patch("lightning.app.utilities.proxies._Copier") def test_path_attributes_to_transfer(_, origin, exists_remote, expected_get): - """Test that any Lightning Path objects passed to the run method get transferred automatically (if they - exist).""" + """Test that any Lightning Path objects passed to the run method get transferred automatically (if they exist).""" path_mock = Mock() path_mock.origin_name = origin path_mock.exists_remote = Mock(return_value=exists_remote) @@ -518,8 +516,8 @@ def test_persist_artifacts(tmp_path): def test_work_state_observer(): - """Tests that the WorkStateObserver sends deltas to the queue when state residuals remain that haven't been - handled by the setattr.""" + """Tests that the WorkStateObserver sends deltas to the queue when state residuals remain that haven't been handled + by the setattr.""" class WorkWithoutSetattr(LightningWork): def __init__(self): diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index 9408ad4292..37ae024335 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -71,8 +71,8 @@ def test_set_cuda_device(_, set_device_mock): @mock.patch("torch.cuda.is_available", return_value=True) @mock.patch("torch.cuda.device_count", return_value=100) def test_num_cuda_devices_without_nvml(*_): - """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for - determining CUDA availability.""" + """Test that if NVML can't be loaded, our helper functions fall back to the default implementation for determining + CUDA availability.""" num_cuda_devices.cache_clear() assert is_cuda_available() assert num_cuda_devices() == 100 diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 77f78da4b3..b1dbc89633 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -121,6 +121,7 @@ def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``. + """ import logging diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py index fb32b80a2a..0f02572c5c 100644 --- a/tests/tests_fabric/plugins/environments/test_mpi.py +++ b/tests/tests_fabric/plugins/environments/test_mpi.py @@ -70,8 +70,7 @@ def test_default_attributes(monkeypatch): def test_init_local_comm(monkeypatch): - """Test that it can determine the node rank and local rank based on the hostnames of all participating - nodes.""" + """Test that it can determine the node rank and local rank based on the hostnames of all participating nodes.""" # pretend mpi4py is available monkeypatch.setattr(lightning.fabric.plugins.environments.mpi, "_MPI4PY_AVAILABLE", True) mpi4py_mock = MagicMock() diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py index 76226ea2f2..e989534343 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py @@ -28,6 +28,7 @@ def test_deepspeed_precision_choice(_, precision): """Test to ensure precision plugin is correctly chosen. DeepSpeed handles precision via custom DeepSpeedPrecision. + """ connector = _Connector( accelerator="auto", diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index c66782b61d..117d0bb3db 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -74,8 +74,7 @@ def test_ddp_no_backward_sync(): @mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel") def test_ddp_extra_kwargs(ddp_mock): - """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel - wrapper.""" + """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel wrapper.""" module = torch.nn.Linear(1, 1) strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")]) strategy.setup_module(module) diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 84cd86ffc3..2edded9ca8 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -399,6 +399,7 @@ def test_validate_parallel_devices_indices(device_indices): """Test that the strategy validates that it doesn't support selecting specific devices by index. DeepSpeed doesn't support it and needs the index to match to the local rank of the process. + """ strategy = DeepSpeedStrategy( accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices] diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index ce879e050f..3db926a6e9 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -245,6 +245,7 @@ def test_deepspeed_env_variables_on_platforms(_, deepspeed_dist_mock, platform): """Test to ensure that we set up distributed communication correctly. When using Windows, ranks environment variables should not be set, and DeepSpeed should handle this. + """ fabric = Fabric(strategy=DeepSpeedStrategy(stage=3)) strategy = fabric._strategy diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index dd782d636f..71de0b2534 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -361,8 +361,7 @@ def test_fsdp_load_unknown_checkpoint_type(tmp_path): @RunIf(min_torch="2.0.0") def test_fsdp_load_raw_checkpoint_validate_single_file(tmp_path): - """Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict - checkpoint.""" + """Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint.""" strategy = FSDPStrategy() model = Mock(spec=nn.Module) path = tmp_path / "folder" @@ -451,6 +450,7 @@ class StatusChecker: This is confusing (since it logs "FAILED"), but more importantly the orphan rank will continue trying to execute the rest of the test suite. So instead we add calls to `os._exit` which actually forces the process to shut down. + """ success = False try: diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py index 5eb2119847..ae101e4272 100644 --- a/tests/tests_fabric/strategies/test_strategy.py +++ b/tests/tests_fabric/strategies/test_strategy.py @@ -155,8 +155,7 @@ def test_load_checkpoint_strict_loading(tmp_path): def test_load_checkpoint_non_strict_loading(tmp_path): - """Test that no error is raised if `strict=False` and state is requested that does not exist in the - checkpoint.""" + """Test that no error is raised if `strict=False` and state is requested that does not exist in the checkpoint.""" strategy = SingleDeviceStrategy() # surrogate class to test implementation in base class # objects with initial state diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 164882d013..7db7fcf8da 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -69,8 +69,8 @@ def test_cli_env_vars_strategy(_, strategy, monkeypatch, fake_script): def test_cli_get_supported_strategies(): - """Test to ensure that when new strategies get added, we must consider updating the list of supported ones in - the CLI.""" + """Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the + CLI.""" if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available(): assert len(_get_supported_strategies()) == 7 assert "fsdp" in _get_supported_strategies() diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index ae99397de3..045ab64d1e 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -120,8 +120,8 @@ def test_setup_compiled_module(setup_method): @pytest.mark.parametrize("move_to_device", [True, False]) @pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, initial_device, target_device): - """Test that `move_to_device` leads to parameters being moved to the correct device and that the device - attributes on the wrapper are updated.""" + """Test that `move_to_device` leads to parameters being moved to the correct device and that the device attributes + on the wrapper are updated.""" initial_device = torch.device(initial_device) target_device = torch.device(target_device) expected_device = target_device if move_to_device else initial_device @@ -149,8 +149,7 @@ def test_setup_module_move_to_device(setup_method, move_to_device, accelerator, @pytest.mark.parametrize("move_to_device", [True, False]) @pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) def test_setup_module_parameters_on_different_devices(setup_method, move_to_device): - """Test that a warning is emitted when model parameters are on a different device prior to calling - `setup()`.""" + """Test that a warning is emitted when model parameters are on a different device prior to calling `setup()`.""" device0 = torch.device("cpu") device1 = torch.device("cuda", 0) @@ -262,8 +261,7 @@ def test_setup_optimizers_twice_fails(): @pytest.mark.parametrize("strategy_cls", [DeepSpeedStrategy, XLAStrategy]) def test_setup_optimizers_not_supported(strategy_cls): - """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers - independently.""" + """Test that `setup_optimizers` validates the strategy supports setting up model and optimizers independently.""" fabric = Fabric() fabric._launched = True # pretend we have launched multiple processes model = nn.Linear(1, 2) @@ -275,8 +273,7 @@ def test_setup_optimizers_not_supported(strategy_cls): @RunIf(min_cuda_gpus=1, min_torch="2.1") def test_setup_optimizer_on_meta_device(): - """Test that the setup-methods validate that the optimizer doesn't have references to meta-device - parameters.""" + """Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters.""" fabric = Fabric(strategy="fsdp", devices=1) fabric._launched = True # pretend we have launched multiple processes with fabric.init_module(empty_init=True): @@ -350,8 +347,8 @@ def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): def test_setup_dataloaders_raises_for_unknown_custom_args(): - """Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's - run method.""" + """Test that an error raises when custom dataloaders with unknown arguments are created from outside Fabric's run + method.""" class CustomDataLoader(DataLoader): def __init__(self, new_arg, *args, **kwargs): @@ -508,8 +505,7 @@ def test_seed_everything(): ], ) def test_setup_dataloaders_replace_custom_sampler(strategy): - """Test that asking to replace a custom sampler results in an error when a distributed sampler would be - needed.""" + """Test that asking to replace a custom sampler results in an error when a distributed sampler would be needed.""" custom_sampler = Mock(spec=Sampler) dataloader = DataLoader(Mock(), sampler=custom_sampler) @@ -744,8 +740,7 @@ def test_overridden_run_and_cli_not_allowed(): def test_module_sharding_context(): - """Test that the sharding context manager gets applied when the strategy supports it and is a no-op - otherwise.""" + """Test that the sharding context manager gets applied when the strategy supports it and is a no-op otherwise.""" fabric = Fabric() fabric._strategy = MagicMock(spec=DDPStrategy, module_sharded_context=Mock()) with pytest.warns(DeprecationWarning, match="sharded_model"), fabric.sharded_model(): diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 79b0a33c57..1b0f7333db 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -267,8 +267,8 @@ def test_fabric_module_device_dtype_propagation(device_str, dtype): def test_fabric_dataloader_iterator(): - """Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no - automatic device placement).""" + """Test that the iteration over a FabricDataLoader wraps the iterator of the underlying dataloader (no automatic + device placement).""" dataloader = DataLoader(range(5), batch_size=2) fabric_dataloader = _FabricDataLoader(dataloader) assert len(fabric_dataloader) == len(dataloader) == 3 diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 533e52c299..82b601f366 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -48,13 +48,13 @@ def test_has_len(): def test_replace_dunder_methods_multiple_loaders_without_init(): """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent - class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error - occured only sometimes because it depends on the order in which we are iterating over a set of classes we are - patching. + class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured + only sometimes because it depends on the order in which we are iterating over a set of classes we are patching. This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and b) the mechanism checking for presence of `__old__init__` works as expected. + """ classes = [DataLoader] for i in range(100): @@ -253,10 +253,11 @@ def test_replace_dunder_methods_extra_kwargs(): def test_replace_dunder_methods_attrs(): - """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` - are correctly preserved even after reinstantiation. + """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` are + correctly preserved even after reinstantiation. It also includes a custom `__setattr__` + """ class Loader(DataLoader): @@ -413,11 +414,12 @@ def test_update_dataloader_typerror_custom_exception(): def test_custom_batch_sampler(): - """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to - properly reinstantiate the class, is invoked properly. + """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly + reinstantiate the class, is invoked properly. It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore not setting `__pl_saved_{args,arg_names,kwargs}` attributes. + """ class MyBatchSampler(BatchSampler): @@ -456,8 +458,7 @@ def test_custom_batch_sampler(): def test_custom_batch_sampler_no_sampler(): - """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler - argument.""" + """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument.""" class MyBatchSampler(BatchSampler): # Custom batch sampler, without sampler argument. @@ -511,10 +512,10 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset(): def test_dataloader_kwargs_replacement_with_array_default_comparison(): - """Test that the comparison of attributes and default argument values works with arrays (truth value - ambiguous). + """Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous). Regression test for issue #15408. + """ dataset = RandomDataset(5, 100) diff --git a/tests/tests_fabric/utilities/test_device_dtype_mixin.py b/tests/tests_fabric/utilities/test_device_dtype_mixin.py index bb1570eed0..1261ca5e0a 100644 --- a/tests/tests_fabric/utilities/test_device_dtype_mixin.py +++ b/tests/tests_fabric/utilities/test_device_dtype_mixin.py @@ -36,8 +36,8 @@ class TopModule(_DeviceDtypeModuleMixin): ) @RunIf(min_cuda_gpus=1) def test_submodules_device_and_dtype(dst_device_str, dst_type): - """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and - the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule).""" + """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and the + special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule).""" dst_device = torch.device(dst_device_str) model = TopModule() assert model.device == torch.device("cpu") diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index c5286d0dbe..4cecc4658c 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -78,6 +78,7 @@ def test_sanitize_callable_params(): """Callback function are not serializiable. Therefore, we get them a chance to return something and if the returned type is not accepted, return None. + """ def return_something(): diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py index a0165961cc..e39d52556b 100644 --- a/tests/tests_fabric/utilities/test_warnings.py +++ b/tests/tests_fabric/utilities/test_warnings.py @@ -14,6 +14,7 @@ """Test that the warnings actually appear and they have the correct `stacklevel` Needs to be run outside of `pytest` as it captures all the warnings. + """ from contextlib import redirect_stderr from io import StringIO @@ -39,16 +40,16 @@ if __name__ == "__main__": cache.deprecation("test7") output = stderr.getvalue() - assert "test_warnings.py:29: UserWarning: test1" in output - assert "test_warnings.py:30: DeprecationWarning: test2" in output + assert "test_warnings.py:30: UserWarning: test1" in output + assert "test_warnings.py:31: DeprecationWarning: test2" in output - assert "test_warnings.py:32: UserWarning: test3" in output - assert "test_warnings.py:33: DeprecationWarning: test4" in output + assert "test_warnings.py:33: UserWarning: test3" in output + assert "test_warnings.py:34: DeprecationWarning: test4" in output - assert "test_warnings.py:35: LightningDeprecationWarning: test5" in output + assert "test_warnings.py:36: LightningDeprecationWarning: test5" in output - assert "test_warnings.py:38: UserWarning: test6" in output - assert "test_warnings.py:39: LightningDeprecationWarning: test7" in output + assert "test_warnings.py:39: UserWarning: test6" in output + assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output # check that logging is properly configured import logging diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index e724652a07..ac5b3443ae 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -39,8 +39,8 @@ def test_get_device_stats(tmpdir): @pytest.mark.parametrize("restore_after_pre_setup", [True, False]) def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup): - """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- - dispatch is called.""" + """Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- dispatch + is called.""" class TestPlugin(SingleDeviceStrategy): setup_called = False diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d4c8e85509..7c6a75ea3b 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -353,6 +353,7 @@ def test_train_progress_bar_update_amount( At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. + """ model = BoringModel() progress_bar = TQDMProgressBar(refresh_rate=refresh_rate) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 375734dead..e43122f402 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -63,6 +63,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): https://github.com/Lightning-AI/lightning/issues/1464 https://github.com/Lightning-AI/lightning/issues/1463 + """ seed_everything(42) model = ClassificationModel() diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 129b434e6b..7fbb7f580b 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -209,8 +209,7 @@ class OnEpochLayerFinetuning(BaseFinetuning): def test_base_finetuning_internal_optimizer_metadata(tmpdir): - """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning - Callbacks.""" + """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks.""" seed_everything(42) @@ -325,8 +324,7 @@ class FinetuningBoringModel(BoringModel): def test_callbacks_restore(tmpdir): - """Test callbacks restore is called after optimizers have been re-created but before optimizer states - reload.""" + """Test callbacks restore is called after optimizers have been re-created but before optimizer states reload.""" chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = FinetuningBoringModel() @@ -400,8 +398,7 @@ class BackboneBoringModel(BoringModel): def test_callbacks_restore_backbone(tmpdir): - """Test callbacks restore is called after optimizers have been re-created but before optimizer states - reload.""" + """Test callbacks restore is called after optimizers have been re-created but before optimizer states reload.""" ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer = Trainer( diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index aba956c414..59c1a145b1 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -37,8 +37,7 @@ def test_prediction_writer_invalid_write_interval(): def test_prediction_writer_hook_call_intervals(): - """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined - interval.""" + """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined interval.""" DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index d206fadb59..61676864af 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -281,8 +281,8 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool def test_permanent_when_model_is_saved_multiple_times( tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end ): - """When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not - the trained model if we want to continue with the same pruning buffers.""" + """When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not the + trained model if we want to continue with the same pruning buffers.""" if prune_on_train_epoch_end and save_on_train_epoch_end: pytest.xfail( "Pruning sets the `grad_fn` of the parameters so we can't save" diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index c8172e413c..cccbb3252e 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -85,8 +85,8 @@ def mock_training_epoch_loop(trainer): def test_model_checkpoint_score_and_ckpt( tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool ): - """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and - checkpoint data.""" + """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint + data.""" max_epochs = 3 limit_train_batches = 5 limit_val_batches = 7 @@ -190,8 +190,8 @@ def test_model_checkpoint_score_and_ckpt( def test_model_checkpoint_score_and_ckpt_val_check_interval( tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned ): - """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and - checkpoint data with val_check_interval.""" + """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint + data with val_check_interval.""" seed_everything(0) max_epochs = 3 limit_train_batches = 12 @@ -1131,8 +1131,8 @@ def test_hparams_type(tmpdir, use_omegaconf): def test_ckpt_version_after_rerun_new_trainer(tmpdir): - """Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances - are used.""" + """Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances are + used.""" epochs = 2 for i in range(epochs): mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}") @@ -1158,8 +1158,8 @@ def test_ckpt_version_after_rerun_new_trainer(tmpdir): def test_ckpt_version_after_rerun_same_trainer(tmpdir): - """Check that previous checkpoints are renamed to have the correct version suffix when the same trainer - instance is used.""" + """Check that previous checkpoints are renamed to have the correct version suffix when the same trainer instance is + used.""" mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test") mc.STARTING_VERSION = 9 trainer = Trainer( @@ -1303,8 +1303,8 @@ def test_model_checkpoint_saveload_ckpt(tmpdir): def test_resume_training_preserves_old_ckpt_last(tmpdir): - """Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from - the old checkpoint.""" + """Ensures that the last saved checkpoint is not deleted from the previous folder when training is resumed from the + old checkpoint.""" model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index fc6762b373..47a4ea9063 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -94,6 +94,7 @@ def restore_signal_handlers(): """Ensures that signal handlers get restored before the next test runs. This is a safety net for tests that don't run Trainer's teardown. + """ valid_signals = _SignalConnector._valid_signals() if not _IS_WINDOWS: @@ -207,6 +208,7 @@ def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. Setting ``filterwarnings`` with pytest breaks ``caplog`` when ``not logger.propagate``. + """ import logging @@ -248,6 +250,7 @@ def single_process_pg(): """Initialize the default process group with only the current process for testing purposes. The process group is destroyed when the with block is exited. + """ if torch.distributed.is_initialized(): raise RuntimeError("Can't use `single_process_pg` when the default process group is already initialized.") diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 751f580bd4..d254ec4385 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -237,8 +237,7 @@ def test_full_loop(tmpdir): def test_dm_reload_dataloaders_every_n_epochs(tmpdir): - """Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative - integer.""" + """Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative integer.""" class CustomBoringDataModule(BoringDataModule): def __init__(self): diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 6d612dd3c9..3c57791e26 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -424,6 +424,7 @@ def test_lightning_module_scriptable(): """Test that the LightningModule is `torch.jit.script`-able. Regression test for #15917. + """ model = BoringModel() trainer = Trainer() diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 22f93ee790..3d0ee4d7a9 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -73,6 +73,7 @@ def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmpdi """Test that the user can use our LightningOptimizer. Not recommended. + """ class TestModel(BoringModel): diff --git a/tests/tests_pytorch/helpers/datasets.py b/tests/tests_pytorch/helpers/datasets.py index e2954b0d73..8860160d6f 100644 --- a/tests/tests_pytorch/helpers/datasets.py +++ b/tests/tests_pytorch/helpers/datasets.py @@ -46,6 +46,7 @@ class MNIST(Dataset): 60000 >>> torch.bincount(dataset.targets) tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]) + """ RESOURCES = ( @@ -148,6 +149,7 @@ class TrialMNIST(MNIST): [0, 1, 2] >>> torch.bincount(dataset.targets) tensor([100, 100, 100]) + """ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence] = (0, 1, 2), **kwargs): diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 3fab1b6906..3f090b0264 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -170,6 +170,7 @@ def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class): """Test that the logger objects can be pickled. This test only makes sense if the packages are installed. + """ _patch_comet_atexit(monkeypatch) try: @@ -270,8 +271,8 @@ class CustomLoggerWithoutExperiment(Logger): @pytest.mark.parametrize("logger_class", [*ALL_LOGGER_CLASSES_WO_NEPTUNE, CustomLoggerWithoutExperiment]) @RunIf(skip_windows=True) def test_logger_initialization(tmpdir, monkeypatch, logger_class): - """Test that loggers get replaced by dummy loggers on global rank > 0 and that the experiment object is - available at the right time in Trainer.""" + """Test that loggers get replaced by dummy loggers on global rank > 0 and that the experiment object is available + at the right time in Trainer.""" _patch_comet_atexit(monkeypatch) try: _test_logger_initialization(tmpdir, logger_class) diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 239645c2a7..c1f6821311 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -275,8 +275,7 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir): @mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True) @mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_many_params(client, _, param, tmpdir): - """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 - parameters.""" + """Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters.""" logger = MLFlowLogger("test", save_dir=tmpdir) params = {f"test_{idx}": f"test_param_{idx}" for idx in range(150)} diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index 81f73b02d7..b2a7133240 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -41,11 +41,12 @@ def create_run_mock(mode="async", **kwargs): def create_neptune_mock(): - """Mock with provides nice `logger.name` and `logger.version` values. Additionally, it allows `mode` as an - argument to test different Neptune modes. + """Mock with provides nice `logger.name` and `logger.version` values. Additionally, it allows `mode` as an argument + to test different Neptune modes. Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local directories in FS. + """ return MagicMock(init_run=MagicMock(side_effect=create_run_mock)) @@ -88,6 +89,7 @@ def tmpdir_unittest_fixture(request, tmpdir): Resources: * https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture * https://towardsdatascience.com/mixing-pytest-fixture-and-unittest-testcase-for-selenium-test-9162218e8c8e + """ request.cls.tmpdir = tmpdir @@ -152,8 +154,8 @@ class TestNeptuneLogger(unittest.TestCase): @patch("lightning.pytorch.loggers.neptune.Run", Run) @patch("lightning.pytorch.loggers.neptune.Handler", Run) def test_online_with_wrong_kwargs(self, neptune): - """Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable - in init.""" + """Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable in + init.""" with self.assertRaises(ValueError): NeptuneLogger(run="some string") diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 3929144b7b..1ec5327612 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -53,6 +53,7 @@ def test_wandb_logger_init(wandb, monkeypatch): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. + """ # test wandb.init called when there is no W&B run wandb.run = None @@ -142,6 +143,7 @@ def test_wandb_pickle(wandb, tmpdir): """Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. + """ class Experiment: @@ -373,8 +375,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir): @mock.patch("lightning.pytorch.loggers.wandb.Run", new=mock.Mock) @mock.patch("lightning.pytorch.loggers.wandb.wandb") def test_wandb_log_model_with_score(wandb, monkeypatch, tmpdir): - """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar - tensor.""" + """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor.""" wandb.run = None model = BoringModel() diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 5758ecb1f6..ea8f760a59 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -28,8 +28,7 @@ from tests_pytorch.helpers.runif import RunIf @mock.patch("lightning.pytorch.loops.evaluation_loop._EvaluationLoop._on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): - """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end` - hooks.""" + """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end` hooks.""" model = BoringModel() trainer = Trainer( @@ -112,6 +111,7 @@ def test_memory_consumption_validation(tmpdir): Cannot run with MPS, since there we can only measure shared memory and not dedicated, which device has how much memory allocated. + """ def get_memory(): diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index f5990f266d..39dd766f06 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -129,12 +129,14 @@ def get_cycles_per_ms() -> float: This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of init, which takes much longer than subsequent calls. + """ def measure() -> float: """Measure and return approximate number of cycles per millisecond for `torch.cuda._sleep` Copied from: https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_cuda.py#L81. + """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) diff --git a/tests/tests_pytorch/loops/test_progress.py b/tests/tests_pytorch/loops/test_progress.py index daa381e6f8..27184d7b17 100644 --- a/tests/tests_pytorch/loops/test_progress.py +++ b/tests/tests_pytorch/loops/test_progress.py @@ -93,6 +93,7 @@ def test_optimizer_progress_default_factory(): """Ensure that the defaults are created appropriately. If `default_factory` was not used, the default would be shared between instances. + """ p1 = _OptimizerProgress() p2 = _OptimizerProgress() diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index 7d0690c5ac..7814f56cff 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -63,8 +63,7 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir): def test_should_stop_early_stopping_conditions_not_met( caplog, min_epochs, min_steps, current_epoch, global_step, early_stop, epoch_loop_done, raise_info_msg ): - """Test that checks that info message is logged when users sets `should_stop` but min conditions are not - met.""" + """Test that checks that info message is logged when users sets `should_stop` but min conditions are not met.""" trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0) trainer.fit_loop.max_batches = 10 trainer.should_stop = True @@ -86,6 +85,7 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop (min_epochs/steps is satisfied). + """ model = BoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index b6529f2614..dc5e12defa 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -138,8 +138,7 @@ def test__training_step__epoch_end__flow_scalar(tmpdir): def test_train_step_no_return(tmpdir): - """Tests that only training_step raises a warning when nothing is returned in case of - automatic_optimization.""" + """Tests that only training_step raises a warning when nothing is returned in case of automatic_optimization.""" class TestModel(BoringModel): def training_step(self, batch): diff --git a/tests/tests_pytorch/models/test_cpu.py b/tests/tests_pytorch/models/test_cpu.py index e79f014c7e..92423b85cb 100644 --- a/tests/tests_pytorch/models/test_cpu.py +++ b/tests/tests_pytorch/models/test_cpu.py @@ -155,6 +155,7 @@ def test_lbfgs_cpu_model(tmpdir): """Test each of the trainer options. Testing LBFGS optimizer + """ seed_everything(42) @@ -247,6 +248,7 @@ def test_running_test_no_val(tmpdir): """Verify `test()` works on a model with no `val_dataloader`. It performs train and test only + """ seed_everything(42) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index df0046daab..61bdfc0305 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -105,8 +105,7 @@ def test_model_properties_fit_ckpt_path(tmpdir): @RunIf(sklearn=True) def test_trainer_properties_restore_ckpt_path(tmpdir): - """Test that required trainer properties are set correctly when resuming from checkpoint in different - phases.""" + """Test that required trainer properties are set correctly when resuming from checkpoint in different phases.""" class CustomClassifModel(ClassificationModel): def configure_optimizers(self): diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 9c7c265fb8..00595bfd25 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -457,6 +457,7 @@ def test_pytorch_profiler_multiple_loggers(tmpdir): multiple loggers. See issue #8157. + """ def look_for_trace(trace_dir): diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 2eb367bee2..73b006e6f7 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -90,8 +90,7 @@ def test_global_state_snapshot(): @pytest.mark.parametrize("fake_node_rank", [0, 1]) @pytest.mark.parametrize("fake_local_rank", [0, 1]) def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, tmpdir): - """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary - file.""" + """Tests that the spawn strategy transfers the new weights to the main process and deletes the temporary file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) fake_global_rank = 2 * fake_node_rank + fake_local_rank @@ -130,8 +129,8 @@ def test_collect_rank_zero_results(trainer_fn, fake_node_rank, fake_local_rank, @pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"]) def test_transfer_weights(tmpdir, trainer_fn): - """Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the - temporary file.""" + """Tests that the multiprocessing launcher transfers the new weights to the main process and deletes the temporary + file.""" model = Mock(wraps=BoringModel(), spec=BoringModel) strategy = DDPStrategy(start_method="spawn") trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, strategy=strategy) diff --git a/tests/tests_pytorch/strategies/test_ddp_spawn.py b/tests/tests_pytorch/strategies/test_ddp_spawn.py index 4c03af87fc..74f562f3c1 100644 --- a/tests/tests_pytorch/strategies/test_ddp_spawn.py +++ b/tests/tests_pytorch/strategies/test_ddp_spawn.py @@ -84,8 +84,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): def test_ddp_spawn_find_unused_parameters_exception(): - """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning - users.""" + """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users.""" trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp_spawn", max_steps=2) with pytest.raises( ProcessRaisedException, match="It looks like your LightningModule has parameters that were not used in" diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 3e18264553..aeabb378d6 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -297,8 +297,7 @@ class UnusedParametersModel(BoringModel): def test_ddp_strategy_find_unused_parameters_exception(): - """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning - users.""" + """Test that the DDP strategy can change PyTorch's error message so that it's more useful for Lightning users.""" trainer = Trainer(accelerator="cpu", devices=1, strategy="ddp", max_steps=2) with pytest.raises(RuntimeError, match="It looks like your LightningModule has parameters that were not used in"): trainer.fit(UnusedParametersModel()) diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 6e432c0d9c..d4dcb38a39 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -106,8 +106,7 @@ def deepspeed_zero_config(deepspeed_config): @RunIf(deepspeed=True) @pytest.mark.parametrize("strategy", ["deepspeed", DeepSpeedStrategy]) def test_deepspeed_strategy_string(tmpdir, strategy): - """Test to ensure that the strategy can be passed via string or instance, and parallel devices is correctly - set.""" + """Test to ensure that the strategy can be passed via string or instance, and parallel devices is correctly set.""" trainer = Trainer( accelerator="cpu", @@ -141,6 +140,7 @@ def test_deepspeed_precision_choice(cuda_count_1, tmpdir): """Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin + """ trainer = Trainer( fast_dev_run=True, @@ -286,8 +286,8 @@ def test_deepspeed_run_configure_optimizers(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) def test_deepspeed_config(tmpdir, deepspeed_zero_config): - """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including - optimizers/schedulers and saves the model weights to load correctly.""" + """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers + and saves the model weights to load correctly.""" class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: @@ -358,8 +358,8 @@ def test_deepspeed_custom_precision_params(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) @pytest.mark.parametrize("precision", ["fp16", "bf16"]) def test_deepspeed_inference_precision_during_inference(precision, tmpdir): - """Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains - these changes.""" + """Ensure if we modify the precision for deepspeed and execute inference-only, the deepspeed config contains these + changes.""" class TestCB(Callback): def on_validation_start(self, trainer, pl_module) -> None: @@ -399,8 +399,8 @@ def test_deepspeed_custom_activation_checkpointing_params(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir): - """Ensure if we modify the activation checkpointing parameters, we pass these to - deepspeed.checkpointing.configure correctly.""" + """Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure + correctly.""" ds = DeepSpeedStrategy( partition_activations=True, cpu_checkpointing=True, @@ -456,8 +456,7 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_co @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multigpu(tmpdir): - """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized - correctly.""" + """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, @@ -930,8 +929,8 @@ def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) def test_deepspeed_multigpu_test_rnn(tmpdir): - """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when - training with certain layers which will crash with explicit partitioning.""" + """Test to ensure that turning off explicit partitioning of the entire module for ZeRO Stage 3 works when training + with certain layers which will crash with explicit partitioning.""" class TestModel(BoringModel): def __init__(self): @@ -962,6 +961,7 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmpdir, pl """Test to ensure that we setup distributed communication using correctly. When using windows, ranks environment variables should not be set, and deepspeed should handle this. + """ trainer = Trainer(default_root_dir=tmpdir, strategy=DeepSpeedStrategy(stage=3)) strategy = trainer.strategy @@ -1087,8 +1087,8 @@ def test_deepspeed_setup_train_dataloader(tmpdir): @pytest.mark.parametrize("limit_train_batches", [2]) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) def test_scheduler_step_count(mock_step, tmpdir, max_epoch, limit_train_batches, interval): - """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is - set to step or epoch.""" + """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is set to + step or epoch.""" class TestModel(BoringModel): def configure_optimizers(self): @@ -1122,8 +1122,8 @@ def test_scheduler_step_count(mock_step, tmpdir, max_epoch, limit_train_batches, @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) def test_deepspeed_configure_gradient_clipping(tmpdir): - """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in - case of deepspeed.""" + """Test to ensure that a warning is raised when `LightningModule.configure_gradient_clipping` is overridden in case + of deepspeed.""" class TestModel(BoringModel): def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm): @@ -1162,8 +1162,8 @@ def test_deepspeed_gradient_clip_by_value(tmpdir): @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multi_save_same_filepath(tmpdir): - """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old - sharded checkpoints.""" + """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old sharded + checkpoints.""" class CustomModel(BoringModel): def training_step(self, *args, **kwargs): @@ -1278,6 +1278,7 @@ def test_validate_parallel_devices_indices(device_indices): """Test that the strategy validates that it doesn't support selecting specific devices by index. DeepSpeed doesn't support it and needs the index to match to the local rank of the process. + """ strategy = DeepSpeedStrategy( accelerator=CUDAAccelerator(), parallel_devices=[torch.device("cuda", i) for i in device_indices] diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 96273daa58..381f799b08 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -285,6 +285,7 @@ def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params): """Test to ensure that the full state dict is extracted when using FSDP strategy. Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. + """ model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) correct_state_dict = model.state_dict() # State dict before wrapping @@ -547,6 +548,7 @@ def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params): Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can be restored to DDP, it means that the optimizer states were saved correctly. + """ model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) @@ -604,6 +606,7 @@ def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params): Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model can be restored to FSDP, it means that the optimizer states were restored correctly. + """ # restore model to ddp diff --git a/tests/tests_pytorch/strategies/test_single_device_strategy.py b/tests/tests_pytorch/strategies/test_single_device_strategy.py index 85c92ded50..d4648f55f6 100644 --- a/tests/tests_pytorch/strategies/test_single_device_strategy.py +++ b/tests/tests_pytorch/strategies/test_single_device_strategy.py @@ -45,6 +45,7 @@ def test_single_gpu(): """Tests if device is set correctly when training and after teardown for single GPU strategy. Cannot run this test on MPS due to shared memory not allowing dedicated measurements of GPU memory utilization. + """ trainer = Trainer(accelerator="gpu", devices=1, fast_dev_run=True) # assert training strategy attributes for device setting diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index c9ad6853a1..b074429642 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -903,6 +903,7 @@ def test_lightning_cli_custom_subcommand(): model: A model x: The x y: The y + """ class TestCLI(LightningCLI): diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 00dab05ad3..1e70bd0e59 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -587,8 +587,8 @@ def test_error_raised_with_insufficient_float_limit_train_dataloader(): ], ) def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmpdir): - """Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises - an error.""" + """Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises an + error.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) model = BoringModel() datamodule = BoringDataModule() diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index 13bbb2243c..9019249c37 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -26,6 +26,7 @@ class AllRankLogger(Logger): """Logger to test all-rank logging (i.e. not just rank 0). Logs are saved to local variable `logs`. + """ def __init__(self): @@ -102,6 +103,7 @@ def test_first_logger_call_in_subprocess(tmpdir): """Test that the Trainer does not call the logger too early. Only when the worker processes are initialized do we have access to the rank and know which one is the main process. + """ class LoggerCallsObserver(Callback): diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 0dc0df1f90..2623108e5b 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -675,8 +675,7 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): def test_warning_with_small_dataloader_and_logging_interval(tmpdir): - """Test that a warning message is shown if the dataloader length is too short for the chosen logging - interval.""" + """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval.""" model = BoringModel() dataloader = DataLoader(RandomDataset(32, length=10)) model.train_dataloader = lambda: dataloader @@ -847,8 +846,7 @@ class ModelWithDataLoaderDistributedSampler(BoringModel): @RunIf(min_cuda_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler_already_attached(tmpdir): - """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on - dataloader.""" + """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader.""" seed_everything(123) model = ModelWithDataLoaderDistributedSampler() trainer = Trainer( @@ -1209,8 +1207,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmp_path, monkeypatch, sanity def test_dataloaders_reset_and_attach(tmpdir): - """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching - the new one.""" + """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset dataloaders before attaching the + new one.""" # the assertions compare the datasets and not dataloaders since we patch and replace the samplers dataloader_0 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_1 = DataLoader(dataset=RandomDataset(32, 64)) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 346ef9a84e..20438af351 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -67,8 +67,7 @@ if _OMEGACONF_AVAILABLE: def test_trainer_error_when_input_not_lightning_module(): - """Test that a useful error gets raised when the Trainer methods receive something other than a - LightningModule.""" + """Test that a useful error gets raised when the Trainer methods receive something other than a LightningModule.""" trainer = Trainer() for method in ("fit", "validate", "test", "predict"): @@ -347,8 +346,8 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files) def test_model_checkpoint_only_weights(tmpdir): - """Tests use case where ModelCheckpoint is configured to save only model weights, and user tries to load - checkpoint to resume training.""" + """Tests use case where ModelCheckpoint is configured to save only model weights, and user tries to load checkpoint + to resume training.""" model = BoringModel() trainer = Trainer( @@ -1447,8 +1446,8 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): @pytest.mark.parametrize(("max_steps", "max_epochs", "global_step"), [(10, 5, 10), (20, None, 20)]) def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step): - """Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`, - and disabled if the limit is reached.""" + """Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`, and + disabled if the limit is reached.""" dataset_len = 200 batch_size = 10 diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 87a7412396..bee5c679ed 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -217,8 +217,7 @@ def test_datamodule_parameter(tmpdir): def test_accumulation_and_early_stopping(tmpdir): - """Test that early stopping of learning rate finder works, and that accumulation also works for this - feature.""" + """Test that early stopping of learning rate finder works, and that accumulation also works for this feature.""" seed_everything(1) class TestModel(BoringModel): diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 4e88953286..3f32f8231d 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -231,8 +231,7 @@ def test_call_to_trainer_method(tmpdir, scale_method): def test_error_on_dataloader_passed_to_fit(tmpdir): - """Verify that when the auto-scale batch size feature raises an error if a train dataloader is passed to - fit.""" + """Verify that when the auto-scale batch size feature raises an error if a train dataloader is passed to fit.""" # only train passed to fit model = BatchSizeModel(batch_size=2) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index f8f2be9786..c9cf6cd4de 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -178,8 +178,8 @@ def _run_simple_migration(monkeypatch, old_checkpoint): def test_migrate_checkpoint_too_new(): - """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer - version of Lightning than installed.""" + """Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer version + of Lightning than installed.""" super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123} with pytest.warns( PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}" diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 7109523b37..9f247616fd 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -403,8 +403,7 @@ def test_combined_loader_sequence_with_map_and_iterable(lengths): @pytest.mark.parametrize("use_distributed_sampler", [False, True]) def test_combined_data_loader_validation_test(use_distributed_sampler): - """This test makes sure distributed sampler has been properly injected in dataloaders when using - CombinedLoader.""" + """This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 3c6d2ababb..930650dcb2 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -139,11 +139,12 @@ def test_update_dataloader_typerror_custom_exception(): @pytest.mark.parametrize("predicting", [True, False]) def test_custom_batch_sampler(predicting): - """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to - properly reinstantiate the class, is invoked properly. + """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to properly + reinstantiate the class, is invoked properly. It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore not setting `__pl_saved_{args,arg_names,kwargs}` attributes. + """ class MyBatchSampler(BatchSampler): @@ -189,8 +190,8 @@ def test_custom_batch_sampler(predicting): def test_custom_batch_sampler_no_drop_last(): - """Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and - we want to reset it.""" + """Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and we + want to reset it.""" class MyBatchSampler(BatchSampler): # Custom batch sampler with extra argument, but without `drop_last` @@ -217,8 +218,7 @@ def test_custom_batch_sampler_no_drop_last(): def test_custom_batch_sampler_no_sampler(): - """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler - argument.""" + """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler argument.""" class MyBatchSampler(BatchSampler): # Custom batch sampler, without sampler argument. @@ -269,10 +269,10 @@ def test_dataloader_kwargs_replacement_with_iterable_dataset(mode): def test_dataloader_kwargs_replacement_with_array_default_comparison(): - """Test that the comparison of attributes and default argument values works with arrays (truth value - ambiguous). + """Test that the comparison of attributes and default argument values works with arrays (truth value ambiguous). Regression test for issue #15408. + """ dataset = RandomDataset(5, 100) diff --git a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py index 146ab1aa66..919acee7e4 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py @@ -22,8 +22,8 @@ from tests_pytorch.helpers.runif import RunIf @RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True) def test_deepspeed_summary(tmpdir): - """Test to ensure that the summary contains the correct values when stage 3 is enabled and that the trainer - enables the `DeepSpeedSummary` when DeepSpeed is used.""" + """Test to ensure that the summary contains the correct values when stage 3 is enabled and that the trainer enables + the `DeepSpeedSummary` when DeepSpeed is used.""" model = BoringModel() total_parameters = sum(x.numel() for x in model.parameters()) diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 0fb31b59ba..ae3c092907 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -63,10 +63,11 @@ def _shortcut_patch(orig_fn, shortcut_case, attr_names=None): @pytest.fixture() def clean_import(): - """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the - current state of the imported modules. + """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the current + state of the imported modules. Afterwards, it restores the original state of the modules. + """ import sys @@ -108,6 +109,7 @@ def test_import_with_unavailable_dependencies(patch_name, new_fn, to_import, cle When the patch is applied and the module is imported, it should not raise any errors. The list of cases to check was compiled by finding else branches of top-level if statements checking for the availability of the module and performing imports. + """ with mock.patch(patch_name, new=new_fn): importlib.import_module(to_import) diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py index 78f0570ee2..04c4d50ae8 100644 --- a/tests/tests_pytorch/utilities/test_warnings.py +++ b/tests/tests_pytorch/utilities/test_warnings.py @@ -14,6 +14,7 @@ """Test that the warnings actually appear and they have the correct `stacklevel` Needs to be run outside of `pytest` as it captures all the warnings. + """ from contextlib import redirect_stderr from io import StringIO