docformatter: config with black (#18064)
* docformatter: config with black * additional_dependencies: [tomli] * 119 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e33816ce60
commit
efa7b2f9ef
|
@ -86,6 +86,7 @@ class _RequirementWithComment(Requirement):
|
|||
'arrow>=1.2.0'
|
||||
>>> _RequirementWithComment("arrow").adjust("major")
|
||||
'arrow'
|
||||
|
||||
"""
|
||||
out = str(self)
|
||||
if self.strict:
|
||||
|
@ -115,6 +116,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen
|
|||
>>> txt = '\\n'.join(txt)
|
||||
>>> [r.adjust('none') for r in _parse_requirements(txt)]
|
||||
['this', 'example', 'foo # strict', 'thing']
|
||||
|
||||
"""
|
||||
lines = yield_lines(strs)
|
||||
pip_argument = None
|
||||
|
@ -149,6 +151,7 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str
|
|||
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
|
||||
>>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
['sphinx<...]
|
||||
|
||||
"""
|
||||
assert unfreeze in {"none", "major", "all"}
|
||||
path = Path(path_dir) / file_name
|
||||
|
@ -165,6 +168,7 @@ def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
|
|||
|
||||
>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
'...PyTorch Lightning is just organized PyTorch...'
|
||||
|
||||
"""
|
||||
path_readme = os.path.join(path_dir, "README.md")
|
||||
with open(path_readme, encoding="utf-8") as fo:
|
||||
|
@ -244,6 +248,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
|
|||
"""Load all base requirements from all particular packages and prune duplicates.
|
||||
|
||||
>>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
|
||||
|
||||
"""
|
||||
requires = [
|
||||
load_requirements(d, unfreeze="none" if freeze_requirements else "major")
|
||||
|
@ -300,6 +305,7 @@ def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]], lightning
|
|||
'http://pytorch_lightning.ai', \
|
||||
'from lightning_fabric import __version__', \
|
||||
'@lightning.ai']
|
||||
|
||||
"""
|
||||
out = lines[:]
|
||||
for source_import, target_import in mapping:
|
||||
|
|
|
@ -60,7 +60,8 @@ repos:
|
|||
rev: v1.7.3
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
|
||||
additional_dependencies: [tomli]
|
||||
args: ["--in-place"]
|
||||
|
||||
- repo: https://github.com/asottile/yesqa
|
||||
rev: v1.5.0
|
||||
|
|
|
@ -23,6 +23,7 @@ class FileServer(L.LightningWork):
|
|||
drive: The drive can share data inside your application.
|
||||
base_dir: The local directory where the data will be stored.
|
||||
chunk_size: The quantity of bytes to download/upload at once.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
cloud_build_config=L.BuildConfig(["flask, flask-cors"]),
|
||||
|
@ -238,4 +239,5 @@ def test_file_server_in_cloud():
|
|||
|
||||
# 2. By calling logs = get_logs_fn(),
|
||||
# you get all the logs currently on the admin page.
|
||||
|
||||
"""
|
||||
|
|
|
@ -36,6 +36,7 @@ class GithubRepoRunner(TracerPythonScript):
|
|||
script_args: The arguments to be provided to the script.
|
||||
requirements: The python requirements tp run the script.
|
||||
cloud_compute: The object to select the cloud instance.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
script_path=script_path,
|
||||
|
|
|
@ -10,6 +10,7 @@ class Locust(LightningWork):
|
|||
|
||||
Arguments:
|
||||
num_users: Number of users emulated by Locust
|
||||
|
||||
"""
|
||||
# Note: Using the default port 8089 of Locust.
|
||||
super().__init__(
|
||||
|
|
|
@ -18,6 +18,7 @@ class MLServer(LightningWork):
|
|||
Example: "mlserver_sklearn.SKLearnModel".
|
||||
Learn more here: $ML_SERVER_URL/tree/master/runtimes
|
||||
workers: Number of server worker.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -51,6 +52,7 @@ class MLServer(LightningWork):
|
|||
|
||||
Arguments:
|
||||
model_path: The path to the trained model.
|
||||
|
||||
"""
|
||||
# 1: Use the host and port at runtime so it works in the cloud.
|
||||
# $ML_SERVER_URL/blob/master/mlserver/settings.py#L50
|
||||
|
|
|
@ -15,6 +15,7 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
|
|||
|
||||
Usage:
|
||||
download_file('http://web4host.net/5MB.zip')
|
||||
|
||||
"""
|
||||
if url == "NEED_TO_BE_CREATED":
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -5,6 +5,7 @@ Run the app:
|
|||
lightning run app examples/layout/demo.py
|
||||
|
||||
This starts one server for each flow that returns a UI. Access the UI at the link printed in the terminal.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
|
@ -137,6 +137,7 @@ class MyCustomTrainer:
|
|||
If not specified, no validation will run.
|
||||
ckpt_path: Path to previous checkpoints to resume training from.
|
||||
If specified, will always look for the latest checkpoint within the given directory.
|
||||
|
||||
"""
|
||||
self.fabric.launch()
|
||||
|
||||
|
@ -207,6 +208,7 @@ class MyCustomTrainer:
|
|||
If greater then the number of batches in the ``train_loader``, this has no effect.
|
||||
scheduler_cfg: The learning rate scheduler configuration.
|
||||
Have a look at :meth:`lightning.pytorch.LightninModule.configure_optimizers` for supported values.
|
||||
|
||||
"""
|
||||
self.fabric.call("on_train_epoch_start")
|
||||
iterable = self.progbar_wrapper(
|
||||
|
@ -268,6 +270,7 @@ class MyCustomTrainer:
|
|||
val_loader: The dataloader yielding the validation batches.
|
||||
limit_batches: Limits the batches during this validation epoch.
|
||||
If greater then the number of batches in the ``val_loader``, this has no effect.
|
||||
|
||||
"""
|
||||
# no validation if val_loader wasn't passed
|
||||
if val_loader is None:
|
||||
|
@ -311,13 +314,14 @@ class MyCustomTrainer:
|
|||
torch.set_grad_enabled(True)
|
||||
|
||||
def training_step(self, model: L.LightningModule, batch: Any, batch_idx: int) -> torch.Tensor:
|
||||
"""A single training step, running forward and backward. The optimizer step is called separately, as this
|
||||
is given as a closure to the optimizer step.
|
||||
"""A single training step, running forward and backward. The optimizer step is called separately, as this is
|
||||
given as a closure to the optimizer step.
|
||||
|
||||
Args:
|
||||
model: the lightning module to train
|
||||
batch: the batch to run the forward on
|
||||
batch_idx: index of the current batch w.r.t the current epoch
|
||||
|
||||
"""
|
||||
outputs: Union[torch.Tensor, Mapping[str, Any]] = model.training_step(batch, batch_idx=batch_idx)
|
||||
|
||||
|
@ -347,6 +351,7 @@ class MyCustomTrainer:
|
|||
Have a look at :meth:`lightning.pytorch.LightningModule.configure_optimizers` for supported values.
|
||||
level: whether we are trying to step on epoch- or step-level
|
||||
current_value: Holds the current_epoch if ``level==epoch``, else holds the ``global_step``
|
||||
|
||||
"""
|
||||
|
||||
# no scheduler
|
||||
|
@ -395,6 +400,7 @@ class MyCustomTrainer:
|
|||
Args:
|
||||
iterable: the iterable to wrap with tqdm
|
||||
total: the total length of the iterable, necessary in case the number of batches was limited.
|
||||
|
||||
"""
|
||||
if self.fabric.is_global_zero:
|
||||
return tqdm(iterable, total=total, **kwargs)
|
||||
|
@ -406,6 +412,7 @@ class MyCustomTrainer:
|
|||
Args:
|
||||
state: a mapping contaning model, optimizer and lr scheduler
|
||||
path: the path to load the checkpoint from
|
||||
|
||||
"""
|
||||
if state is None:
|
||||
state = {}
|
||||
|
@ -458,6 +465,7 @@ class MyCustomTrainer:
|
|||
Args:
|
||||
configure_optim_output: The output of ``configure_optimizers``.
|
||||
For supported values, please refer to :meth:`lightning.pytorch.LightningModule.configure_optimizers`.
|
||||
|
||||
"""
|
||||
_lr_sched_defaults = {"interval": "epoch", "frequency": 1, "monitor": "val_loss"}
|
||||
|
||||
|
@ -511,6 +519,7 @@ class MyCustomTrainer:
|
|||
prog_bar: a progressbar (on global rank zero) or an iterable (every other rank).
|
||||
candidates: the values to add as postfix strings to the progressbar.
|
||||
prefix: the prefix to add to each of these values.
|
||||
|
||||
"""
|
||||
if isinstance(prog_bar, tqdm) and candidates is not None:
|
||||
postfix_str = ""
|
||||
|
|
|
@ -25,6 +25,7 @@ and replace ``loss.backward()`` with ``self.backward(loss)``.
|
|||
Accelerate your training loop by setting the ``--accelerator``, ``--strategy``, ``--devices`` options directly from
|
||||
the command line. See ``lightning run model --help`` or learn more from the documentation:
|
||||
https://lightning.ai/docs/fabric.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""MNIST autoencoder example.
|
||||
|
||||
To run: python autoencoder.py --trainer.max_epochs=50
|
||||
|
||||
"""
|
||||
from os import path
|
||||
from typing import Optional, Tuple
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""MNIST backbone image classifier example.
|
||||
|
||||
To run: python backbone_image_classifier.py --trainer.max_epochs=50
|
||||
|
||||
"""
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
|
|
@ -20,6 +20,7 @@ visualized in 2 ways:
|
|||
* With PyTorch Tensorboard Profiler (Instructions are here: https://github.com/pytorch/kineto/tree/master/tb_plugin)
|
||||
1. pip install tensorboard torch-tb-profiler
|
||||
2. tensorboard --logdir={FOLDER}
|
||||
|
||||
"""
|
||||
|
||||
from os import path
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a
|
||||
pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the
|
||||
'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`,
|
||||
see below) is trained for 15 epochs.
|
||||
pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats
|
||||
and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see
|
||||
below) is trained for 15 epochs.
|
||||
|
||||
The training consists of three stages.
|
||||
|
||||
|
@ -37,6 +37,7 @@ Note:
|
|||
|
||||
To run:
|
||||
python computer_vision_fine_tuning.py fit
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
@ -97,6 +98,7 @@ class CatDogImageDataModule(LightningDataModule):
|
|||
dl_path: root directory where to download the data
|
||||
num_workers: number of CPU workers
|
||||
batch_size: number of sample in a batch
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -174,6 +176,7 @@ class TransferLearningModel(LightningModule):
|
|||
milestones: List of two epochs milestones
|
||||
lr: Initial learning rate
|
||||
lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.backbone = backbone
|
||||
|
@ -209,6 +212,7 @@ class TransferLearningModel(LightningModule):
|
|||
"""Forward pass.
|
||||
|
||||
Returns logits.
|
||||
|
||||
"""
|
||||
# 1. Feature extraction:
|
||||
x = self.feature_extractor(x)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
After a few epochs, launch TensorBoard to see the images being generated at every batch:
|
||||
|
||||
tensorboard --logdir default
|
||||
|
||||
"""
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ or show all options you can change:
|
|||
|
||||
python imagenet.py --help
|
||||
python imagenet.py fit --help
|
||||
|
||||
"""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
|
|
@ -29,6 +29,7 @@ References
|
|||
|
||||
[1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-
|
||||
Second-Edition/blob/master/Chapter06/02_dqn_pong.py
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
@ -54,6 +55,7 @@ class DQN(nn.Module):
|
|||
DQN(
|
||||
(net): Sequential(...)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
|
||||
|
@ -79,6 +81,7 @@ class ReplayBuffer:
|
|||
|
||||
>>> ReplayBuffer(5) # doctest: +ELLIPSIS
|
||||
<...reinforce_learn_Qnet.ReplayBuffer object at ...>
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
|
@ -96,6 +99,7 @@ class ReplayBuffer:
|
|||
|
||||
Args:
|
||||
experience: tuple (state, action, reward, done, new_state)
|
||||
|
||||
"""
|
||||
self.buffer.append(experience)
|
||||
|
||||
|
@ -117,6 +121,7 @@ class RLDataset(IterableDataset):
|
|||
|
||||
>>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS
|
||||
<...reinforce_learn_Qnet.RLDataset object at ...>
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
|
||||
|
@ -141,6 +146,7 @@ class Agent:
|
|||
>>> buffer = ReplayBuffer(10)
|
||||
>>> Agent(env, buffer) # doctest: +ELLIPSIS
|
||||
<...reinforce_learn_Qnet.Agent object at ...>
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
|
||||
|
@ -168,6 +174,7 @@ class Agent:
|
|||
|
||||
Returns:
|
||||
action
|
||||
|
||||
"""
|
||||
if np.random.random() < epsilon:
|
||||
action = self.env.action_space.sample()
|
||||
|
@ -194,6 +201,7 @@ class Agent:
|
|||
|
||||
Returns:
|
||||
reward, done
|
||||
|
||||
"""
|
||||
action = self.get_action(net, epsilon, device)
|
||||
|
||||
|
@ -222,6 +230,7 @@ class DQNLightning(LightningModule):
|
|||
(net): Sequential(...)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -270,6 +279,7 @@ class DQNLightning(LightningModule):
|
|||
|
||||
Args:
|
||||
steps: number of random steps to populate the buffer with
|
||||
|
||||
"""
|
||||
for i in range(steps):
|
||||
self.agent.play_step(self.net, epsilon=1.0)
|
||||
|
@ -282,6 +292,7 @@ class DQNLightning(LightningModule):
|
|||
|
||||
Returns:
|
||||
q values
|
||||
|
||||
"""
|
||||
return self.net(x)
|
||||
|
||||
|
@ -293,6 +304,7 @@ class DQNLightning(LightningModule):
|
|||
|
||||
Returns:
|
||||
loss
|
||||
|
||||
"""
|
||||
states, actions, rewards, dones, next_states = batch
|
||||
|
||||
|
@ -308,8 +320,8 @@ class DQNLightning(LightningModule):
|
|||
return nn.MSELoss()(state_action_values, expected_state_action_values)
|
||||
|
||||
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
|
||||
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss
|
||||
based on the minibatch received.
|
||||
"""Carries out a single step through the environment to update the replay buffer. Then calculates loss based on
|
||||
the minibatch received.
|
||||
|
||||
Args:
|
||||
batch: current mini batch of replay data
|
||||
|
@ -317,6 +329,7 @@ class DQNLightning(LightningModule):
|
|||
|
||||
Returns:
|
||||
Training loss and log metrics
|
||||
|
||||
"""
|
||||
device = self.get_device(batch)
|
||||
epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame)
|
||||
|
|
|
@ -26,6 +26,7 @@ References
|
|||
[1] https://github.com/openai/baselines/blob/master/baselines/ppo2/ppo2.py
|
||||
[2] https://github.com/openai/spinningup
|
||||
[3] https://github.com/sid-sundrani/ppo_lightning
|
||||
|
||||
"""
|
||||
import argparse
|
||||
from typing import Callable, Iterator, List, Tuple
|
||||
|
@ -52,8 +53,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
|
|||
|
||||
|
||||
class ActorCategorical(nn.Module):
|
||||
"""Policy network, for discrete action spaces, which returns a distribution and an action given an
|
||||
observation."""
|
||||
"""Policy network, for discrete action spaces, which returns a distribution and an action given an observation."""
|
||||
|
||||
def __init__(self, actor_net):
|
||||
"""
|
||||
|
@ -81,6 +81,7 @@ class ActorCategorical(nn.Module):
|
|||
|
||||
Returns:
|
||||
log probability of the action under pi
|
||||
|
||||
"""
|
||||
return pi.log_prob(actions)
|
||||
|
||||
|
@ -117,6 +118,7 @@ class ActorContinuous(nn.Module):
|
|||
|
||||
Returns:
|
||||
log probability of the action under pi
|
||||
|
||||
"""
|
||||
return pi.log_prob(actions).sum(axis=-1)
|
||||
|
||||
|
@ -127,6 +129,7 @@ class ExperienceSourceDataset(IterableDataset):
|
|||
|
||||
Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the
|
||||
experience source and how the batch is generated is defined the Lightning model itself
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, generate_batch: Callable):
|
||||
|
@ -144,6 +147,7 @@ class PPOLightning(LightningModule):
|
|||
Train:
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -231,6 +235,7 @@ class PPOLightning(LightningModule):
|
|||
|
||||
Returns:
|
||||
Tuple of policy and action
|
||||
|
||||
"""
|
||||
pi, action = self.actor(x)
|
||||
value = self.critic(x)
|
||||
|
@ -245,6 +250,7 @@ class PPOLightning(LightningModule):
|
|||
|
||||
Returns:
|
||||
list of discounted rewards/advantages
|
||||
|
||||
"""
|
||||
assert isinstance(rewards[0], float)
|
||||
|
||||
|
@ -267,6 +273,7 @@ class PPOLightning(LightningModule):
|
|||
|
||||
Returns:
|
||||
list of advantages
|
||||
|
||||
"""
|
||||
rews = rewards + [last_value]
|
||||
vals = values + [last_value]
|
||||
|
@ -373,6 +380,7 @@ class PPOLightning(LightningModule):
|
|||
|
||||
Args:
|
||||
batch: batch of replay buffer/trajectory data
|
||||
|
||||
"""
|
||||
state, action, old_logp, qval, adv = batch
|
||||
|
||||
|
@ -405,8 +413,7 @@ class PPOLightning(LightningModule):
|
|||
return optimizer_actor, optimizer_critic
|
||||
|
||||
def optimizer_step(self, *args, **kwargs):
|
||||
"""Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data
|
||||
sample."""
|
||||
"""Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data sample."""
|
||||
for _ in range(self.nb_optim_iters):
|
||||
super().optimizer_step(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -31,8 +31,7 @@ DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27
|
|||
|
||||
|
||||
def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)):
|
||||
"""Create synthetic dataset with random images, just to simulate that the dataset have been already
|
||||
downloaded."""
|
||||
"""Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded."""
|
||||
path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH)
|
||||
path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH)
|
||||
for p_dir in (path_dir_images, path_dir_masks):
|
||||
|
@ -65,6 +64,7 @@ class KITTI(Dataset):
|
|||
In the `get_item` function, images and masks are resized to the given `img_size`, masks are
|
||||
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
|
||||
(mask does not usually require transforms, but they can be implemented in a similar way).
|
||||
|
||||
"""
|
||||
|
||||
IMAGE_PATH = os.path.join("training", "image_2")
|
||||
|
@ -154,6 +154,7 @@ class UNet(nn.Module):
|
|||
(5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int = 19, num_layers: int = 5, features_start: int = 64, bilinear: bool = False):
|
||||
|
@ -200,6 +201,7 @@ class DoubleConv(nn.Module):
|
|||
DoubleConv(
|
||||
(net): Sequential(...)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int):
|
||||
|
@ -229,6 +231,7 @@ class Down(nn.Module):
|
|||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int):
|
||||
|
@ -240,8 +243,8 @@ class Down(nn.Module):
|
|||
|
||||
|
||||
class Up(nn.Module):
|
||||
"""Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
|
||||
map from contracting path, followed by double 3x3 convolution.
|
||||
"""Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map
|
||||
from contracting path, followed by double 3x3 convolution.
|
||||
|
||||
>>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
Up(
|
||||
|
@ -250,6 +253,7 @@ class Up(nn.Module):
|
|||
(net): Sequential(...)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
|
||||
|
@ -306,6 +310,7 @@ class SegModel(LightningModule):
|
|||
)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -50,6 +50,13 @@ skip = ["_notebooks"]
|
|||
line-length = 120
|
||||
exclude = '(_notebooks/.*)'
|
||||
|
||||
[tool.docformatter]
|
||||
recursive = true
|
||||
# this need to be shorter as some docstings are r"""...
|
||||
wrap-summaries = 119
|
||||
wrap-descriptions = 120
|
||||
blank = true
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Diagnose your system and show basic information.
|
||||
|
||||
This server mainly to get detail info for better bug reporting.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
1
setup.py
1
setup.py
|
@ -38,6 +38,7 @@ There are considered three main scenarios for installing this project:
|
|||
compared against PyPI registry
|
||||
b) with a parameterization build desired packages in to standard `dist/` folder
|
||||
c) validate packages and publish to PyPI
|
||||
|
||||
"""
|
||||
import contextlib
|
||||
import glob
|
||||
|
|
|
@ -58,6 +58,7 @@ class _FastApiMockRequest:
|
|||
|
||||
def configure_api(self):
|
||||
return [Post("/api/v1/request", self.request)]
|
||||
|
||||
"""
|
||||
|
||||
_body: Optional[str] = None
|
||||
|
@ -116,6 +117,7 @@ class _HttpMethod:
|
|||
route: The path used to route the requests
|
||||
method: The associated flow method
|
||||
timeout: The time in seconds taken before raising a timeout exception.
|
||||
|
||||
"""
|
||||
self.route = route
|
||||
self.attached_to_flow = hasattr(method, "__self__")
|
||||
|
|
|
@ -566,6 +566,7 @@ def _install_app_from_source(
|
|||
If true, overwrite the app directory without asking if it already exists
|
||||
git_sha:
|
||||
The git_sha for checking out the git repo of the app.
|
||||
|
||||
"""
|
||||
|
||||
if not cwd:
|
||||
|
|
|
@ -46,6 +46,7 @@ def logs(app_name: str, components: List[str], follow: bool) -> None:
|
|||
Print logs only from selected works:
|
||||
|
||||
$ lightning show logs my-application root.work_a root.work_b
|
||||
|
||||
"""
|
||||
_show_logs(app_name, components, follow)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ r"""To test a lightning component:
|
|||
|
||||
1. Init the component.
|
||||
2. call .run()
|
||||
|
||||
"""
|
||||
from placeholdername.component import TemplateComponent
|
||||
|
||||
|
|
|
@ -57,6 +57,7 @@ def connect_app(app_name_or_id: str):
|
|||
\b
|
||||
# once done, disconnect and go back to the standard lightning CLI commands
|
||||
lightning disconnect
|
||||
|
||||
"""
|
||||
from lightning.app.utilities.commands.base import _download_command
|
||||
|
||||
|
|
|
@ -98,6 +98,7 @@ def delete_app(app_name: str, skip_user_confirm_prompt: bool) -> None:
|
|||
Deleting an app also deletes all app websites, works, artifacts, and logs. This permanently removes any record of
|
||||
the app as well as all any of its associated resources and data. This does not affect any resources and data
|
||||
associated with other Lightning apps on your account.
|
||||
|
||||
"""
|
||||
console = Console()
|
||||
|
||||
|
|
|
@ -40,10 +40,11 @@ def launch() -> None:
|
|||
@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str)
|
||||
@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int)
|
||||
def run_server(file: str, queue_id: str, host: str, port: int) -> None:
|
||||
"""It takes the application file as input, build the application object and then use that to run the
|
||||
application server.
|
||||
"""It takes the application file as input, build the application object and then use that to run the application
|
||||
server.
|
||||
|
||||
This is used by the cloud runners to start the status server for the application
|
||||
|
||||
"""
|
||||
logger.debug(f"Run Server: {file} {queue_id} {host} {port}")
|
||||
start_application_server(file, host, port, queue_id=queue_id)
|
||||
|
@ -54,10 +55,11 @@ def run_server(file: str, queue_id: str, host: str, port: int) -> None:
|
|||
@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
|
||||
@click.option("--base-url", help="Base url at which the app server is hosted", default="")
|
||||
def run_flow(file: str, queue_id: str, base_url: str) -> None:
|
||||
"""It takes the application file as input, build the application object, proxy all the work components and then
|
||||
run the application flow defined in the root component.
|
||||
"""It takes the application file as input, build the application object, proxy all the work components and then run
|
||||
the application flow defined in the root component.
|
||||
|
||||
It does exactly what a singleprocess dispatcher would do but with proxied work components.
|
||||
|
||||
"""
|
||||
logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
|
||||
run_lightning_flow(file, queue_id=queue_id, base_url=base_url)
|
||||
|
@ -68,8 +70,8 @@ def run_flow(file: str, queue_id: str, base_url: str) -> None:
|
|||
@click.option("--work-name", type=str)
|
||||
@click.option("--queue-id", help="ID for identifying queue", default="", type=str)
|
||||
def run_work(file: str, work_name: str, queue_id: str) -> None:
|
||||
"""Unlike other entrypoints, this command will take the file path or module details for a work component and
|
||||
run that by fetching the states from the queues."""
|
||||
"""Unlike other entrypoints, this command will take the file path or module details for a work component and run
|
||||
that by fetching the states from the queues."""
|
||||
logger.debug(f"Run Work: {file} {work_name} {queue_id}")
|
||||
run_lightning_work(
|
||||
file=file,
|
||||
|
@ -109,10 +111,11 @@ def run_flow_and_servers(
|
|||
port: int,
|
||||
flow_port: Tuple[Tuple[str, int]],
|
||||
) -> None:
|
||||
"""It takes the application file as input, build the application object and then use that to run the
|
||||
application flow defined in the root component, the application server and all the flow frontends.
|
||||
"""It takes the application file as input, build the application object and then use that to run the application
|
||||
flow defined in the root component, the application server and all the flow frontends.
|
||||
|
||||
This is used by the cloud runners to start the flow, the status server and all frontends for the application
|
||||
|
||||
"""
|
||||
logger.debug(f"Run Flow: {file} {queue_id} {base_url}")
|
||||
logger.debug(f"Run Server: {file} {queue_id} {host} {port}.")
|
||||
|
|
|
@ -13,6 +13,7 @@ class TensorBoard(LightningFlow):
|
|||
Args:
|
||||
log_dir: The path to the directory where the TensorBoard log-files will appear.
|
||||
sync_every_n_seconds: How often to sync the log directory (given as an argument to the run method)
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.worker = TensorBoardWorker(log_dir=log_dir, sync_every_n_seconds=sync_every_n_seconds)
|
||||
|
|
|
@ -29,6 +29,7 @@ def _configure_session() -> Session:
|
|||
"""Configures the session for GET and POST requests.
|
||||
|
||||
It enables a generous retrial strategy that waits for the application server to connect.
|
||||
|
||||
"""
|
||||
retry_strategy = Retry(
|
||||
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
|
||||
|
|
|
@ -158,6 +158,7 @@ class Database(LightningWork):
|
|||
|
||||
# RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
|
||||
kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
|
||||
|
||||
"""
|
||||
super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
|
||||
self.db_filename = db_filename
|
||||
|
|
|
@ -52,6 +52,7 @@ def _pydantic_column_type(pydantic_type: Any) -> Any:
|
|||
class TrialConfig(SQLModel, table=False):
|
||||
...
|
||||
params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
|
||||
|
||||
"""
|
||||
|
||||
class PydanticJSONType(TypeDecorator, Generic[T]):
|
||||
|
|
|
@ -67,6 +67,7 @@ class MultiNode(LightningFlow):
|
|||
running locally.
|
||||
work_args: Arguments to be provided to the work on instantiation.
|
||||
work_kwargs: Keywords arguments to be provided to the work on instantiation.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
if num_nodes > 1 and not is_running_in_cloud():
|
||||
|
|
|
@ -70,6 +70,7 @@ class PopenPythonScript(LightningWork):
|
|||
|
||||
.. literalinclude:: ../../../../examples/app/components/python/component_popen.py
|
||||
:language: python
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if not os.path.exists(script_path):
|
||||
|
|
|
@ -117,6 +117,7 @@ class TracerPythonScript(LightningWork):
|
|||
|
||||
.. literalinclude:: ../../../../examples/app/components/python/app.py
|
||||
:language: python
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.script_path = str(script_path)
|
||||
|
|
|
@ -114,8 +114,8 @@ def _create_fastapi(title: str) -> _TrackableFastAPI:
|
|||
|
||||
|
||||
class _LoadBalancer(LightningWork):
|
||||
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton
|
||||
API asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
|
||||
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
|
||||
asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
|
||||
|
||||
After enabling you will require to send username and password from the request header for the private endpoints.
|
||||
|
||||
|
@ -131,6 +131,7 @@ class _LoadBalancer(LightningWork):
|
|||
api_name: The name to be displayed on the UI. Normally, it is the name of the work class
|
||||
cold_start_proxy: The proxy service to use while the work is cold starting.
|
||||
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
|
||||
|
||||
"""
|
||||
|
||||
@requires(["aiohttp"])
|
||||
|
@ -236,6 +237,7 @@ class _LoadBalancer(LightningWork):
|
|||
|
||||
Two instances of this function should not be running with shared `_state_server` as that would create race
|
||||
conditions
|
||||
|
||||
"""
|
||||
while True:
|
||||
await asyncio.sleep(0.05)
|
||||
|
@ -293,6 +295,7 @@ class _LoadBalancer(LightningWork):
|
|||
"""This function checks if we have processing capacity for one more request or not.
|
||||
|
||||
Depends on the value from here, we decide whether we should proxy the request or not
|
||||
|
||||
"""
|
||||
if not self._fastapi_app:
|
||||
return False
|
||||
|
@ -385,6 +388,7 @@ class _LoadBalancer(LightningWork):
|
|||
"""Updates works that load balancer distributes requests to.
|
||||
|
||||
AutoScaler uses this method to increase/decrease the number of works.
|
||||
|
||||
"""
|
||||
old_server_urls = set(self.servers)
|
||||
current_server_urls = {
|
||||
|
@ -623,6 +627,7 @@ class AutoScaler(LightningFlow):
|
|||
|
||||
Returns:
|
||||
The name of the new work attribute.
|
||||
|
||||
"""
|
||||
work_attribute = uuid.uuid4().hex
|
||||
work_attribute = f"worker_{self.num_replicas}_{str(work_attribute)}"
|
||||
|
@ -665,6 +670,7 @@ class AutoScaler(LightningFlow):
|
|||
Returns:
|
||||
The target number of running works. The value will be adjusted after this method runs
|
||||
so that it satisfies ``min_replicas<=replicas<=max_replicas``.
|
||||
|
||||
"""
|
||||
pending_requests = metrics["pending_requests"]
|
||||
active_or_pending_works = replicas + metrics["pending_works"]
|
||||
|
|
|
@ -35,6 +35,7 @@ class ColdStartProxy:
|
|||
|
||||
Args:
|
||||
proxy_url (str): The url of the proxy service
|
||||
|
||||
"""
|
||||
|
||||
@requires(["aiohttp"])
|
||||
|
@ -46,12 +47,13 @@ class ColdStartProxy:
|
|||
|
||||
async def handle_request(self, request: BaseModel) -> Any:
|
||||
"""This method is called when the request is received while the work is cold starting. The default
|
||||
implementation of this method is to forward the request body to the proxy service with POST method but the
|
||||
user can override this method to handle the request in any way.
|
||||
implementation of this method is to forward the request body to the proxy service with POST method but the user
|
||||
can override this method to handle the request in any way.
|
||||
|
||||
Args:
|
||||
request (BaseModel): The request body, a pydantic model that is being
|
||||
forwarded by load balancer which is a FastAPI service
|
||||
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
|
@ -77,6 +77,7 @@ class ServeGradio(LightningWork, abc.ABC):
|
|||
"""Override to instantiate and return your model.
|
||||
|
||||
The model would be accessible under self.model
|
||||
|
||||
"""
|
||||
|
||||
def run(self, *args: Any, **kwargs: Any):
|
||||
|
|
|
@ -211,6 +211,7 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
... return {"prediction": self._model(request.image)}
|
||||
...
|
||||
>>> app = LightningApp(SimpleServer())
|
||||
|
||||
"""
|
||||
super().__init__(parallel=True, **kwargs)
|
||||
if not issubclass(input_type, BaseModel):
|
||||
|
@ -228,6 +229,7 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
|
||||
Note that this will be called exactly once on every work machines. So if you have multiple machines for serving,
|
||||
this will be called on each of them.
|
||||
|
||||
"""
|
||||
return
|
||||
|
||||
|
@ -243,6 +245,7 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
|
||||
This method must be overriden by the user with the prediction logic. The pre/post processing, actual prediction
|
||||
using the model(s) etc goes here
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -325,6 +328,7 @@ class PythonServer(LightningWork, abc.ABC):
|
|||
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
|
||||
|
||||
Normally, you don't need to override this method.
|
||||
|
||||
"""
|
||||
self.setup(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
|
|||
host: Address to be used to serve the model.
|
||||
port: Port to be used to serve the model.
|
||||
workers: Number of workers for the uvicorn. Warning, this won't work if your subclass takes more arguments.
|
||||
|
||||
"""
|
||||
super().__init__(parallel=True, host=host, port=port)
|
||||
if input and input not in _DESERIALIZER:
|
||||
|
@ -151,8 +152,8 @@ class ModelInferenceAPI(LightningWork, abc.ABC):
|
|||
|
||||
|
||||
def _maybe_create_instance() -> Optional[ModelInferenceAPI]:
|
||||
"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi
|
||||
workers are present."""
|
||||
"""This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi workers
|
||||
are present."""
|
||||
render_fn_name = os.getenv("LIGHTNING_MODEL_INFERENCE_API_CLASS_NAME", None)
|
||||
render_fn_module_file = os.getenv("LIGHTNING_MODEL_INFERENCE_API_FILE", None)
|
||||
if render_fn_name is None or render_fn_module_file is None:
|
||||
|
|
|
@ -50,6 +50,7 @@ class ServeStreamlit(LightningWork, abc.ABC):
|
|||
"""Optionally override to instantiate and return your model.
|
||||
|
||||
The model will be accessible under ``self.model``.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
|
|
|
@ -28,4 +28,5 @@ class BaseType(abc.ABCMeta):
|
|||
"""Take the inputs from the network and deserilize/convert them them.
|
||||
|
||||
Output from this method will go to the exposed method as arguments.
|
||||
|
||||
"""
|
||||
|
|
|
@ -159,6 +159,7 @@ class LightningTrainerScript(LightningFlow):
|
|||
cloud_compute: The cloud compute object used in the cloud.
|
||||
sanity_serving: Whether to validate that the model correctly implements
|
||||
the ServableModule API
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.script_path = script_path
|
||||
|
|
|
@ -270,8 +270,8 @@ async def post_delta(
|
|||
x_lightning_session_uuid: Optional[str] = Header(None), # type: ignore[assignment]
|
||||
x_lightning_session_id: Optional[str] = Header(None), # type: ignore[assignment]
|
||||
) -> Optional[Dict]:
|
||||
"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to
|
||||
update the state."""
|
||||
"""This endpoint is used to make an update to the app state using delta diff, mainly used by streamlit to update
|
||||
the state."""
|
||||
|
||||
if x_lightning_session_uuid is None:
|
||||
raise Exception("Missing X-Lightning-Session-UUID header")
|
||||
|
|
|
@ -383,8 +383,7 @@ class LightningApp:
|
|||
return deltas
|
||||
|
||||
def maybe_apply_changes(self) -> Optional[bool]:
|
||||
"""Get the deltas from both the flow queue and the work queue, merge the two deltas and update the
|
||||
state."""
|
||||
"""Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state."""
|
||||
self._send_flow_to_work_deltas(self.state)
|
||||
|
||||
if not self.collect_changes:
|
||||
|
@ -503,6 +502,7 @@ class LightningApp:
|
|||
"""Entry point of the LightningApp.
|
||||
|
||||
This would be dispatched by the Runtime objects.
|
||||
|
||||
"""
|
||||
self._original_state = deepcopy(self.state)
|
||||
done = False
|
||||
|
|
|
@ -396,6 +396,7 @@ class LightningFlow:
|
|||
|
||||
.. deprecated:: 1.9.0
|
||||
This function is deprecated and will be removed in 2.0.0. Use :meth:`stop` instead.
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
DeprecationWarning(
|
||||
|
@ -411,6 +412,7 @@ class LightningFlow:
|
|||
(prefixed by '__') attributes are not.
|
||||
|
||||
Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
|
||||
|
||||
"""
|
||||
return name in LightningFlow._INTERNAL_STATE_VARS or not name.startswith("_")
|
||||
|
||||
|
@ -487,6 +489,7 @@ class LightningFlow:
|
|||
</div>
|
||||
</div>
|
||||
<br />
|
||||
|
||||
"""
|
||||
if not user_key:
|
||||
frame = cast(FrameType, inspect.currentframe()).f_back
|
||||
|
@ -626,6 +629,7 @@ class LightningFlow:
|
|||
</div>
|
||||
</div>
|
||||
<br />
|
||||
|
||||
"""
|
||||
return [{"name": name, "content": component} for (name, component) in self.flows.items()]
|
||||
|
||||
|
@ -639,6 +643,7 @@ class LightningFlow:
|
|||
run_once: Whether to run the entire iteration only once.
|
||||
Otherwise, it would restart from the beginning.
|
||||
user_key: Key to be used to track the caching mechanism.
|
||||
|
||||
"""
|
||||
if not isinstance(iterable, Iterable):
|
||||
raise TypeError(f"An iterable should be provided to `self.iterate` method. Found {iterable}")
|
||||
|
@ -708,6 +713,7 @@ class LightningFlow:
|
|||
.. code-block:: bash
|
||||
|
||||
lightning my_command_name --args name=my_own_name
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -741,6 +747,7 @@ class LightningFlow:
|
|||
|
||||
Once the app is running, you can access the Swagger UI of the app
|
||||
under the ``/docs`` route.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -805,6 +812,7 @@ class LightningFlow:
|
|||
children_states: The state of the dynamic children of this flow.
|
||||
strict: Whether to raise an exception if a dynamic
|
||||
children hasn't been re-created.
|
||||
|
||||
"""
|
||||
self.set_state(flow_state, recurse=False)
|
||||
direct_children_states = {k: v for k, v in children_states.items() if "." not in k}
|
||||
|
|
|
@ -182,6 +182,7 @@ class BaseQueue(ABC):
|
|||
timeout:
|
||||
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
|
||||
A timeout of None can be used to block indefinitely.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -190,6 +191,7 @@ class BaseQueue(ABC):
|
|||
"""Returns True if the queue is running, False otherwise.
|
||||
|
||||
Child classes should override this property and implement custom logic as required
|
||||
|
||||
"""
|
||||
return True
|
||||
|
||||
|
@ -286,6 +288,7 @@ class RedisQueue(BaseQueue):
|
|||
timeout:
|
||||
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
|
||||
A timeout of None can be used to block indefinitely.
|
||||
|
||||
"""
|
||||
if timeout is None:
|
||||
# this means it's blocking in redis
|
||||
|
@ -464,6 +467,7 @@ class HTTPQueue(BaseQueue):
|
|||
|
||||
This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be
|
||||
accurate. Remove this eventually and let the Queue class take app id and name of the queue as arguments
|
||||
|
||||
"""
|
||||
if "_" not in queue_name:
|
||||
return "", queue_name
|
||||
|
|
|
@ -124,6 +124,7 @@ class LightningWork:
|
|||
</div>
|
||||
</div>
|
||||
<br />
|
||||
|
||||
"""
|
||||
from lightning.app.runners.backends.backend import Backend
|
||||
|
||||
|
@ -212,6 +213,7 @@ class LightningWork:
|
|||
|
||||
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
|
||||
Locally, the address is 127.0.0.1 and in the cloud it will be determined by the cluster.
|
||||
|
||||
"""
|
||||
return self._internal_ip
|
||||
|
||||
|
@ -221,6 +223,7 @@ class LightningWork:
|
|||
|
||||
By default, this attribute returns the empty string and the ip address will only be returned once the work runs.
|
||||
Locally, this address is undefined (empty string) and in the cloud it will be determined by the cluster.
|
||||
|
||||
"""
|
||||
return self._public_ip
|
||||
|
||||
|
@ -234,6 +237,7 @@ class LightningWork:
|
|||
(prefixed by '__') attributes are not.
|
||||
|
||||
Exceptions are listed in the `_INTERNAL_STATE_VARS` class variable.
|
||||
|
||||
"""
|
||||
return name in LightningWork._INTERNAL_STATE_VARS or not name.startswith("_")
|
||||
|
||||
|
@ -247,6 +251,7 @@ class LightningWork:
|
|||
"""Returns the display name of the LightningWork in the cloud.
|
||||
|
||||
The display name needs to set before the run method of the work is called.
|
||||
|
||||
"""
|
||||
return self._display_name
|
||||
|
||||
|
@ -269,6 +274,7 @@ class LightningWork:
|
|||
"""Whether to run in parallel mode or not.
|
||||
|
||||
When parallel is False, the flow waits for the work to finish.
|
||||
|
||||
"""
|
||||
return self._parallel
|
||||
|
||||
|
@ -325,6 +331,7 @@ class LightningWork:
|
|||
"""Return the current status of the work.
|
||||
|
||||
All statuses are stored in the state.
|
||||
|
||||
"""
|
||||
call_hash = self._calls[CacheCallsKeys.LATEST_CALL_HASH]
|
||||
if call_hash in self._calls:
|
||||
|
@ -628,6 +635,7 @@ class LightningWork:
|
|||
|
||||
Raises:
|
||||
LightningPlatformException: If resource exceeds platform quotas or other constraints.
|
||||
|
||||
"""
|
||||
|
||||
def on_exception(self, exception: BaseException) -> None:
|
||||
|
@ -636,8 +644,7 @@ class LightningWork:
|
|||
raise exception
|
||||
|
||||
def _aggregate_status_timeout(self, statuses: List[Dict]) -> WorkStatus:
|
||||
"""Method used to return the first request and the total count of timeout after the latest succeeded
|
||||
status."""
|
||||
"""Method used to return the first request and the total count of timeout after the latest succeeded status."""
|
||||
succeeded_statuses = [
|
||||
status_idx for status_idx, status in enumerate(statuses) if status["stage"] == WorkStageStatus.SUCCEEDED
|
||||
]
|
||||
|
@ -653,6 +660,7 @@ class LightningWork:
|
|||
"""Override this hook to add your logic when the work is exiting.
|
||||
|
||||
Note: This hook is not guaranteed to be called when running in the cloud.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -660,6 +668,7 @@ class LightningWork:
|
|||
"""Stops LightingWork component and shuts down hardware provisioned via L.CloudCompute.
|
||||
|
||||
This can only be called from a ``LightningFlow``.
|
||||
|
||||
"""
|
||||
if not self._backend:
|
||||
raise RuntimeError(f"Only the `LightningFlow` can request this work ({self.name!r}) to stop.")
|
||||
|
@ -675,6 +684,7 @@ class LightningWork:
|
|||
"""Delete LightingWork component and shuts down hardware provisioned via L.CloudCompute.
|
||||
|
||||
Locally, the work.delete() behaves as work.stop().
|
||||
|
||||
"""
|
||||
if not self._backend:
|
||||
raise Exception(
|
||||
|
@ -755,4 +765,5 @@ class LightningWork:
|
|||
returned URL can depend on the state. This is not the case if the work returns a
|
||||
:class:`~lightning.app.frontend.frontend.Frontend`. These need to be provided at the time of app creation
|
||||
in order for the runtime to start the server.
|
||||
|
||||
"""
|
||||
|
|
|
@ -23,6 +23,7 @@ class Frontend(ABC):
|
|||
"""Base class for any frontend that gets exposed by LightningFlows.
|
||||
|
||||
The flow attribute will be set by the app while bootstrapping.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
@ -48,6 +49,7 @@ class Frontend(ABC):
|
|||
|
||||
def start_server(self, host, port, root_path=""):
|
||||
self._process = subprocess.Popen(["flask", "run" "--host", host, "--port", str(port)])
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
@ -62,4 +64,5 @@ class Frontend(ABC):
|
|||
|
||||
def stop_server(self):
|
||||
self._process.kill()
|
||||
|
||||
"""
|
||||
|
|
|
@ -81,6 +81,7 @@ class JustPyFrontend(Frontend):
|
|||
|
||||
|
||||
app = LightningApp(Flow())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, render_fn: Callable) -> None:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app
|
||||
framework."""
|
||||
"""The PanelFrontend and AppStateWatcher make it easy to create Lightning Apps with the Panel data app framework."""
|
||||
from lightning.app.frontend.panel.app_state_watcher import AppStateWatcher
|
||||
from lightning.app.frontend.panel.panel_frontend import PanelFrontend
|
||||
|
||||
|
|
|
@ -94,6 +94,7 @@ def _watch_app_state(callback: Callable):
|
|||
def handle_state_change():
|
||||
print("The App State changed.")
|
||||
watch_app_state(handle_state_change)
|
||||
|
||||
"""
|
||||
_CALLBACKS.append(callback)
|
||||
_start_websocket()
|
||||
|
|
|
@ -72,6 +72,7 @@ class AppStateWatcher(Parameterized):
|
|||
Pydantic which additionally provides powerful and unique features for building reactive apps.
|
||||
|
||||
Please note the ``AppStateWatcher`` is a singleton, i.e., only one instance is instantiated
|
||||
|
||||
"""
|
||||
|
||||
state: AppState = ClassSelector(
|
||||
|
|
|
@ -35,6 +35,7 @@ def _has_panel_autoreload() -> bool:
|
|||
"""Returns True if the PANEL_AUTORELOAD environment variable is set to 'yes' or 'true'.
|
||||
|
||||
Please note the casing of value does not matter
|
||||
|
||||
"""
|
||||
return os.environ.get("PANEL_AUTORELOAD", "no").lower() in ["yes", "y", "true"]
|
||||
|
||||
|
@ -98,6 +99,7 @@ class PanelFrontend(Frontend):
|
|||
For development you can get Panel autoreload by setting the ``PANEL_AUTORELOAD``
|
||||
environment variable to 'yes', i.e. run
|
||||
``PANEL_AUTORELOAD=yes lightning run app app_basic.py``
|
||||
|
||||
"""
|
||||
|
||||
@requires("panel")
|
||||
|
|
|
@ -26,6 +26,7 @@ Example:
|
|||
.. code-block:: bash
|
||||
|
||||
python panel_serve_render_fn
|
||||
|
||||
"""
|
||||
import inspect
|
||||
import os
|
||||
|
|
|
@ -61,6 +61,7 @@ class StreamlitFrontend(Frontend):
|
|||
|
||||
st.write("Hello from streamlit!")
|
||||
st.write(state.counter)
|
||||
|
||||
"""
|
||||
|
||||
@requires("streamlit")
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""This file gets run by streamlit, which we launch within Lightning.
|
||||
|
||||
From here, we will call the render function that the user provided in ``configure_layout``.
|
||||
|
||||
"""
|
||||
import os
|
||||
import pydoc
|
||||
|
|
|
@ -37,6 +37,7 @@ def _get_flow_state(flow: str) -> AppState:
|
|||
|
||||
Returns:
|
||||
AppState: An AppState scoped to the current Flow.
|
||||
|
||||
"""
|
||||
app_state = AppState()
|
||||
app_state._request_state() # pylint: disable=protected-access
|
||||
|
@ -54,6 +55,7 @@ def _get_frontend_environment(flow: str, render_fn_or_file: Callable | str, port
|
|||
|
||||
Returns:
|
||||
os._Environ: An environment
|
||||
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
env["LIGHTNING_FLOW_NAME"] = flow
|
||||
|
|
|
@ -44,6 +44,7 @@ class StaticWebFrontend(Frontend):
|
|||
|
||||
def configure_layout(self):
|
||||
return StaticWebFrontend("path/to/folder/to/serve")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, serve_dir: str) -> None:
|
||||
|
@ -102,8 +103,7 @@ def _start_server(
|
|||
|
||||
|
||||
def _get_log_config(log_file: str) -> dict:
|
||||
"""Returns a logger configuration in the format expected by uvicorn that sends all logs to the given
|
||||
logfile."""
|
||||
"""Returns a logger configuration in the format expected by uvicorn that sends all logs to the given logfile."""
|
||||
# Modified from the default config found in uvicorn.config.LOGGING_CONFIG
|
||||
return {
|
||||
"version": 1,
|
||||
|
|
|
@ -109,6 +109,7 @@ def run_lightning_work(
|
|||
|
||||
It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
|
||||
specific logic is being implemented here
|
||||
|
||||
"""
|
||||
logger.debug(f"Run Lightning Work {file} {work_name} {queue_id}")
|
||||
|
||||
|
@ -231,6 +232,7 @@ def serve_frontend(file: str, flow_name: str, host: str, port: int):
|
|||
|
||||
It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud
|
||||
specific logic is being implemented here.
|
||||
|
||||
"""
|
||||
_set_frontend_context()
|
||||
logger.debug(f"Run Serve Frontend {file} {flow_name} {host} {port}")
|
||||
|
@ -340,15 +342,16 @@ def manage_server_processes(processes: List[Tuple[str, Process]]) -> None:
|
|||
|
||||
|
||||
def _get_frontends_from_app(entrypoint_file):
|
||||
"""This function is used to get the frontends from the app. It will be used to start the frontends in a
|
||||
separate process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded
|
||||
locally to set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081
|
||||
if the app.spec.frontends is not set.
|
||||
"""This function is used to get the frontends from the app. It will be used to start the frontends in a separate
|
||||
process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded locally to
|
||||
set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 if the
|
||||
app.spec.frontends is not set.
|
||||
|
||||
NOTE: frontend_name are sorted to ensure that they get consistent ports.
|
||||
|
||||
:param entrypoint_file: The entrypoint file for the app
|
||||
:return: A list of tuples of the form (frontend_name, port_number)
|
||||
|
||||
"""
|
||||
app = load_app_from_file(entrypoint_file)
|
||||
|
||||
|
|
|
@ -252,6 +252,7 @@ class CloudBackend(Backend):
|
|||
Normally, the Lightning frameworks communicates statuses through the queues, but while the Work instance is
|
||||
being provisionied, the queues don't exist yet and hence we need to make API calls directly to the backend to
|
||||
fetch the status and update it in the states.
|
||||
|
||||
"""
|
||||
if not works:
|
||||
return
|
||||
|
@ -305,6 +306,7 @@ class CloudBackend(Backend):
|
|||
"""Stop resources for all LightningWorks in this app.
|
||||
|
||||
The Works are stopped rather than deleted so that they can be inspected for debugging.
|
||||
|
||||
"""
|
||||
cloud_works = self._get_cloud_work_specs(self.client)
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class LightningPlugin:
|
|||
|
||||
Returns:
|
||||
The relative URL of the created job.
|
||||
|
||||
"""
|
||||
from lightning.app.runners.backends.cloud import CloudBackend
|
||||
from lightning.app.runners.cloud import CloudRuntime
|
||||
|
|
|
@ -198,8 +198,8 @@ class CloudRuntime(Runtime):
|
|||
cluster_id: str,
|
||||
source_app: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
|
||||
such as the project and cluster IDs that are instead passed directly.
|
||||
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such
|
||||
as the project and cluster IDs that are instead passed directly.
|
||||
|
||||
Args:
|
||||
project_id: The ID of the project.
|
||||
|
@ -214,6 +214,7 @@ class CloudRuntime(Runtime):
|
|||
|
||||
Returns:
|
||||
The URL of the created job.
|
||||
|
||||
"""
|
||||
# Dispatch in four phases: resolution, validation, spec creation, API transactions
|
||||
# Resolution
|
||||
|
@ -432,6 +433,7 @@ class CloudRuntime(Runtime):
|
|||
"""Find and load the config file if it exists (otherwise create an empty config).
|
||||
|
||||
Override the name if provided.
|
||||
|
||||
"""
|
||||
config_file = _get_config_file(self.entrypoint)
|
||||
cloudspace_config = AppConfig.load_from_file(config_file) if config_file.exists() and load else AppConfig()
|
||||
|
@ -611,6 +613,7 @@ class CloudRuntime(Runtime):
|
|||
"""Check if the user likely needs credits to run the app with its hardware.
|
||||
|
||||
Returns False if user has 1 or more credits.
|
||||
|
||||
"""
|
||||
balance = project.balance
|
||||
if balance is None:
|
||||
|
@ -698,8 +701,8 @@ class CloudRuntime(Runtime):
|
|||
raise RuntimeError(f"Unknown mount protocol `{mount.protocol}` for work `{work.name}`.")
|
||||
|
||||
def _get_flow_servers(self) -> List[V1Flowserver]:
|
||||
"""Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs
|
||||
to start servers."""
|
||||
"""Collect a spec for each flow that contains a frontend so that the backend knows for which flows it needs to
|
||||
start servers."""
|
||||
flow_servers: List[V1Flowserver] = []
|
||||
for flow_name in self.app.frontends:
|
||||
flow_server = V1Flowserver(name=flow_name)
|
||||
|
@ -889,8 +892,7 @@ class CloudRuntime(Runtime):
|
|||
def _get_env_vars(
|
||||
env_vars: Dict[str, str], secrets: Dict[str, str], run_app_comment_commands: bool
|
||||
) -> List[V1EnvVar]:
|
||||
"""Generate the list of environment variable specs for the app, including variables set by the
|
||||
framework."""
|
||||
"""Generate the list of environment variable specs for the app, including variables set by the framework."""
|
||||
v1_env_vars = [V1EnvVar(name=k, value=v) for k, v in env_vars.items()]
|
||||
|
||||
if len(secrets.values()) > 0:
|
||||
|
@ -929,6 +931,7 @@ class CloudRuntime(Runtime):
|
|||
"""Create the cloudspace if it doesn't exist.
|
||||
|
||||
Return the cloudspace ID.
|
||||
|
||||
"""
|
||||
if existing_cloudspace is None:
|
||||
cloudspace_body = ProjectIdCloudspacesBody(name=name, can_download_source_code=True)
|
||||
|
@ -980,6 +983,7 @@ class CloudRuntime(Runtime):
|
|||
"""Transfer an existing instance to the given run ID and update its specification.
|
||||
|
||||
Return the instance.
|
||||
|
||||
"""
|
||||
run_instance = self.backend.client.lightningapp_instance_service_update_lightningapp_instance_release(
|
||||
project_id=project_id,
|
||||
|
|
|
@ -39,6 +39,7 @@ class MultiProcessRuntime(Runtime):
|
|||
|
||||
The MultiProcessRuntime will generate 1 process for each :class:`~lightning.app.core.work.LightningWork` and attach
|
||||
queues to enable communication between the different processes.
|
||||
|
||||
"""
|
||||
|
||||
backend: Union[str, Backend] = "multiprocessing"
|
||||
|
|
|
@ -66,6 +66,7 @@ def dispatch(
|
|||
run_app_comment_commands: whether to parse commands from the entrypoint file and execute them before app startup
|
||||
enable_basic_auth: whether to enable basic authentication for the app
|
||||
(use credentials in the format username:password as an argument)
|
||||
|
||||
"""
|
||||
from lightning.app.runners.runtime_type import RuntimeType
|
||||
from lightning.app.utilities.component import _set_flow_context
|
||||
|
|
|
@ -34,9 +34,9 @@ def _copytree(
|
|||
dirs_exist_ok=False,
|
||||
dry_run=False,
|
||||
) -> List[str]:
|
||||
"""Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like
|
||||
`git` does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks.
|
||||
Differences between original and this function are.
|
||||
"""Vendor in from `shutil.copytree` to support ignoring files recursively based on `.lightningignore`, like `git`
|
||||
does with `.gitignore`. Also removed a few checks from the original copytree related to symlink checks. Differences
|
||||
between original and this function are.
|
||||
|
||||
1. It supports a list of ignore function instead of a single one in the
|
||||
original. We can use this for filtering out files based on nested
|
||||
|
@ -66,6 +66,7 @@ def _copytree(
|
|||
|
||||
|
||||
If exception(s) occur, an Error is raised with a list of reasons.
|
||||
|
||||
"""
|
||||
files_copied = []
|
||||
|
||||
|
@ -146,8 +147,8 @@ def _parse_lightningignore(lines: Tuple[str]) -> Set[str]:
|
|||
|
||||
|
||||
def _read_lightningignore(path: Path) -> Set[str]:
|
||||
"""Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's
|
||||
done to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
|
||||
"""Reads ignore file and filter and empty lines. This will also remove patterns that start with a `/`. That's done
|
||||
to allow `glob` to simulate the behavior done by `git` where it interprets that as a root path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -158,6 +159,7 @@ def _read_lightningignore(path: Path) -> Set[str]:
|
|||
-------
|
||||
Set[str]
|
||||
Set of unique lines.
|
||||
|
||||
"""
|
||||
raw_lines = path.open().readlines()
|
||||
return _parse_lightningignore(raw_lines)
|
||||
|
|
|
@ -33,6 +33,7 @@ def _get_hash(files: List[str], algorithm: str = "blake2", chunk_num_blocks: int
|
|||
----------
|
||||
[1] https://crypto.stackexchange.com/questions/70101/blake2-vs-md5-for-checksum-file-integrity
|
||||
[2] https://stackoverflow.com/questions/1131220/get-md5-hash-of-big-files-in-python
|
||||
|
||||
"""
|
||||
# validate input
|
||||
if algorithm == "blake2":
|
||||
|
|
|
@ -133,6 +133,7 @@ class LocalSourceCodeDir:
|
|||
packaged repository files which have a size > 2GB.
|
||||
|
||||
This limitation should be removed during the datastore upload redesign
|
||||
|
||||
"""
|
||||
if self.package_path.stat().st_size > 2e9:
|
||||
raise OSError(
|
||||
|
|
|
@ -36,6 +36,7 @@ def _get_dir_size_and_count(source_dir: str, prefix: Optional[str] = None) -> Tu
|
|||
-------
|
||||
Tuple[int, int]
|
||||
Size in megabytes and file count
|
||||
|
||||
"""
|
||||
size = 0
|
||||
count = 0
|
||||
|
@ -61,6 +62,7 @@ class _TarResults:
|
|||
The total size of the original directory files in bytes
|
||||
after_size: int
|
||||
The total size of the compressed and tarred split files in bytes
|
||||
|
||||
"""
|
||||
|
||||
before_size: int
|
||||
|
@ -70,13 +72,12 @@ class _TarResults:
|
|||
def _get_split_size(
|
||||
total_size: int, minimum_split_size: int = 1024 * 1000 * 20, max_split_count: int = MAX_SPLIT_COUNT
|
||||
) -> int:
|
||||
"""Calculate the split size we should use to split the multipart upload of an object to a bucket. We are
|
||||
limited to 1000 max parts as the way we are using ListMultipartUploads. More info
|
||||
https://github.com/gridai/grid/pull/5267
|
||||
"""Calculate the split size we should use to split the multipart upload of an object to a bucket. We are limited
|
||||
to 1000 max parts as the way we are using ListMultipartUploads. More info https://github.com/gridai/grid/pull/5267
|
||||
https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html#mpu-process
|
||||
https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListMultipartUploads.html
|
||||
https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31
|
||||
bytes for a single file upload.
|
||||
https://github.com/psf/requests/issues/2717#issuecomment-724725392 Python or requests has a limit of 2**31 bytes
|
||||
for a single file upload.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -91,6 +92,7 @@ def _get_split_size(
|
|||
-------
|
||||
int
|
||||
Split size
|
||||
|
||||
"""
|
||||
max_size = max_split_count * (1 << 31) # max size per part limited by Requests or urllib as shown in ref above
|
||||
if total_size > max_size:
|
||||
|
@ -123,6 +125,7 @@ def _tar_path(source_path: str, target_file: str, compression: bool = False) ->
|
|||
-------
|
||||
TarResults
|
||||
Results that holds file counts and sizes
|
||||
|
||||
"""
|
||||
if os.path.isdir(source_path):
|
||||
before_size, _ = _get_dir_size_and_count(source_path)
|
||||
|
@ -149,6 +152,7 @@ def _tar_path_python(source_path: str, target_file: str, compression: bool = Fal
|
|||
Target tar file
|
||||
compression: bool, default False
|
||||
Enable compression, which is disabled by default.
|
||||
|
||||
"""
|
||||
file_mode = "w:gz" if compression else "w:"
|
||||
|
||||
|
@ -172,6 +176,7 @@ def _tar_path_subprocess(source_path: str, target_file: str, compression: bool =
|
|||
Target tar file
|
||||
compression: bool, default False
|
||||
Enable compression, which is disabled by default.
|
||||
|
||||
"""
|
||||
# Only add compression when users explicitly request it.
|
||||
# We do this because it takes too long to compress
|
||||
|
|
|
@ -35,6 +35,7 @@ class FileUploader:
|
|||
Size of all files to upload
|
||||
name: str
|
||||
Name of this upload to display progress
|
||||
|
||||
"""
|
||||
|
||||
workers: int = 8
|
||||
|
@ -72,6 +73,7 @@ class FileUploader:
|
|||
-------
|
||||
str
|
||||
ETag from response
|
||||
|
||||
"""
|
||||
disconnect_retries = retries
|
||||
while disconnect_retries > 0:
|
||||
|
|
|
@ -52,6 +52,7 @@ class _Copier(Thread):
|
|||
will send requests to this queue.
|
||||
copy_response_queue: A queue connecting the central StorageOrchestrator with the Copier. The Copier
|
||||
will send a response to this queue whenever a requested copy has finished.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -116,6 +117,7 @@ def _copy_files(
|
|||
interpreted as a folder as well. If the source is a file, the destination path is interpreted as a file too.
|
||||
|
||||
Files in a folder are copied recursively and efficiently using multiple threads.
|
||||
|
||||
"""
|
||||
if fs is None:
|
||||
fs = _filesystem()
|
||||
|
|
|
@ -46,6 +46,7 @@ class Drive:
|
|||
component_name: The component name which owns this drive.
|
||||
When not provided, it is automatically inferred by Lightning.
|
||||
root_folder: This is the folder from where the Drive perceives the data (e.g this acts as a mount dir).
|
||||
|
||||
"""
|
||||
if id.startswith("s3://"):
|
||||
raise ValueError(
|
||||
|
@ -96,6 +97,7 @@ class Drive:
|
|||
|
||||
Arguments:
|
||||
path: The relative path to your files to be added to the Drive.
|
||||
|
||||
"""
|
||||
if not self.component_name:
|
||||
raise Exception("The component name needs to be known to put a path to the Drive.")
|
||||
|
@ -121,6 +123,7 @@ class Drive:
|
|||
path: The relative path you want to list files from the Drive.
|
||||
component_name: By default, the Drive lists files across all components.
|
||||
If you provide a component name, the listing is specific to this component.
|
||||
|
||||
"""
|
||||
if _is_flow_context():
|
||||
raise Exception("The flow isn't allowed to list files from a Drive.")
|
||||
|
@ -165,6 +168,7 @@ class Drive:
|
|||
If you provide a component name, the matching is specific to this component.
|
||||
timeout: Whether to wait for the files to be available if not created yet.
|
||||
overwrite: Whether to override the provided path if it exists.
|
||||
|
||||
"""
|
||||
if _is_flow_context():
|
||||
raise Exception("The flow isn't allowed to get files from a Drive.")
|
||||
|
@ -207,11 +211,12 @@ class Drive:
|
|||
self._get(self.fs, match, pathlib.Path(os.path.join(self.root_folder, path)).resolve(), overwrite=overwrite)
|
||||
|
||||
def delete(self, path: str) -> None:
|
||||
"""This method enables to delete files under the provided path from the Drive in a blocking fashion. Only
|
||||
the component which added a file can delete them.
|
||||
"""This method enables to delete files under the provided path from the Drive in a blocking fashion. Only the
|
||||
component which added a file can delete them.
|
||||
|
||||
Arguments:
|
||||
path: The relative path you want to delete files from the Drive.
|
||||
|
||||
"""
|
||||
if not self.component_name:
|
||||
raise Exception("The component name needs to be known to delete a path to the Drive.")
|
||||
|
|
|
@ -42,6 +42,7 @@ class FileSystem:
|
|||
src_path: The path to your files locally
|
||||
dst_path: The path to your files transfered in the shared storage.
|
||||
put_fn: The method to use to put files in the shared storage.
|
||||
|
||||
"""
|
||||
if not os.path.exists(Path(src_path).resolve()):
|
||||
raise FileExistsError(f"The provided path {src_path} doesn't exist")
|
||||
|
@ -66,6 +67,7 @@ class FileSystem:
|
|||
src_path: The path to your files in the shared storage
|
||||
dst_path: The path to your files transfered locally
|
||||
get_fn: The method to use to put files in the shared storage.
|
||||
|
||||
"""
|
||||
if not src_path.startswith("/"):
|
||||
raise Exception(f"The provided destination {src_path} needs to start with `/`.")
|
||||
|
@ -80,6 +82,7 @@ class FileSystem:
|
|||
|
||||
Arguments:
|
||||
path: The path to files to list.
|
||||
|
||||
"""
|
||||
if not path.startswith("/"):
|
||||
raise Exception(f"The provided destination {path} needs to start with `/`.")
|
||||
|
@ -104,6 +107,7 @@ class FileSystem:
|
|||
|
||||
Arguments:
|
||||
path: The path to files to list.
|
||||
|
||||
"""
|
||||
if not path.startswith("/"):
|
||||
raise Exception(f"The provided destination {path} needs to start with `/`.")
|
||||
|
|
|
@ -34,6 +34,7 @@ class Mount:
|
|||
mount_path: An absolute directory path in the work where external data source should
|
||||
be mounted as a filesystem. This path should not already exist in your codebase.
|
||||
If not included, then the root_dir will be set to `/data/<last folder name in the bucket>`
|
||||
|
||||
"""
|
||||
|
||||
source: str = ""
|
||||
|
|
|
@ -47,6 +47,7 @@ class StorageOrchestrator(Thread):
|
|||
put requests on this queue for the file-transfer thread to complete.
|
||||
copy_response_queues: A dictionary of Queues where each Queue connects to one Work. The queue is expected to
|
||||
contain the completion response from the file-transfer thread running in the Work process.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -53,6 +53,7 @@ class Path(PathlibPath):
|
|||
Args:
|
||||
*args: Accepts the same arguments as in :class:`pathlib.Path`
|
||||
**kwargs: Accepts the same keyword arguments as in :class:`pathlib.Path`
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
|
@ -105,6 +106,7 @@ class Path(PathlibPath):
|
|||
"""The name of the LightningWork where this path was first created.
|
||||
|
||||
Attaching a Path to a LightningWork will automatically make it the `origin`.
|
||||
|
||||
"""
|
||||
from lightning.app.core.work import LightningWork
|
||||
|
||||
|
@ -115,6 +117,7 @@ class Path(PathlibPath):
|
|||
"""The name of the LightningWork where this path is being accessed.
|
||||
|
||||
By default, this is the same as the :attr:`origin_name`.
|
||||
|
||||
"""
|
||||
from lightning.app.core.work import LightningWork
|
||||
|
||||
|
@ -125,6 +128,7 @@ class Path(PathlibPath):
|
|||
"""The hash of this Path uniquely identifies the file path and the associated origin Work.
|
||||
|
||||
Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
|
||||
|
||||
"""
|
||||
if self._origin is None:
|
||||
return None
|
||||
|
@ -152,6 +156,7 @@ class Path(PathlibPath):
|
|||
If you strictly want to check local existence only, use :meth:`exists_local` instead. If you strictly want
|
||||
to check existence on the remote (regardless of whether the file exists locally or not), use
|
||||
:meth:`exists_remote`.
|
||||
|
||||
"""
|
||||
return self.exists_local() or (self._origin and self.exists_remote())
|
||||
|
||||
|
@ -164,6 +169,7 @@ class Path(PathlibPath):
|
|||
|
||||
Raises:
|
||||
RuntimeError: If the path is not attached to any Work (origin undefined).
|
||||
|
||||
"""
|
||||
# Fail early if we need to check the remote but an origin is not defined
|
||||
if not self._origin or self._request_queue is None or self._response_queue is None:
|
||||
|
@ -272,6 +278,7 @@ class Path(PathlibPath):
|
|||
|
||||
Args:
|
||||
work: LightningWork to be attached to this Path.
|
||||
|
||||
"""
|
||||
if self._origin is None:
|
||||
# Can become an owner only if there is not already one
|
||||
|
@ -374,11 +381,11 @@ def _is_lit_path(path: Union[str, Path]) -> bool:
|
|||
|
||||
|
||||
def _shared_local_mount_path() -> pathlib.Path:
|
||||
"""Returns the shared directory through which the Copier threads move files from one Work filesystem to
|
||||
another.
|
||||
"""Returns the shared directory through which the Copier threads move files from one Work filesystem to another.
|
||||
|
||||
The shared directory can be set via the environment variable ``SHARED_MOUNT_DIRECTORY`` and should be pointing to a
|
||||
directory that all Works have mounted (shared filesystem).
|
||||
|
||||
"""
|
||||
path = pathlib.Path(os.environ.get("SHARED_MOUNT_DIRECTORY", ".shared"))
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -397,6 +404,7 @@ def _shared_storage_path() -> pathlib.Path:
|
|||
The shared path gets set by the environment. Locally, it is pointing to a directory determined by the
|
||||
``SHARED_MOUNT_DIRECTORY`` environment variable. In the cloud, the shared path will point to a S3 bucket. All Works
|
||||
have access to this shared dropbox.
|
||||
|
||||
"""
|
||||
storage_path = os.getenv("LIGHTNING_STORAGE_PATH", "")
|
||||
if storage_path != "":
|
||||
|
|
|
@ -61,6 +61,7 @@ class _BasePayload(ABC):
|
|||
"""The hash of this Payload uniquely identifies the payload and the associated origin Work.
|
||||
|
||||
Returns ``None`` if the origin is not defined, i.e., this Path did not yet get attached to a LightningWork.
|
||||
|
||||
"""
|
||||
if self._origin is None:
|
||||
return None
|
||||
|
@ -72,6 +73,7 @@ class _BasePayload(ABC):
|
|||
"""The name of the LightningWork where this payload was first created.
|
||||
|
||||
Attaching a Payload to a LightningWork will automatically make it the `origin`.
|
||||
|
||||
"""
|
||||
from lightning.app.core.work import LightningWork
|
||||
|
||||
|
@ -82,6 +84,7 @@ class _BasePayload(ABC):
|
|||
"""The name of the LightningWork where this payload is being accessed.
|
||||
|
||||
By default, this is the same as the :attr:`origin_name`.
|
||||
|
||||
"""
|
||||
from lightning.app.core.work import LightningWork
|
||||
|
||||
|
@ -107,6 +110,7 @@ class _BasePayload(ABC):
|
|||
|
||||
Args:
|
||||
work: LightningWork to be attached to this Payload.
|
||||
|
||||
"""
|
||||
if self._origin is None:
|
||||
# Can become an owner only if there is not already one
|
||||
|
@ -130,6 +134,7 @@ class _BasePayload(ABC):
|
|||
|
||||
Raises:
|
||||
RuntimeError: If the payload is not attached to any Work (origin undefined).
|
||||
|
||||
"""
|
||||
# Fail early if we need to check the remote but an origin is not defined
|
||||
if not self._origin or self._request_queue is None or self._response_queue is None:
|
||||
|
|
|
@ -62,6 +62,7 @@ class _RunIf:
|
|||
@pytest.mark.parametrize("arg1", [1, 2.0])
|
||||
def test_wrapper(arg1):
|
||||
assert arg1 > 0.0
|
||||
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
|
@ -155,6 +156,7 @@ class EmptyFlow(LightningFlow):
|
|||
"""A LightningFlow that implements all abstract methods to do nothing.
|
||||
|
||||
Useful for mocking in tests.
|
||||
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
|
@ -165,6 +167,7 @@ class EmptyWork(LightningWork):
|
|||
"""A LightningWork that implements all abstract methods to do nothing.
|
||||
|
||||
Useful for mocking in tests.
|
||||
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -503,6 +503,7 @@ def delete_cloud_lightning_apps(name=None):
|
|||
"""Cleanup cloud apps that start with the name test-{PR_NUMBER}-{TEST_APP_NAME}.
|
||||
|
||||
PR_NUMBER and TEST_APP_NAME are environment variables.
|
||||
|
||||
"""
|
||||
|
||||
client = LightningClient()
|
||||
|
|
|
@ -44,6 +44,7 @@ def _extract_commands_from_file(file_name: str) -> CommandLines:
|
|||
"""Extract all lines at the top of the file which contain commands to execute.
|
||||
|
||||
The return struct contains a list of commands to execute with the corresponding line number the command executed on.
|
||||
|
||||
"""
|
||||
cl = CommandLines(
|
||||
file=file_name,
|
||||
|
@ -83,6 +84,7 @@ def _execute_app_commands(cl: CommandLines) -> None:
|
|||
"""Open a subprocess shell to execute app commands.
|
||||
|
||||
The calling app environment is used in the current environment the code is running in
|
||||
|
||||
"""
|
||||
for command, line_number in zip(cl.commands, cl.line_numbers):
|
||||
logger.info(f"Running app setup command: {command}")
|
||||
|
@ -116,6 +118,7 @@ def run_app_commands(file: str) -> None:
|
|||
foo! bar <--- not a command import lightning <--- not a command, end parsing.
|
||||
|
||||
where `echo "hello world" && pip install foo` would be executed in the current running environment.
|
||||
|
||||
"""
|
||||
cl = _extract_commands_from_file(file_name=file)
|
||||
if len(cl.commands) == 0:
|
||||
|
|
|
@ -57,8 +57,7 @@ class StateEntry:
|
|||
|
||||
|
||||
class StateStore(ABC):
|
||||
"""Base class of State store that provides simple key, value store to keep track of app state, served app
|
||||
state."""
|
||||
"""Base class of State store that provides simple key, value store to keep track of app state, served app state."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
|
@ -352,6 +351,7 @@ def _walk_to_component(
|
|||
"""Returns a generator that runs through the tree starting from the root down to the given component.
|
||||
|
||||
At each node, yields parent and child as a tuple.
|
||||
|
||||
"""
|
||||
from lightning.app.structures import Dict, List
|
||||
|
||||
|
@ -469,6 +469,7 @@ def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict:
|
|||
root_flow: The flow at the top of the component tree.
|
||||
state: The collected state dict.
|
||||
strict: Whether to validate all components have been re-created.
|
||||
|
||||
"""
|
||||
# 1: Reload the state of the existing works
|
||||
for w in root_flow.works():
|
||||
|
|
|
@ -49,6 +49,7 @@ def _push_log_events_to_read_queue_callback(component_name: str, read_queue: que
|
|||
"""Pushes _LogEvents from websocket to read_queue.
|
||||
|
||||
Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
|
||||
|
||||
"""
|
||||
|
||||
def callback(ws_app: WebSocketApp, msg: str):
|
||||
|
|
|
@ -42,6 +42,7 @@ def _credential_string_to_basic_auth_params(credential_string: str) -> Dict[str,
|
|||
"""Returns the name/ID pair for each given Secret name.
|
||||
|
||||
Raises a `ValueError` if any of the given Secret names do not exist.
|
||||
|
||||
"""
|
||||
if credential_string.count(":") != 1:
|
||||
raise ValueError(
|
||||
|
|
|
@ -129,6 +129,7 @@ class _LightningAppOpenAPIRetriever:
|
|||
Arguments:
|
||||
app_id_or_name_or_url: An identified for the app.
|
||||
use_cache: Whether to load the openapi spec from the cache.
|
||||
|
||||
"""
|
||||
self.app_id_or_name_or_url = app_id_or_name_or_url
|
||||
self.url = None
|
||||
|
|
|
@ -25,6 +25,7 @@ def _get_default_cluster(client: LightningClient, project_id: str) -> str:
|
|||
"""This utility implements a minimal version of the cluster selection logic used in the cloud.
|
||||
|
||||
TODO: This should be requested directly from the platform.
|
||||
|
||||
"""
|
||||
cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id=project_id).clusters
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ def _convert_paths_after_init(root: "LightningFlow"):
|
|||
This is necessary because at the time of instantiating the component, its full affiliation is not known and Paths
|
||||
that get passed to other componenets during ``__init__`` are otherwise not able to reference their origin or
|
||||
consumer.
|
||||
|
||||
"""
|
||||
from lightning.app import LightningFlow, LightningWork
|
||||
from lightning.app.storage import Path
|
||||
|
@ -52,6 +53,7 @@ def _sanitize_state(state: Dict[str, Any]) -> Dict[str, Any]:
|
|||
"""Utility function to sanitize the state of a component.
|
||||
|
||||
Sanitization enables the state to be deep-copied and hashed.
|
||||
|
||||
"""
|
||||
from lightning.app.storage import Drive, Path
|
||||
from lightning.app.storage.payload import _BasePayload
|
||||
|
@ -132,6 +134,7 @@ def _context(ctx: str) -> Generator[None, None, None]:
|
|||
|
||||
The context is used to determine whether the current process is running for a LightningFlow or for a LightningWork.
|
||||
See also :func:`_get_context`, :func:`_set_context`. For internal use only.
|
||||
|
||||
"""
|
||||
prev = _get_context()
|
||||
_set_context(ctx)
|
||||
|
|
|
@ -29,6 +29,7 @@ class AttributeDict(Dict):
|
|||
"key2": abc
|
||||
"my-key": 3.14
|
||||
"new_key": 42
|
||||
|
||||
"""
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
|
|
|
@ -29,6 +29,7 @@ class _ApiExceptionHandler(Group):
|
|||
|
||||
However, if the ApiException cannot be decoded, or is not
|
||||
a 4xx error, the original ApiException will be re-raised.
|
||||
|
||||
"""
|
||||
|
||||
def invoke(self, ctx: Context) -> Any:
|
||||
|
@ -81,6 +82,7 @@ class LightningPlatformException(Exception): # pragma: no cover
|
|||
|
||||
It gets raised by the Lightning Launcher on the platform side when the app is running in the cloud, and is useful
|
||||
when framework or user code needs to catch exceptions specific to the platform, e.g., when resources exceed quotas.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ def execute_git_command(args: List[str], cwd=None) -> str:
|
|||
-------
|
||||
output: str
|
||||
String combining stdout and stderr.
|
||||
|
||||
"""
|
||||
process = subprocess.run(["git"] + args, capture_output=True, text=True, cwd=cwd, check=False)
|
||||
return process.stdout.strip() + process.stderr.strip()
|
||||
|
@ -61,6 +62,7 @@ def check_if_remote_head_is_different() -> Union[bool, None]:
|
|||
|
||||
This only compares the local SHA to the HEAD commit of a given branch. This check won't be used if user isn't in a
|
||||
HEAD locally.
|
||||
|
||||
"""
|
||||
# Check SHA values.
|
||||
local_sha = execute_git_command(["rev-parse", "@"])
|
||||
|
@ -78,6 +80,7 @@ def has_uncommitted_files() -> bool:
|
|||
"""Checks if user has uncommited files in local repository.
|
||||
|
||||
If there are uncommited files, then show a prompt indicating that uncommited files exist locally.
|
||||
|
||||
"""
|
||||
files = execute_git_command(["update-index", "--refresh"])
|
||||
return bool(files)
|
||||
|
|
|
@ -32,6 +32,7 @@ def _get_extras(extras: str) -> str:
|
|||
"""Get the given extras as a space delimited string.
|
||||
|
||||
Used by the platform to install cloud extras in the cloud.
|
||||
|
||||
"""
|
||||
from lightning.app import __package_name__
|
||||
|
||||
|
|
|
@ -22,13 +22,14 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class LightningVisitor(ast.NodeVisitor):
|
||||
"""Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected
|
||||
to define class_name and implement the analyze_class_def method.
|
||||
"""Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to
|
||||
define class_name and implement the analyze_class_def method.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_name: str
|
||||
Name of class to identify, to be defined in subclasses.
|
||||
|
||||
"""
|
||||
|
||||
class_name: Optional[str] = None
|
||||
|
@ -63,6 +64,7 @@ class LightningModuleVisitor(LightningVisitor):
|
|||
Names of methods that are part of the LightningModule API.
|
||||
hooks: Set[str]
|
||||
Names of hooks that are part of the LightningModule API.
|
||||
|
||||
"""
|
||||
|
||||
class_name: Optional[str] = "LightningModule"
|
||||
|
@ -132,6 +134,7 @@ class LightningDataModuleVisitor(LightningVisitor):
|
|||
Name of class to identify.
|
||||
methods: Set[str]
|
||||
Names of methods that are part of the LightningDataModule API.
|
||||
|
||||
"""
|
||||
|
||||
class_name = "LightningDataModule"
|
||||
|
@ -155,6 +158,7 @@ class LightningLoggerVisitor(LightningVisitor):
|
|||
Name of class to identify.
|
||||
methods: Set[str]
|
||||
Names of methods that are part of the Logger API.
|
||||
|
||||
"""
|
||||
|
||||
class_name = "Logger"
|
||||
|
@ -171,6 +175,7 @@ class LightningCallbackVisitor(LightningVisitor):
|
|||
Name of class to identify.
|
||||
methods: Set[str]
|
||||
Names of methods that are part of the Logger API.
|
||||
|
||||
"""
|
||||
|
||||
class_name = "Callback"
|
||||
|
@ -223,6 +228,7 @@ class LightningStrategyVisitor(LightningVisitor):
|
|||
Name of class to identify.
|
||||
methods: Set[str]
|
||||
Names of methods that are part of the Logger API.
|
||||
|
||||
"""
|
||||
|
||||
class_name = "Strategy"
|
||||
|
@ -282,6 +288,7 @@ class Scanner:
|
|||
glob_pattern: str
|
||||
Glob pattern to use when looking for files in the path,
|
||||
applied when path is a directory. Default is "**/*.py".
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Finalize introspecting the methods from all the discovered methods.
|
||||
|
@ -341,6 +348,7 @@ class Scanner:
|
|||
List[Dict[str, Any]]
|
||||
List of dicts containing all metadata required
|
||||
to import modules found.
|
||||
|
||||
"""
|
||||
modules_found: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ def _add_comment_to_literal_code(method, contains, comment):
|
|||
"""Inspects a method's code and adds a message to it.
|
||||
|
||||
This is a nice to have, so if it fails for some reason, it shouldn't affect the program.
|
||||
|
||||
"""
|
||||
try:
|
||||
lines = inspect.getsource(method)
|
||||
|
|
|
@ -61,6 +61,7 @@ def _load_objects_from_file(
|
|||
raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit.
|
||||
mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to
|
||||
be loaded without installing dependencies.
|
||||
|
||||
"""
|
||||
|
||||
# Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313
|
||||
|
@ -110,6 +111,7 @@ def load_app_from_file(
|
|||
Arguments:
|
||||
filepath: The path to the file containing the LightningApp.
|
||||
raise_exception: If True, raise an exception if the app cannot be loaded.
|
||||
|
||||
"""
|
||||
from lightning.app.core.app import LightningApp
|
||||
|
||||
|
@ -142,6 +144,7 @@ def open_python_file(filename):
|
|||
|
||||
In Python 3, we would like all files to be opened with utf-8 encoding. However, some author like to specify PEP263
|
||||
headers in their source files with their own encodings. In that case, we should respect the author's encoding.
|
||||
|
||||
"""
|
||||
import tokenize
|
||||
|
||||
|
@ -204,6 +207,7 @@ def _patch_sys_path(append):
|
|||
|
||||
Args:
|
||||
append: The value to append to the path.
|
||||
|
||||
"""
|
||||
if append in sys.path:
|
||||
yield
|
||||
|
|
|
@ -60,6 +60,7 @@ class Auth:
|
|||
Returns
|
||||
----------
|
||||
True if credentials are available.
|
||||
|
||||
"""
|
||||
if not self.secrets_file.exists():
|
||||
logger.debug("Credentials file not found.")
|
||||
|
@ -117,6 +118,7 @@ class Auth:
|
|||
Returns
|
||||
----------
|
||||
authorization header to use when authentication completes.
|
||||
|
||||
"""
|
||||
if not self.load():
|
||||
# First try to authenticate from env
|
||||
|
|
|
@ -79,6 +79,7 @@ class _LightningLogsSocketAPI(_AuthTokenGetter):
|
|||
|
||||
Returns:
|
||||
WebSocketApp of the wanted socket
|
||||
|
||||
"""
|
||||
_token = self._get_api_token()
|
||||
clean_ws_host = urlparse(self.api_client.configuration.host).netloc
|
||||
|
|
|
@ -1353,6 +1353,7 @@ def get_unique_name():
|
|||
'meek-ardinghelli-4506'
|
||||
>>> get_unique_name()
|
||||
'truthful-dijkstra-2286'
|
||||
|
||||
"""
|
||||
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999) # noqa: S311
|
||||
return f"{adjective}-{surname}-{i}"
|
||||
|
|
|
@ -95,6 +95,7 @@ def _configure_session() -> Session:
|
|||
"""Configures the session for GET and POST requests.
|
||||
|
||||
It enables a generous retrial strategy that waits for the application server to connect.
|
||||
|
||||
"""
|
||||
retry_strategy = Retry(
|
||||
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
|
||||
|
@ -124,10 +125,10 @@ def _get_next_backoff_time(num_retries: int, backoff_value: float = 0.5) -> floa
|
|||
|
||||
|
||||
def _retry_wrapper(self, func: Callable, max_tries: Optional[int] = None) -> Callable:
|
||||
"""Returns the function decorated by a wrapper that retries the call several times if a connection error
|
||||
occurs.
|
||||
"""Returns the function decorated by a wrapper that retries the call several times if a connection error occurs.
|
||||
|
||||
The retries follow an exponential backoff.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
|
@ -175,6 +176,7 @@ class LightningClient(GridRestClient):
|
|||
Args:
|
||||
retry: Whether API calls should follow a retry mechanism with exponential backoff.
|
||||
max_tries: Maximum number of attempts (or -1 to retry forever).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, retry: bool = True, max_tries: Optional[int] = None) -> None:
|
||||
|
@ -275,5 +277,6 @@ class HTTPClient:
|
|||
|
||||
We enabled customisation here instead of just using `logger.debug` because HTTP logging can be very noisy, but
|
||||
it is crucial for finding bugs when we have them
|
||||
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -49,6 +49,7 @@ def create_openapi_object(json_obj: Dict, target: Any):
|
|||
|
||||
Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
|
||||
object.
|
||||
|
||||
"""
|
||||
if not isinstance(json_obj, dict):
|
||||
raise TypeError("json_obj must be a dictionary")
|
||||
|
|
|
@ -29,6 +29,7 @@ class AppConfig:
|
|||
|
||||
Args:
|
||||
name: Optional name of the application. If not provided, auto-generates a new name.
|
||||
|
||||
"""
|
||||
|
||||
name: str = field(default_factory=get_unique_name)
|
||||
|
@ -56,6 +57,7 @@ class AppConfig:
|
|||
|
||||
Args:
|
||||
directory: Path to a folder which contains the '.lightning' config file to load.
|
||||
|
||||
"""
|
||||
return cls.load_from_file(pathlib.Path(directory, _APP_CONFIG_FILENAME))
|
||||
|
||||
|
@ -65,6 +67,7 @@ def _get_config_file(source_path: Union[str, pathlib.Path]) -> pathlib.Path:
|
|||
|
||||
Args:
|
||||
source_path: A path to a folder or a file.
|
||||
|
||||
"""
|
||||
source_path = pathlib.Path(source_path).absolute()
|
||||
if source_path.is_file():
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue