docformatter: config with black (#18064)

* docformatter: config with black

* additional_dependencies: [tomli]

* 119

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2023-08-09 16:44:20 +02:00 committed by GitHub
parent e33816ce60
commit efa7b2f9ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
356 changed files with 1343 additions and 479 deletions

View File

@ -86,6 +86,7 @@ class _RequirementWithComment(Requirement):
'arrow>=1.2.0'
>>> _RequirementWithComment("arrow").adjust("major")
'arrow'
"""
out = str(self)
if self.strict:
@ -115,6 +116,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen
>>> txt = '\\n'.join(txt)
>>> [r.adjust('none') for r in _parse_requirements(txt)]
['this', 'example', 'foo # strict', 'thing']
"""
lines = yield_lines(strs)
pip_argument = None
@ -149,6 +151,7 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
>>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx<...]
"""
assert unfreeze in {"none", "major", "all"}
path = Path(path_dir) / file_name
@ -165,6 +168,7 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'...PyTorch Lightning is just organized PyTorch...'
"""
path_readme = os.path.join(path_dir, "README.md")
with open(path_readme, encoding="utf-8") as fo:
@ -244,6 +248,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
"""Load all base requirements from all particular packages and prune duplicates.
>>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
"""
requires = [
load_requirements(d, unfreeze="none" if freeze_requirements else "major")
@ -300,6 +305,7 @@ def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning
'http://pytorch_lightning.ai', \
'from lightning_fabric import __version__', \
'@lightning.ai']
"""
out = lines[:]
for source_import, target_import in mapping:

View File

@ -60,7 +60,8 @@ repos:
rev: v1.7.3
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
additional_dependencies: [tomli]
args: ["--in-place"]
- repo: https://github.com/asottile/yesqa
rev: v1.5.0

View File

@ -23,6 +23,7 @@ class FileServer(L.LightningWork):
drive: The drive can share data inside your application.
base_dir: The local directory where the data will be stored.
chunk_size: The quantity of bytes to download/upload at once.
"""
super().__init__(
cloud_build_config=L.BuildConfig(["flask, flask-cors"]),
@ -238,4 +239,5 @@ def test_file_server_in_cloud():
# 2. By calling logs = get_logs_fn(),
# you get all the logs currently on the admin page.
"""

View File

@ -36,6 +36,7 @@ class GithubRepoRunner(TracerPythonScript):
script_args: The arguments to be provided to the script.
requirements: The python requirements tp run the script.
cloud_compute: The object to select the cloud instance.
"""
super().__init__(
script_path=script_path,

View File

@ -10,6 +10,7 @@ class Locust(LightningWork):
Arguments:
num_users: Number of users emulated by Locust
"""
# Note: Using the default port 8089 of Locust.
super().__init__(

View File

@ -18,6 +18,7 @@ class MLServer(LightningWork):
Example: "mlserver_sklearn.SKLearnModel".
Learn more here: $ML_SERVER_URL/tree/master/runtimes
workers: Number of server worker.
"""
def __init__(
@ -51,6 +52,7 @@ class MLServer(LightningWork):
Arguments:
model_path: The path to the trained model.
"""
# 1: Use the host and port at runtime so it works in the cloud.
# $ML_SERVER_URL/blob/master/mlserver/settings.py#L50

View File

@ -15,6 +15,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
Usage:
download_file('http://web4host.net/5MB.zip')
"""
if url == "NEED_TO_BE_CREATED":
raise NotImplementedError

View File

@ -5,6 +5,7 @@ Run the app:
lightning run app examples/layout/demo.py
This starts one server for each flow that returns a UI. Access the UI at the link printed in the terminal.
"""
import os

View File

@ -137,6 +137,7 @@ class MyCustomTrainer:
If not specified, no validation will run.
ckpt_path: Path to previous checkpoints to resume training from.
If specified, will always look for the latest checkpoint within the given directory.
"""
self.fabric.launch()
@ -207,6 +208,7 @@ class MyCustomTrainer:
If greater then the number of batches in the ``train_loader``, this has no effect.
scheduler_cfg: The learning rate scheduler configuration.
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
"""
self.fabric.call("on_train_epoch_start")
iterable = self.progbar_wrapper(
@ -268,6 +270,7 @@ class MyCustomTrainer:
val_loader: The dataloader yielding the validation batches.
limit_batches: Limits the batches during this validation epoch.
If greater then the number of batches in the ``val_loader``, this has no effect.
"""
# no validation if val_loader wasn't passed
if val_loader is None:
@ -311,13 +314,14 @@ class MyCustomTrainer:
torch.set_grad_enabled(True)
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
"""A single training step, running forward and backward. The optimizer step is called separately, as this
is given as a closure to the optimizer step.
"""A single training step, running forward and backward. The optimizer step is called separately, as this is
given as a closure to the optimizer step.
Args:
model: the lightning module to train
batch: the batch to run the forward on
batch_idx: index of the current batch w.r.t the current epoch
"""
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
@ -347,6 +351,7 @@ class MyCustomTrainer:
Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values.
level: whether we are trying to step on epoch- or step-level
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
"""
# no scheduler
@ -395,6 +400,7 @@ class MyCustomTrainer:
Args:
iterable: the iterable to wrap with tqdm
total: the total length of the iterable, necessary in case the number of batches was limited.
"""
if self.fabric.is_global_zero:
return tqdm(iterable, total=total, **kwargs)
@ -406,6 +412,7 @@ class MyCustomTrainer:
Args:
state: a mapping contaning model, optimizer and lr scheduler
path: the path to load the checkpoint from
"""
if state is None:
state = {}
@ -458,6 +465,7 @@ class MyCustomTrainer:
Args:
configure_optim_output: The output of ``configure_optimizers``.
For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
"""
_lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"}
@ -511,6 +519,7 @@ class MyCustomTrainer:
prog_bar: a progressbar (on global rank zero) or an iterable (every other rank).
candidates: the values to add as postfix strings to the progressbar.
prefix: the prefix to add to each of these values.
"""
if isinstance(prog_bar, tqdm) and candidates is not None:
postfix_str = ""

View File

@ -25,6 +25,7 @@ and replace ``loss.backward()`` with ``self.backward(loss)``.
Accelerate your training loop by setting the ``--accelerator``, ``--strategy``, ``--devices`` options directly from
the command line. See ``lightning run model --help`` or learn more from the documentation:
https://lightning.ai/docs/fabric.
"""
import argparse

View File

@ -14,6 +14,7 @@
"""MNIST autoencoder example.
To run: python autoencoder.py --trainer.max_epochs=50
"""
from os import path
from typing import Optional, Tuple

View File

@ -14,6 +14,7 @@
"""MNIST backbone image classifier example.
To run: python backbone_image_classifier.py --trainer.max_epochs=50
"""
from os import path
from typing import Optional

View File

@ -20,6 +20,7 @@ visualized in 2 ways:
* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin)
1. pip install tensorboard torch-tb-profiler
2. tensorboard --logdir={FOLDER}
"""
from os import path

View File

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a
pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the
'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`,
see below) is trained for 15 epochs.
pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats
and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see
below) is trained for 15 epochs.
The training consists of three stages.
@ -37,6 +37,7 @@ Note:
To run:
python computer_vision_fine_tuning.py fit
"""
import logging
@ -97,6 +98,7 @@ class CatDogImageDataModule(LightningDataModule):
dl_path: root directory where to download the data
num_workers: number of CPU workers
batch_size: number of sample in a batch
"""
super().__init__()
@ -174,6 +176,7 @@ class TransferLearningModel(LightningModule):
milestones: List of two epochs milestones
lr: Initial learning rate
lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
"""
super().__init__()
self.backbone = backbone
@ -209,6 +212,7 @@ class TransferLearningModel(LightningModule):
"""Forward pass.
Returns logits.
"""
# 1. Feature extraction:
x = self.feature_extractor(x)

View File

@ -16,6 +16,7 @@
After a few epochs, launch TensorBoard to see the images being generated at every batch:
tensorboard --logdir default
"""
from argparse import ArgumentParser, Namespace

View File

@ -28,6 +28,7 @@ or show all options you can change:
python imagenet.py --help
python imagenet.py fit --help
"""
import os
from typing import Optional

View File

@ -29,6 +29,7 @@ References
[1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
"""
import argparse
@ -54,6 +55,7 @@ class DQN(nn.Module):
DQN(
(net): Sequential(...)
)
"""
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
@ -79,6 +81,7 @@ class ReplayBuffer:
>>> ReplayBuffer(5) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.ReplayBuffer object at ...>
"""
def __init__(self, capacity: int) -> None:
@ -96,6 +99,7 @@ class ReplayBuffer:
Args:
experience: tuple (state, action, reward, done, new_state)
"""
self.buffer.append(experience)
@ -117,6 +121,7 @@ class RLDataset(IterableDataset):
>>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.RLDataset object at ...>
"""
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
@ -141,6 +146,7 @@ class Agent:
>>> buffer = ReplayBuffer(10)
>>> Agent(env, buffer) # doctest: +ELLIPSIS
<...reinforce_learn_Qnet.Agent object at ...>
"""
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
@ -168,6 +174,7 @@ class Agent:
Returns:
action
"""
if np.random.random() < epsilon:
action = self.env.action_space.sample()
@ -194,6 +201,7 @@ class Agent:
Returns:
reward, done
"""
action = self.get_action(net, epsilon, device)
@ -222,6 +230,7 @@ class DQNLightning(LightningModule):
(net): Sequential(...)
)
)
"""
def __init__(
@ -270,6 +279,7 @@ class DQNLightning(LightningModule):
Args:
steps: number of random steps to populate the buffer with
"""
for i in range(steps):
self.agent.play_step(self.net, epsilon=1.0)
@ -282,6 +292,7 @@ class DQNLightning(LightningModule):
Returns:
q values
"""
return self.net(x)
@ -293,6 +304,7 @@ class DQNLightning(LightningModule):
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch
@ -308,8 +320,8 @@ class DQNLightning(LightningModule):
return nn.MSELoss()(state_action_values, expected_state_action_values)
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss
based on the minibatch received.
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
the minibatch received.
Args:
batch: current mini batch of replay data
@ -317,6 +329,7 @@ class DQNLightning(LightningModule):
Returns:
Training loss and log metrics
"""
device = self.get_device(batch)
epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame)

View File

@ -26,6 +26,7 @@ References
[1] https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py
[2] https://github.com/openai/spinningup
[3] https://github.com/sid-sundrani/ppo_lightning
"""
import argparse
from typing import Callable, Iterator, List, Tuple
@ -52,8 +53,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
class ActorCategorical(nn.Module):
"""Policy network, for discrete action spaces, which returns a distribution and an action given an
observation."""
"""Policy network, for discrete action spaces, which returns a distribution and an action given an observation."""
def __init__(self, actor_net):
"""
@ -81,6 +81,7 @@ class ActorCategorical(nn.Module):
Returns:
log probability of the action under pi
"""
return pi.log_prob(actions)
@ -117,6 +118,7 @@ class ActorContinuous(nn.Module):
Returns:
log probability of the action under pi
"""
return pi.log_prob(actions).sum(axis=-1)
@ -127,6 +129,7 @@ class ExperienceSourceDataset(IterableDataset):
Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the
experience source and how the batch is generated is defined the Lightning model itself
"""
def __init__(self, generate_batch: Callable):
@ -144,6 +147,7 @@ class PPOLightning(LightningModule):
Train:
trainer = Trainer()
trainer.fit(model)
"""
def __init__(
@ -231,6 +235,7 @@ class PPOLightning(LightningModule):
Returns:
Tuple of policy and action
"""
pi, action = self.actor(x)
value = self.critic(x)
@ -245,6 +250,7 @@ class PPOLightning(LightningModule):
Returns:
list of discounted rewards/advantages
"""
assert isinstance(rewards[0], float)
@ -267,6 +273,7 @@ class PPOLightning(LightningModule):
Returns:
list of advantages
"""
rews = rewards + [last_value]
vals = values + [last_value]
@ -373,6 +380,7 @@ class PPOLightning(LightningModule):
Args:
batch: batch of replay buffer/trajectory data
"""
state, action, old_logp, qval, adv = batch
@ -405,8 +413,7 @@ class PPOLightning(LightningModule):
return optimizer_actor, optimizer_critic
def optimizer_step(self, *args, **kwargs):
"""Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data
sample."""
"""Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data sample."""
for _ in range(self.nb_optim_iters):
super().optimizer_step(*args, **kwargs)

View File

@ -31,8 +31,7 @@ DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27
def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
"""Create synthetic dataset with random images, just to simulate that the dataset have been already
downloaded."""
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
for p_dir in (path_dir_images, path_dir_masks):
@ -65,6 +64,7 @@ class KITTI(Dataset):
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
"""
IMAGE_PATH = os.path.join("training", "image_2")
@ -154,6 +154,7 @@ class UNet(nn.Module):
(5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
"""
def __init__(self, num_classes: int = 19, num_layers: int = 5, features_start: int = 64, bilinear: bool = False):
@ -200,6 +201,7 @@ class DoubleConv(nn.Module):
DoubleConv(
(net): Sequential(...)
)
"""
def __init__(self, in_ch: int, out_ch: int):
@ -229,6 +231,7 @@ class Down(nn.Module):
)
)
)
"""
def __init__(self, in_ch: int, out_ch: int):
@ -240,8 +243,8 @@ class Down(nn.Module):
class Up(nn.Module):
"""Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
map from contracting path, followed by double 3x3 convolution.
"""Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map
from contracting path, followed by double 3x3 convolution.
>>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Up(
@ -250,6 +253,7 @@ class Up(nn.Module):
(net): Sequential(...)
)
)
"""
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
@ -306,6 +310,7 @@ class SegModel(LightningModule):
)
)
)
"""
def __init__(

View File

@ -50,6 +50,13 @@ skip = ["_notebooks"]
line-length = 120
exclude = '(_notebooks/.*)'
[tool.docformatter]
recursive = true
# this need to be shorter as some docstings are r"""...
wrap-summaries = 119
wrap-descriptions = 120
blank = true
[tool.ruff]
line-length = 120

View File

@ -14,6 +14,7 @@
"""Diagnose your system and show basic information.
This server mainly to get detail info for better bug reporting.
"""
import os

View File

@ -38,6 +38,7 @@ There are considered three main scenarios for installing this project:
compared against PyPI registry
b) with a parameterization build desired packages in to standard `dist/` folder
c) validate packages and publish to PyPI
"""
import contextlib
import glob

View File

@ -58,6 +58,7 @@ class _FastApiMockRequest:
def configure_api(self):
return [Post("/api/v1/request", self.request)]
"""
_body: Optional[str] = None
@ -116,6 +117,7 @@ class _HttpMethod:
route: The path used to route the requests
method: The associated flow method
timeout: The time in seconds taken before raising a timeout exception.
"""
self.route = route
self.attached_to_flow = hasattr(method, "__self__")

View File

@ -566,6 +566,7 @@ def _install_app_from_source(
If true, overwrite the app directory without asking if it already exists
git_sha:
The git_sha for checking out the git repo of the app.
"""
if not cwd:

View File

@ -46,6 +46,7 @@ def logs(app_name: str, components: List[str], follow: bool) -> None:
Print logs only from selected works:
$ lightning show logs my-application root.work_a root.work_b
"""
_show_logs(app_name, components, follow)

View File

@ -2,6 +2,7 @@ r"""To test a lightning component:
1. Init the component.
2. call .run()
"""
from placeholdername.component import TemplateComponent

View File

@ -57,6 +57,7 @@ def connect_app(app_name_or_id: str):
\b
# once done, disconnect and go back to the standard lightning CLI commands
lightning disconnect
"""
from lightning.app.utilities.commands.base import _download_command

View File

@ -98,6 +98,7 @@ def delete_app(app_name: str, skip_user_confirm_prompt: bool) -> None:
Deleting an app also deletes all app websites, works, artifacts, and logs. This permanently removes any record of
the app as well as all any of its associated resources and data. This does not affect any resources and data
associated with other Lightning apps on your account.
"""
console = Console()

View File

@ -40,10 +40,11 @@ def launch() -> None:
@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str)
@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int)
def run_server(file: str, queue_id: str, host: str, port: int) -> None:
"""It takes the application file as input, build the application object and then use that to run the
application server.
"""It takes the application file as input, build the application object and then use that to run the application
server.
This is used by the cloud runners to start the status server for the application
"""
logger.debug(f"Run Server: {file} {queue_id} {host} {port}")
start_application_server(file, host, port, queue_id=queue_id)
@ -54,10 +55,11 @@ def run_server(file: str, queue_id: str, host: str, port: int) -> None:
@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
@click.option("--base-url", help="Base url at which the app server is hosted", default="")
def run_flow(file: str, queue_id: str, base_url: str) -> None:
"""It takes the application file as input, build the application object, proxy all the work components and then
run the application flow defined in the root component.
"""It takes the application file as input, build the application object, proxy all the work components and then run
the application flow defined in the root component.
It does exactly what a singleprocess dispatcher would do but with proxied work components.
"""
logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
run_lightning_flow(file, queue_id=queue_id, base_url=base_url)
@ -68,8 +70,8 @@ def run_flow(file: str, queue_id: str, base_url: str) -> None:
@click.option("--work-name", type=str)
@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
def run_work(file: str, work_name: str, queue_id: str) -> None:
"""Unlike other entrypoints, this command will take the file path or module details for a work component and
run that by fetching the states from the queues."""
"""Unlike other entrypoints, this command will take the file path or module details for a work component and run
that by fetching the states from the queues."""
logger.debug(f"Run Work: {file} {work_name} {queue_id}")
run_lightning_work(
file=file,
@ -109,10 +111,11 @@ def run_flow_and_servers(
port: int,
flow_port: Tuple[Tuple[str, int]],
) -> None:
"""It takes the application file as input, build the application object and then use that to run the
application flow defined in the root component, the application server and all the flow frontends.
"""It takes the application file as input, build the application object and then use that to run the application
flow defined in the root component, the application server and all the flow frontends.
This is used by the cloud runners to start the flow, the status server and all frontends for the application
"""
logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
logger.debug(f"Run Server: {file} {queue_id} {host} {port}.")

View File

@ -13,6 +13,7 @@ class TensorBoard(LightningFlow):
Args:
log_dir: The path to the directory where the TensorBoard log-files will appear.
sync_every_n_seconds: How often to sync the log directory (given as an argument to the run method)
"""
super().__init__()
self.worker = TensorBoardWorker(log_dir=log_dir, sync_every_n_seconds=sync_every_n_seconds)

View File

@ -29,6 +29,7 @@ def _configure_session() -> Session:
"""Configures the session for GET and POST requests.
It enables a generous retrial strategy that waits for the application server to connect.
"""
retry_strategy = Retry(
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))

View File

@ -158,6 +158,7 @@ class Database(LightningWork):
# RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
"""
super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
self.db_filename = db_filename

View File

@ -52,6 +52,7 @@ def _pydantic_column_type(pydantic_type: Any) -> Any:
class TrialConfig(SQLModel, table=False):
...
params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
"""
class PydanticJSONType(TypeDecorator, Generic[T]):

View File

@ -67,6 +67,7 @@ class MultiNode(LightningFlow):
running locally.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
if num_nodes > 1 and not is_running_in_cloud():

View File

@ -70,6 +70,7 @@ class PopenPythonScript(LightningWork):
.. literalinclude:: ../../../../examples/app/components/python/component_popen.py
:language: python
"""
super().__init__(**kwargs)
if not os.path.exists(script_path):

View File

@ -117,6 +117,7 @@ class TracerPythonScript(LightningWork):
.. literalinclude:: ../../../../examples/app/components/python/app.py
:language: python
"""
super().__init__(**kwargs)
self.script_path = str(script_path)

View File

@ -114,8 +114,8 @@ def _create_fastapi(title: str) -> _TrackableFastAPI:
class _LoadBalancer(LightningWork):
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton
API asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
After enabling you will require to send username and password from the request header for the private endpoints.
@ -131,6 +131,7 @@ class _LoadBalancer(LightningWork):
api_name: The name to be displayed on the UI. Normally, it is the name of the work class
cold_start_proxy: The proxy service to use while the work is cold starting.
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
"""
@requires(["aiohttp"])
@ -236,6 +237,7 @@ class _LoadBalancer(LightningWork):
Two instances of this function should not be running with shared `_state_server` as that would create race
conditions
"""
while True:
await asyncio.sleep(0.05)
@ -293,6 +295,7 @@ class _LoadBalancer(LightningWork):
"""This function checks if we have processing capacity for one more request or not.
Depends on the value from here, we decide whether we should proxy the request or not
"""
if not self._fastapi_app:
return False
@ -385,6 +388,7 @@ class _LoadBalancer(LightningWork):
"""Updates works that load balancer distributes requests to.
AutoScaler uses this method to increase/decrease the number of works.
"""
old_server_urls = set(self.servers)
current_server_urls = {
@ -623,6 +627,7 @@ class AutoScaler(LightningFlow):
Returns:
The name of the new work attribute.
"""
work_attribute = uuid.uuid4().hex
work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
@ -665,6 +670,7 @@ class AutoScaler(LightningFlow):
Returns:
The target number of running works. The value will be adjusted after this method runs
so that it satisfies ``min_replicas<=replicas<=max_replicas``.
"""
pending_requests = metrics["pending_requests"]
active_or_pending_works = replicas + metrics["pending_works"]

View File

@ -35,6 +35,7 @@ class ColdStartProxy:
Args:
proxy_url (str): The url of the proxy service
"""
@requires(["aiohttp"])
@ -46,12 +47,13 @@ class ColdStartProxy:
async def handle_request(self, request: BaseModel) -> Any:
"""This method is called when the request is received while the work is cold starting. The default
implementation of this method is to forward the request body to the proxy service with POST method but the
user can override this method to handle the request in any way.
implementation of this method is to forward the request body to the proxy service with POST method but the user
can override this method to handle the request in any way.
Args:
request (BaseModel): The request body, a pydantic model that is being
forwarded by load balancer which is a FastAPI service
"""
try:
async with aiohttp.ClientSession() as session:

View File

@ -77,6 +77,7 @@ class ServeGradio(LightningWork, abc.ABC):
"""Override to instantiate and return your model.
The model would be accessible under self.model
"""
def run(self, *args: Any, **kwargs: Any):

View File

@ -211,6 +211,7 @@ class PythonServer(LightningWork, abc.ABC):
... return {"prediction": self._model(request.image)}
...
>>> app = LightningApp(SimpleServer())
"""
super().__init__(parallel=True, **kwargs)
if not issubclass(input_type, BaseModel):
@ -228,6 +229,7 @@ class PythonServer(LightningWork, abc.ABC):
Note that this will be called exactly once on every work machines. So if you have multiple machines for serving,
this will be called on each of them.
"""
return
@ -243,6 +245,7 @@ class PythonServer(LightningWork, abc.ABC):
This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction
using the model(s) etc goes here
"""
pass
@ -325,6 +328,7 @@ class PythonServer(LightningWork, abc.ABC):
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
Normally, you don't need to override this method.
"""
self.setup(*args, **kwargs)

View File

@ -67,6 +67,7 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
host: Address to be used to serve the model.
port: Port to be used to serve the model.
workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments.
"""
super().__init__(parallel=True, host=host, port=port)
if input and input not in _DESERIALIZER:
@ -151,8 +152,8 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi
workers are present."""
"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers
are present."""
render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None)
render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None)
if render_fn_name is None or render_fn_module_file is None:

View File

@ -50,6 +50,7 @@ class ServeStreamlit(LightningWork, abc.ABC):
"""Optionally override to instantiate and return your model.
The model will be accessible under ``self.model``.
"""
return None

View File

@ -28,4 +28,5 @@ class BaseType(abc.ABCMeta):
"""Take the inputs from the network and deserilize/convert them them.
Output from this method will go to the exposed method as arguments.
"""

View File

@ -159,6 +159,7 @@ class LightningTrainerScript(LightningFlow):
cloud_compute: The cloud compute object used in the cloud.
sanity_serving: Whether to validate that the model correctly implements
the ServableModule API
"""
super().__init__()
self.script_path = script_path

View File

@ -270,8 +270,8 @@ async def post_delta(
x_lightning_session_uuid: Optional[str] = Header(None), # type: ignore[assignment]
x_lightning_session_id: Optional[str] = Header(None), # type: ignore[assignment]
) -> Optional[Dict]:
"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to
update the state."""
"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update
the state."""
if x_lightning_session_uuid is None:
raise Exception("Missing X-Lightning-Session-UUID header")

View File

@ -383,8 +383,7 @@ class LightningApp:
return deltas
def maybe_apply_changes(self) -> Optional[bool]:
"""Get the deltas from both the flow queue and the work queue, merge the two deltas and update the
state."""
"""Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state."""
self._send_flow_to_work_deltas(self.state)
if not self.collect_changes:
@ -503,6 +502,7 @@ class LightningApp:
"""Entry point of the LightningApp.
This would be dispatched by the Runtime objects.
"""
self._original_state = deepcopy(self.state)
done = False

View File

@ -396,6 +396,7 @@ class LightningFlow:
.. deprecated:: 1.9.0
This function is deprecated and will be removed in 2.0.0. Use :meth:`stop` instead.
"""
warnings.warn(
DeprecationWarning(
@ -411,6 +412,7 @@ class LightningFlow:
(prefixed by '__') attributes are not.
Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
"""
return name in LightningFlow._INTERNAL_STATE_VARS or not name.startswith("_")
@ -487,6 +489,7 @@ class LightningFlow:
</div>
</div>
<br />
"""
if not user_key:
frame = cast(FrameType, inspect.currentframe()).f_back
@ -626,6 +629,7 @@ class LightningFlow:
</div>
</div>
<br />
"""
return [{"name": name, "content": component} for (name, component) in self.flows.items()]
@ -639,6 +643,7 @@ class LightningFlow:
run_once: Whether to run the entire iteration only once.
Otherwise, it would restart from the beginning.
user_key: Key to be used to track the caching mechanism.
"""
if not isinstance(iterable, Iterable):
raise TypeError(f"An iterable should be provided to `self.iterate` method. Found {iterable}")
@ -708,6 +713,7 @@ class LightningFlow:
.. code-block:: bash
lightning my_command_name --args name=my_own_name
"""
raise NotImplementedError
@ -741,6 +747,7 @@ class LightningFlow:
Once the app is running, you can access the Swagger UI of the app
under the ``/docs`` route.
"""
raise NotImplementedError
@ -805,6 +812,7 @@ class LightningFlow:
children_states: The state of the dynamic children of this flow.
strict: Whether to raise an exception if a dynamic
children hasn't been re-created.
"""
self.set_state(flow_state, recurse=False)
direct_children_states = {k: v for k, v in children_states.items() if "." not in k}

View File

@ -182,6 +182,7 @@ class BaseQueue(ABC):
timeout:
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
A timeout of None can be used to block indefinitely.
"""
pass
@ -190,6 +191,7 @@ class BaseQueue(ABC):
"""Returns True if the queue is running, False otherwise.
Child classes should override this property and implement custom logic as required
"""
return True
@ -286,6 +288,7 @@ class RedisQueue(BaseQueue):
timeout:
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
A timeout of None can be used to block indefinitely.
"""
if timeout is None:
# this means it's blocking in redis
@ -464,6 +467,7 @@ class HTTPQueue(BaseQueue):
This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be
accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments
"""
if "_" not in queue_name:
return "", queue_name

View File

@ -124,6 +124,7 @@ class LightningWork:
</div>
</div>
<br />
"""
from lightning.app.runners.backends.backend import Backend
@ -212,6 +213,7 @@ class LightningWork:
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
Locally, the address is 127.0.0.1 and in the cloud it will be determined by the cluster.
"""
return self._internal_ip
@ -221,6 +223,7 @@ class LightningWork:
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster.
"""
return self._public_ip
@ -234,6 +237,7 @@ class LightningWork:
(prefixed by '__') attributes are not.
Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
"""
return name in LightningWork._INTERNAL_STATE_VARS or not name.startswith("_")
@ -247,6 +251,7 @@ class LightningWork:
"""Returns the display name of the LightningWork in the cloud.
The display name needs to set before the run method of the work is called.
"""
return self._display_name
@ -269,6 +274,7 @@ class LightningWork:
"""Whether to run in parallel mode or not.
When parallel is False, the flow waits for the work to finish.
"""
return self._parallel
@ -325,6 +331,7 @@ class LightningWork:
"""Return the current status of the work.
All statuses are stored in the state.
"""
call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
if call_hash in self._calls:
@ -628,6 +635,7 @@ class LightningWork:
Raises:
LightningPlatformException: If resource exceeds platform quotas or other constraints.
"""
def on_exception(self, exception: BaseException) -> None:
@ -636,8 +644,7 @@ class LightningWork:
raise exception
def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus:
"""Method used to return the first request and the total count of timeout after the latest succeeded
status."""
"""Method used to return the first request and the total count of timeout after the latest succeeded status."""
succeeded_statuses = [
status_idx for status_idx, status in enumerate(statuses) if status["stage"] == WorkStageStatus.SUCCEEDED
]
@ -653,6 +660,7 @@ class LightningWork:
"""Override this hook to add your logic when the work is exiting.
Note: This hook is not guaranteed to be called when running in the cloud.
"""
pass
@ -660,6 +668,7 @@ class LightningWork:
"""Stops LightingWork component and shuts down hardware provisioned via L.CloudCompute.
This can only be called from a ``LightningFlow``.
"""
if not self._backend:
raise RuntimeError(f"Only the `LightningFlow` can request this work ({self.name!r}) to stop.")
@ -675,6 +684,7 @@ class LightningWork:
"""Delete LightingWork component and shuts down hardware provisioned via L.CloudCompute.
Locally, the work.delete() behaves as work.stop().
"""
if not self._backend:
raise Exception(
@ -755,4 +765,5 @@ class LightningWork:
returned URL can depend on the state. This is not the case if the work returns a
:class:`~lightning.app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
in order for the runtime to start the server.
"""

View File

@ -23,6 +23,7 @@ class Frontend(ABC):
"""Base class for any frontend that gets exposed by LightningFlows.
The flow attribute will be set by the app while bootstrapping.
"""
def __init__(self) -> None:
@ -48,6 +49,7 @@ class Frontend(ABC):
def start_server(self, host, port, root_path=""):
self._process = subprocess.Popen(["flask", "run" "--host", host, "--port", str(port)])
"""
@abstractmethod
@ -62,4 +64,5 @@ class Frontend(ABC):
def stop_server(self):
self._process.kill()
"""

View File

@ -81,6 +81,7 @@ class JustPyFrontend(Frontend):
app = LightningApp(Flow())
"""
def __init__(self, render_fn: Callable) -> None:

View File

@ -1,5 +1,4 @@
"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app
framework."""
"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app framework."""
from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
from lightning.app.frontend.panel.panel_frontend import PanelFrontend

View File

@ -94,6 +94,7 @@ def _watch_app_state(callback: Callable):
def handle_state_change():
print("The App State changed.")
watch_app_state(handle_state_change)
"""
_CALLBACKS.append(callback)
_start_websocket()

View File

@ -72,6 +72,7 @@ class AppStateWatcher(Parameterized):
Pydantic which additionally provides powerful and unique features for building reactive apps.
Please note the ``AppStateWatcher`` is a singleton, i.e., only one instance is instantiated
"""
state: AppState = ClassSelector(

View File

@ -35,6 +35,7 @@ def _has_panel_autoreload() -> bool:
"""Returns True if the PANEL_AUTORELOAD environment variable is set to 'yes' or 'true'.
Please note the casing of value does not matter
"""
return os.environ.get("PANEL_AUTORELOAD", "no").lower() in ["yes", "y", "true"]
@ -98,6 +99,7 @@ class PanelFrontend(Frontend):
For development you can get Panel autoreload by setting the ``PANEL_AUTORELOAD``
environment variable to 'yes', i.e. run
``PANEL_AUTORELOAD=yes lightning run app app_basic.py``
"""
@requires("panel")

View File

@ -26,6 +26,7 @@ Example:
.. code-block:: bash
python panel_serve_render_fn
"""
import inspect
import os

View File

@ -61,6 +61,7 @@ class StreamlitFrontend(Frontend):
st.write("Hello from streamlit!")
st.write(state.counter)
"""
@requires("streamlit")

View File

@ -14,6 +14,7 @@
"""This file gets run by streamlit, which we launch within Lightning.
From here, we will call the render function that the user provided in ``configure_layout``.
"""
import os
import pydoc

View File

@ -37,6 +37,7 @@ def _get_flow_state(flow: str) -> AppState:
Returns:
AppState: An AppState scoped to the current Flow.
"""
app_state = AppState()
app_state._request_state() # pylint: disable=protected-access
@ -54,6 +55,7 @@ def _get_frontend_environment(flow: str, render_fn_or_file: Callable | str, port
Returns:
os._Environ: An environment
"""
env = os.environ.copy()
env["LIGHTNING_FLOW_NAME"] = flow

View File

@ -44,6 +44,7 @@ class StaticWebFrontend(Frontend):
def configure_layout(self):
return StaticWebFrontend("path/to/folder/to/serve")
"""
def __init__(self, serve_dir: str) -> None:
@ -102,8 +103,7 @@ def _start_server(
def _get_log_config(log_file: str) -> dict:
"""Returns a logger configuration in the format expected by uvicorn that sends all logs to the given
logfile."""
"""Returns a logger configuration in the format expected by uvicorn that sends all logs to the given logfile."""
# Modified from the default config found in uvicorn.config.LOGGING_CONFIG
return {
"version": 1,

View File

@ -109,6 +109,7 @@ def run_lightning_work(
It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
specific logic is being implemented here
"""
logger.debug(f"Run Lightning Work {file} {work_name} {queue_id}")
@ -231,6 +232,7 @@ def serve_frontend(file: str, flow_name: str, host: str, port: int):
It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
specific logic is being implemented here.
"""
_set_frontend_context()
logger.debug(f"Run Serve Frontend {file} {flow_name} {host} {port}")
@ -340,15 +342,16 @@ def manage_server_processes(processes: List[Tuple[str, Process]]) -> None:
def _get_frontends_from_app(entrypoint_file):
"""This function is used to get the frontends from the app. It will be used to start the frontends in a
separate process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded
locally to set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081
if the app.spec.frontends is not set.
"""This function is used to get the frontends from the app. It will be used to start the frontends in a separate
process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded locally to
set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 if the
app.spec.frontends is not set.
NOTE: frontend_name are sorted to ensure that they get consistent ports.
:param entrypoint_file: The entrypoint file for the app
:return: A list of tuples of the form (frontend_name, port_number)
"""
app = load_app_from_file(entrypoint_file)

View File

@ -252,6 +252,7 @@ class CloudBackend(Backend):
Normally, the Lightning frameworks communicates statuses through the queues, but while the Work instance is
being provisionied, the queues don't exist yet and hence we need to make API calls directly to the backend to
fetch the status and update it in the states.
"""
if not works:
return
@ -305,6 +306,7 @@ class CloudBackend(Backend):
"""Stop resources for all LightningWorks in this app.
The Works are stopped rather than deleted so that they can be inspected for debugging.
"""
cloud_works = self._get_cloud_work_specs(self.client)

View File

@ -60,6 +60,7 @@ class LightningPlugin:
Returns:
The relative URL of the created job.
"""
from lightning.app.runners.backends.cloud import CloudBackend
from lightning.app.runners.cloud import CloudRuntime

View File

@ -198,8 +198,8 @@ class CloudRuntime(Runtime):
cluster_id: str,
source_app: Optional[str] = None,
) -> str:
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
such as the project and cluster IDs that are instead passed directly.
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such
as the project and cluster IDs that are instead passed directly.
Args:
project_id: The ID of the project.
@ -214,6 +214,7 @@ class CloudRuntime(Runtime):
Returns:
The URL of the created job.
"""
# Dispatch in four phases: resolution, validation, spec creation, API transactions
# Resolution
@ -432,6 +433,7 @@ class CloudRuntime(Runtime):
"""Find and load the config file if it exists (otherwise create an empty config).
Override the name if provided.
"""
config_file = _get_config_file(self.entrypoint)
cloudspace_config = AppConfig.load_from_file(config_file) if config_file.exists() and load else AppConfig()
@ -611,6 +613,7 @@ class CloudRuntime(Runtime):
"""Check if the user likely needs credits to run the app with its hardware.
Returns False if user has 1 or more credits.
"""
balance = project.balance
if balance is None:
@ -698,8 +701,8 @@ class CloudRuntime(Runtime):
raise RuntimeError(f"Unknown mount protocol `{mount.protocol}` for work `{work.name}`.")
def _get_flow_servers(self) -> List[V1Flowserver]:
"""Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs
to start servers."""
"""Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs to
start servers."""
flow_servers: List[V1Flowserver] = []
for flow_name in self.app.frontends:
flow_server = V1Flowserver(name=flow_name)
@ -889,8 +892,7 @@ class CloudRuntime(Runtime):
def _get_env_vars(
env_vars: Dict[str, str], secrets: Dict[str, str], run_app_comment_commands: bool
) -> List[V1EnvVar]:
"""Generate the list of environment variable specs for the app, including variables set by the
framework."""
"""Generate the list of environment variable specs for the app, including variables set by the framework."""
v1_env_vars = [V1EnvVar(name=k, value=v) for k, v in env_vars.items()]
if len(secrets.values()) > 0:
@ -929,6 +931,7 @@ class CloudRuntime(Runtime):
"""Create the cloudspace if it doesn't exist.
Return the cloudspace ID.
"""
if existing_cloudspace is None:
cloudspace_body = ProjectIdCloudspacesBody(name=name, can_download_source_code=True)
@ -980,6 +983,7 @@ class CloudRuntime(Runtime):
"""Transfer an existing instance to the given run ID and update its specification.
Return the instance.
"""
run_instance = self.backend.client.lightningapp_instance_service_update_lightningapp_instance_release(
project_id=project_id,

View File

@ -39,6 +39,7 @@ class MultiProcessRuntime(Runtime):
The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach
queues to enable communication between the different processes.
"""
backend: Union[str, Backend] = "multiprocessing"

View File

@ -66,6 +66,7 @@ def dispatch(
run_app_comment_commands: whether to parse commands from the entrypoint file and execute them before app startup
enable_basic_auth: whether to enable basic authentication for the app
(use credentials in the format username:password as an argument)
"""
from lightning.app.runners.runtime_type import RuntimeType
from lightning.app.utilities.component import _set_flow_context

View File

@ -34,9 +34,9 @@ def _copytree(
dirs_exist_ok=False,
dry_run=False,
) -> List[str]:
"""Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like
`git` does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks.
Differences between original and this function are.
"""Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like `git`
does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks. Differences
between original and this function are.
1. It supports a list of ignore function instead of a single one in the
original. We can use this for filtering out files based on nested
@ -66,6 +66,7 @@ def _copytree(
If exception(s) occur, an Error is raised with a list of reasons.
"""
files_copied = []
@ -146,8 +147,8 @@ def _parse_lightningignore(lines: Tuple[str]) -> Set[str]:
def _read_lightningignore(path: Path) -> Set[str]:
"""Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's
done to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
"""Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's done
to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
Parameters
----------
@ -158,6 +159,7 @@ def _read_lightningignore(path: Path) -> Set[str]:
-------
Set[str]
Set of unique lines.
"""
raw_lines = path.open().readlines()
return _parse_lightningignore(raw_lines)

View File

@ -33,6 +33,7 @@ def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int
----------
[1] https://crypto.stackexchange.com/questions/70101/blake2-vs-md5-for-checksum-file-integrity
[2] https://stackoverflow.com/questions/1131220/get-md5-hash-of-big-files-in-python
"""
# validate input
if algorithm == "blake2":

View File

@ -133,6 +133,7 @@ class LocalSourceCodeDir:
packaged repository files which have a size > 2GB.
This limitation should be removed during the datastore upload redesign
"""
if self.package_path.stat().st_size > 2e9:
raise OSError(

View File

@ -36,6 +36,7 @@ def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tu
-------
Tuple[int, int]
Size in megabytes and file count
"""
size = 0
count = 0
@ -61,6 +62,7 @@ class _TarResults:
The total size of the original directory files in bytes
after_size: int
The total size of the compressed and tarred split files in bytes
"""
before_size: int
@ -70,13 +72,12 @@ class _TarResults:
def _get_split_size(
total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT
) -> int:
"""Calculate the split size we should use to split the multipart upload of an object to a bucket. We are
limited to 1000 max parts as the way we are using ListMultipartUploads. More info
https://github.com/gridai/grid/pull/5267
"""Calculate the split size we should use to split the multipart upload of an object to a bucket. We are limited
to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267
https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process
https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31
bytes for a single file upload.
https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes
for a single file upload.
Parameters
----------
@ -91,6 +92,7 @@ def _get_split_size(
-------
int
Split size
"""
max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
if total_size > max_size:
@ -123,6 +125,7 @@ def _tar_path(source_path: str, target_file: str, compression: bool = False) ->
-------
TarResults
Results that holds file counts and sizes
"""
if os.path.isdir(source_path):
before_size, _ = _get_dir_size_and_count(source_path)
@ -149,6 +152,7 @@ def _tar_path_python(source_path: str, target_file: str, compression: bool = Fal
Target tar file
compression: bool, default False
Enable compression, which is disabled by default.
"""
file_mode = "w:gz" if compression else "w:"
@ -172,6 +176,7 @@ def _tar_path_subprocess(source_path: str, target_file: str, compression: bool =
Target tar file
compression: bool, default False
Enable compression, which is disabled by default.
"""
# Only add compression when users explicitly request it.
# We do this because it takes too long to compress

View File

@ -35,6 +35,7 @@ class FileUploader:
Size of all files to upload
name: str
Name of this upload to display progress
"""
workers: int = 8
@ -72,6 +73,7 @@ class FileUploader:
-------
str
ETag from response
"""
disconnect_retries = retries
while disconnect_retries > 0:

View File

@ -52,6 +52,7 @@ class _Copier(Thread):
will send requests to this queue.
copy_response_queue: A queue connecting the central StorageOrchestrator with the Copier. The Copier
will send a response to this queue whenever a requested copy has finished.
"""
def __init__(
@ -116,6 +117,7 @@ def _copy_files(
interpreted as a folder as well. If the source is a file, the destination path is interpreted as a file too.
Files in a folder are copied recursively and efficiently using multiple threads.
"""
if fs is None:
fs = _filesystem()

View File

@ -46,6 +46,7 @@ class Drive:
component_name: The component name which owns this drive.
When not provided, it is automatically inferred by Lightning.
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
"""
if id.startswith("s3://"):
raise ValueError(
@ -96,6 +97,7 @@ class Drive:
Arguments:
path: The relative path to your files to be added to the Drive.
"""
if not self.component_name:
raise Exception("The component name needs to be known to put a path to the Drive.")
@ -121,6 +123,7 @@ class Drive:
path: The relative path you want to list files from the Drive.
component_name: By default, the Drive lists files across all components.
If you provide a component name, the listing is specific to this component.
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to list files from a Drive.")
@ -165,6 +168,7 @@ class Drive:
If you provide a component name, the matching is specific to this component.
timeout: Whether to wait for the files to be available if not created yet.
overwrite: Whether to override the provided path if it exists.
"""
if _is_flow_context():
raise Exception("The flow isn't allowed to get files from a Drive.")
@ -207,11 +211,12 @@ class Drive:
self._get(self.fs, match, pathlib.Path(os.path.join(self.root_folder, path)).resolve(), overwrite=overwrite)
def delete(self, path: str) -> None:
"""This method enables to delete files under the provided path from the Drive in a blocking fashion. Only
the component which added a file can delete them.
"""This method enables to delete files under the provided path from the Drive in a blocking fashion. Only the
component which added a file can delete them.
Arguments:
path: The relative path you want to delete files from the Drive.
"""
if not self.component_name:
raise Exception("The component name needs to be known to delete a path to the Drive.")

View File

@ -42,6 +42,7 @@ class FileSystem:
src_path: The path to your files locally
dst_path: The path to your files transfered in the shared storage.
put_fn: The method to use to put files in the shared storage.
"""
if not os.path.exists(Path(src_path).resolve()):
raise FileExistsError(f"The provided path {src_path} doesn't exist")
@ -66,6 +67,7 @@ class FileSystem:
src_path: The path to your files in the shared storage
dst_path: The path to your files transfered locally
get_fn: The method to use to put files in the shared storage.
"""
if not src_path.startswith("/"):
raise Exception(f"The provided destination {src_path} needs to start with `/`.")
@ -80,6 +82,7 @@ class FileSystem:
Arguments:
path: The path to files to list.
"""
if not path.startswith("/"):
raise Exception(f"The provided destination {path} needs to start with `/`.")
@ -104,6 +107,7 @@ class FileSystem:
Arguments:
path: The path to files to list.
"""
if not path.startswith("/"):
raise Exception(f"The provided destination {path} needs to start with `/`.")

View File

@ -34,6 +34,7 @@ class Mount:
mount_path: An absolute directory path in the work where external data source should
be mounted as a filesystem. This path should not already exist in your codebase.
If not included, then the root_dir will be set to `/data/<last folder name in the bucket>`
"""
source: str = ""

View File

@ -47,6 +47,7 @@ class StorageOrchestrator(Thread):
put requests on this queue for the file-transfer thread to complete.
copy_response_queues: A dictionary of Queues where each Queue connects to one Work. The queue is expected to
contain the completion response from the file-transfer thread running in the Work process.
"""
def __init__(

View File

@ -53,6 +53,7 @@ class Path(PathlibPath):
Args:
*args: Accepts the same arguments as in :class:`pathlib.Path`
**kwargs: Accepts the same keyword arguments as in :class:`pathlib.Path`
"""
@classmethod
@ -105,6 +106,7 @@ class Path(PathlibPath):
"""The name of the LightningWork where this path was first created.
Attaching a Path to a LightningWork will automatically make it the `origin`.
"""
from lightning.app.core.work import LightningWork
@ -115,6 +117,7 @@ class Path(PathlibPath):
"""The name of the LightningWork where this path is being accessed.
By default, this is the same as the :attr:`origin_name`.
"""
from lightning.app.core.work import LightningWork
@ -125,6 +128,7 @@ class Path(PathlibPath):
"""The hash of this Path uniquely identifies the file path and the associated origin Work.
Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
"""
if self._origin is None:
return None
@ -152,6 +156,7 @@ class Path(PathlibPath):
If you strictly want to check local existence only, use :meth:`exists_local` instead. If you strictly want
to check existence on the remote (regardless of whether the file exists locally or not), use
:meth:`exists_remote`.
"""
return self.exists_local() or (self._origin and self.exists_remote())
@ -164,6 +169,7 @@ class Path(PathlibPath):
Raises:
RuntimeError: If the path is not attached to any Work (origin undefined).
"""
# Fail early if we need to check the remote but an origin is not defined
if not self._origin or self._request_queue is None or self._response_queue is None:
@ -272,6 +278,7 @@ class Path(PathlibPath):
Args:
work: LightningWork to be attached to this Path.
"""
if self._origin is None:
# Can become an owner only if there is not already one
@ -374,11 +381,11 @@ def _is_lit_path(path: Union[str, Path]) -> bool:
def _shared_local_mount_path() -> pathlib.Path:
"""Returns the shared directory through which the Copier threads move files from one Work filesystem to
another.
"""Returns the shared directory through which the Copier threads move files from one Work filesystem to another.
The shared directory can be set via the environment variable ``SHARED_MOUNT_DIRECTORY`` and should be pointing to a
directory that all Works have mounted (shared filesystem).
"""
path = pathlib.Path(os.environ.get("SHARED_MOUNT_DIRECTORY", ".shared"))
path.mkdir(parents=True, exist_ok=True)
@ -397,6 +404,7 @@ def _shared_storage_path() -> pathlib.Path:
The shared path gets set by the environment. Locally, it is pointing to a directory determined by the
``SHARED_MOUNT_DIRECTORY`` environment variable. In the cloud, the shared path will point to a S3 bucket. All Works
have access to this shared dropbox.
"""
storage_path = os.getenv("LIGHTNING_STORAGE_PATH", "")
if storage_path != "":

View File

@ -61,6 +61,7 @@ class _BasePayload(ABC):
"""The hash of this Payload uniquely identifies the payload and the associated origin Work.
Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
"""
if self._origin is None:
return None
@ -72,6 +73,7 @@ class _BasePayload(ABC):
"""The name of the LightningWork where this payload was first created.
Attaching a Payload to a LightningWork will automatically make it the `origin`.
"""
from lightning.app.core.work import LightningWork
@ -82,6 +84,7 @@ class _BasePayload(ABC):
"""The name of the LightningWork where this payload is being accessed.
By default, this is the same as the :attr:`origin_name`.
"""
from lightning.app.core.work import LightningWork
@ -107,6 +110,7 @@ class _BasePayload(ABC):
Args:
work: LightningWork to be attached to this Payload.
"""
if self._origin is None:
# Can become an owner only if there is not already one
@ -130,6 +134,7 @@ class _BasePayload(ABC):
Raises:
RuntimeError: If the payload is not attached to any Work (origin undefined).
"""
# Fail early if we need to check the remote but an origin is not defined
if not self._origin or self._request_queue is None or self._response_queue is None:

View File

@ -62,6 +62,7 @@ class _RunIf:
@pytest.mark.parametrize("arg1", [1, 2.0])
def test_wrapper(arg1):
assert arg1 > 0.0
"""
def __new__(
@ -155,6 +156,7 @@ class EmptyFlow(LightningFlow):
"""A LightningFlow that implements all abstract methods to do nothing.
Useful for mocking in tests.
"""
def run(self):
@ -165,6 +167,7 @@ class EmptyWork(LightningWork):
"""A LightningWork that implements all abstract methods to do nothing.
Useful for mocking in tests.
"""
def run(self):

View File

@ -503,6 +503,7 @@ def delete_cloud_lightning_apps(name=None):
"""Cleanup cloud apps that start with the name test-{PR_NUMBER}-{TEST_APP_NAME}.
PR_NUMBER and TEST_APP_NAME are environment variables.
"""
client = LightningClient()

View File

@ -44,6 +44,7 @@ def _extract_commands_from_file(file_name: str) -> CommandLines:
"""Extract all lines at the top of the file which contain commands to execute.
The return struct contains a list of commands to execute with the corresponding line number the command executed on.
"""
cl = CommandLines(
file=file_name,
@ -83,6 +84,7 @@ def _execute_app_commands(cl: CommandLines) -> None:
"""Open a subprocess shell to execute app commands.
The calling app environment is used in the current environment the code is running in
"""
for command, line_number in zip(cl.commands, cl.line_numbers):
logger.info(f"Running app setup command: {command}")
@ -116,6 +118,7 @@ def run_app_commands(file: str) -> None:
foo! bar <--- not a command import lightning <--- not a command, end parsing.
where `echo "hello world" && pip install foo` would be executed in the current running environment.
"""
cl = _extract_commands_from_file(file_name=file)
if len(cl.commands) == 0:

View File

@ -57,8 +57,7 @@ class StateEntry:
class StateStore(ABC):
"""Base class of State store that provides simple key, value store to keep track of app state, served app
state."""
"""Base class of State store that provides simple key, value store to keep track of app state, served app state."""
@abstractmethod
def __init__(self):
@ -352,6 +351,7 @@ def _walk_to_component(
"""Returns a generator that runs through the tree starting from the root down to the given component.
At each node, yields parent and child as a tuple.
"""
from lightning.app.structures import Dict, List
@ -469,6 +469,7 @@ def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict:
root_flow: The flow at the top of the component tree.
state: The collected state dict.
strict: Whether to validate all components have been re-created.
"""
# 1: Reload the state of the existing works
for w in root_flow.works():

View File

@ -49,6 +49,7 @@ def _push_log_events_to_read_queue_callback(component_name: str, read_queue: que
"""Pushes _LogEvents from websocket to read_queue.
Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
"""
def callback(ws_app: WebSocketApp, msg: str):

View File

@ -42,6 +42,7 @@ def _credential_string_to_basic_auth_params(credential_string: str) -> Dict[str,
"""Returns the name/ID pair for each given Secret name.
Raises a `ValueError` if any of the given Secret names do not exist.
"""
if credential_string.count(":") != 1:
raise ValueError(

View File

@ -129,6 +129,7 @@ class _LightningAppOpenAPIRetriever:
Arguments:
app_id_or_name_or_url: An identified for the app.
use_cache: Whether to load the openapi spec from the cache.
"""
self.app_id_or_name_or_url = app_id_or_name_or_url
self.url = None

View File

@ -25,6 +25,7 @@ def _get_default_cluster(client: LightningClient, project_id: str) -> str:
"""This utility implements a minimal version of the cluster selection logic used in the cloud.
TODO: This should be requested directly from the platform.
"""
cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id=project_id).clusters

View File

@ -36,6 +36,7 @@ def _convert_paths_after_init(root: "LightningFlow"):
This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths
that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or
consumer.
"""
from lightning.app import LightningFlow, LightningWork
from lightning.app.storage import Path
@ -52,6 +53,7 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
"""Utility function to sanitize the state of a component.
Sanitization enables the state to be deep-copied and hashed.
"""
from lightning.app.storage import Drive, Path
from lightning.app.storage.payload import _BasePayload
@ -132,6 +134,7 @@ def _context(ctx: str) -> Generator[None, None, None]:
The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork.
See also :func:`_get_context`, :func:`_set_context`. For internal use only.
"""
prev = _get_context()
_set_context(ctx)

View File

@ -29,6 +29,7 @@ class AttributeDict(Dict):
"key2": abc
"my-key": 3.14
"new_key": 42
"""
def __getattr__(self, key: str) -> Optional[Any]:

View File

@ -29,6 +29,7 @@ class _ApiExceptionHandler(Group):
However, if the ApiException cannot be decoded, or is not
a 4xx error, the original ApiException will be re-raised.
"""
def invoke(self, ctx: Context) -> Any:
@ -81,6 +82,7 @@ class LightningPlatformException(Exception): # pragma: no cover
It gets raised by the Lightning Launcher on the platform side when the app is running in the cloud, and is useful
when framework or user code needs to catch exceptions specific to the platform, e.g., when resources exceed quotas.
"""

View File

@ -26,6 +26,7 @@ def execute_git_command(args: List[str], cwd=None) -> str:
-------
output: str
String combining stdout and stderr.
"""
process = subprocess.run(["git"] + args, capture_output=True, text=True, cwd=cwd, check=False)
return process.stdout.strip() + process.stderr.strip()
@ -61,6 +62,7 @@ def check_if_remote_head_is_different() -> Union[bool, None]:
This only compares the local SHA to the HEAD commit of a given branch. This check won't be used if user isn't in a
HEAD locally.
"""
# Check SHA values.
local_sha = execute_git_command(["rev-parse", "@"])
@ -78,6 +80,7 @@ def has_uncommitted_files() -> bool:
"""Checks if user has uncommited files in local repository.
If there are uncommited files, then show a prompt indicating that uncommited files exist locally.
"""
files = execute_git_command(["update-index", "--refresh"])
return bool(files)

View File

@ -32,6 +32,7 @@ def _get_extras(extras: str) -> str:
"""Get the given extras as a space delimited string.
Used by the platform to install cloud extras in the cloud.
"""
from lightning.app import __package_name__

View File

@ -22,13 +22,14 @@ if TYPE_CHECKING:
class LightningVisitor(ast.NodeVisitor):
"""Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected
to define class_name and implement the analyze_class_def method.
"""Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to
define class_name and implement the analyze_class_def method.
Attributes
----------
class_name: str
Name of class to identify, to be defined in subclasses.
"""
class_name: Optional[str] = None
@ -63,6 +64,7 @@ class LightningModuleVisitor(LightningVisitor):
Names of methods that are part of the LightningModule API.
hooks: Set[str]
Names of hooks that are part of the LightningModule API.
"""
class_name: Optional[str] = "LightningModule"
@ -132,6 +134,7 @@ class LightningDataModuleVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the LightningDataModule API.
"""
class_name = "LightningDataModule"
@ -155,6 +158,7 @@ class LightningLoggerVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the Logger API.
"""
class_name = "Logger"
@ -171,6 +175,7 @@ class LightningCallbackVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the Logger API.
"""
class_name = "Callback"
@ -223,6 +228,7 @@ class LightningStrategyVisitor(LightningVisitor):
Name of class to identify.
methods: Set[str]
Names of methods that are part of the Logger API.
"""
class_name = "Strategy"
@ -282,6 +288,7 @@ class Scanner:
glob_pattern: str
Glob pattern to use when looking for files in the path,
applied when path is a directory. Default is "**/*.py".
"""
# TODO: Finalize introspecting the methods from all the discovered methods.
@ -341,6 +348,7 @@ class Scanner:
List[Dict[str, Any]]
List of dicts containing all metadata required
to import modules found.
"""
modules_found: Dict[str, List[Dict[str, Any]]] = {}

View File

@ -26,6 +26,7 @@ def _add_comment_to_literal_code(method, contains, comment):
"""Inspects a method's code and adds a message to it.
This is a nice to have, so if it fails for some reason, it shouldn't affect the program.
"""
try:
lines = inspect.getsource(method)

View File

@ -61,6 +61,7 @@ def _load_objects_from_file(
raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit.
mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to
be loaded without installing dependencies.
"""
# Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313
@ -110,6 +111,7 @@ def load_app_from_file(
Arguments:
filepath: The path to the file containing the LightningApp.
raise_exception: If True, raise an exception if the app cannot be loaded.
"""
from lightning.app.core.app import LightningApp
@ -142,6 +144,7 @@ def open_python_file(filename):
In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263
headers in their source files with their own encodings. In that case, we should respect the author's encoding.
"""
import tokenize
@ -204,6 +207,7 @@ def _patch_sys_path(append):
Args:
append: The value to append to the path.
"""
if append in sys.path:
yield

View File

@ -60,6 +60,7 @@ class Auth:
Returns
----------
True if credentials are available.
"""
if not self.secrets_file.exists():
logger.debug("Credentials file not found.")
@ -117,6 +118,7 @@ class Auth:
Returns
----------
authorization header to use when authentication completes.
"""
if not self.load():
# First try to authenticate from env

View File

@ -79,6 +79,7 @@ class _LightningLogsSocketAPI(_AuthTokenGetter):
Returns:
WebSocketApp of the wanted socket
"""
_token = self._get_api_token()
clean_ws_host = urlparse(self.api_client.configuration.host).netloc

View File

@ -1353,6 +1353,7 @@ def get_unique_name():
'meek-ardinghelli-4506'
>>> get_unique_name()
'truthful-dijkstra-2286'
"""
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311
return f"{adjective}-{surname}-{i}"

View File

@ -95,6 +95,7 @@ def _configure_session() -> Session:
"""Configures the session for GET and POST requests.
It enables a generous retrial strategy that waits for the application server to connect.
"""
retry_strategy = Retry(
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
@ -124,10 +125,10 @@ def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> floa
def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable:
"""Returns the function decorated by a wrapper that retries the call several times if a connection error
occurs.
"""Returns the function decorated by a wrapper that retries the call several times if a connection error occurs.
The retries follow an exponential backoff.
"""
@wraps(func)
@ -175,6 +176,7 @@ class LightningClient(GridRestClient):
Args:
retry: Whether API calls should follow a retry mechanism with exponential backoff.
max_tries: Maximum number of attempts (or -1 to retry forever).
"""
def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None:
@ -275,5 +277,6 @@ class HTTPClient:
We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but
it is crucial for finding bugs when we have them
"""
pass

View File

@ -49,6 +49,7 @@ def create_openapi_object(json_obj: Dict, target: Any):
Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
object.
"""
if not isinstance(json_obj, dict):
raise TypeError("json_obj must be a dictionary")

View File

@ -29,6 +29,7 @@ class AppConfig:
Args:
name: Optional name of the application. If not provided, auto-generates a new name.
"""
name: str = field(default_factory=get_unique_name)
@ -56,6 +57,7 @@ class AppConfig:
Args:
directory: Path to a folder which contains the '.lightning' config file to load.
"""
return cls.load_from_file(pathlib.Path(directory, _APP_CONFIG_FILENAME))
@ -65,6 +67,7 @@ def _get_config_file(source_path: Union[str, pathlib.Path]) -> pathlib.Path:
Args:
source_path: A path to a folder or a file.
"""
source_path = pathlib.Path(source_path).absolute()
if source_path.is_file():

Some files were not shown because too many files have changed in this diff Show More