[pre-commit.ci] pre-commit suggestions (#16224)
* [pre-commit.ci] pre-commit suggestions updates: - [github.com/pre-commit/pre-commit-hooks: v4.3.0 → v4.4.0](https://github.com/pre-commit/pre-commit-hooks/compare/v4.3.0...v4.4.0) - [github.com/asottile/pyupgrade: v2.34.0 → v3.3.1](https://github.com/asottile/pyupgrade/compare/v2.34.0...v3.3.1) - https://github.com/myint/docformatter → https://github.com/PyCQA/docformatter - [github.com/PyCQA/docformatter: v1.4 → v1.5.1](https://github.com/PyCQA/docformatter/compare/v1.4...v1.5.1) - [github.com/asottile/yesqa: v1.3.0 → v1.4.0](https://github.com/asottile/yesqa/compare/v1.3.0...v1.4.0) - [github.com/PyCQA/isort: 5.10.1 → 5.11.4](https://github.com/PyCQA/isort/compare/5.10.1...5.11.4) - [github.com/psf/black: 22.6.0 → 22.12.0](https://github.com/psf/black/compare/22.6.0...22.12.0) - [github.com/executablebooks/mdformat: 0.7.14 → 0.7.16](https://github.com/executablebooks/mdformat/compare/0.7.14...0.7.16) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3326e65bb2
commit
b59941cc52
|
@ -23,7 +23,7 @@ ci:
|
|||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
@ -49,33 +49,33 @@ repos:
|
|||
- id: detect-private-key
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.34.0
|
||||
rev: v3.3.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py37-plus]
|
||||
name: Upgrade code
|
||||
|
||||
- repo: https://github.com/myint/docformatter
|
||||
- repo: https://github.com/PyCQA/docformatter
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: docformatter
|
||||
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
|
||||
|
||||
- repo: https://github.com/asottile/yesqa
|
||||
rev: v1.3.0
|
||||
rev: v1.4.0
|
||||
hooks:
|
||||
- id: yesqa
|
||||
name: Unused noqa
|
||||
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.11.4
|
||||
hooks:
|
||||
- id: isort
|
||||
name: Format imports
|
||||
exclude: docs/source-app
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.6.0
|
||||
rev: 22.12.0
|
||||
hooks:
|
||||
- id: black
|
||||
name: Format code
|
||||
|
@ -90,7 +90,7 @@ repos:
|
|||
exclude: docs/source-app
|
||||
|
||||
- repo: https://github.com/executablebooks/mdformat
|
||||
rev: 0.7.14
|
||||
rev: 0.7.16
|
||||
hooks:
|
||||
- id: mdformat
|
||||
additional_dependencies:
|
||||
|
|
|
@ -3,7 +3,6 @@ import lightning as L
|
|||
|
||||
# Step 1: Subclass LightningFlow component to define the app flow.
|
||||
class HelloWorld(L.LightningFlow):
|
||||
|
||||
# Step 2: Add the app logic to the LightningFlow run method to
|
||||
# ``print("Hello World!")`.
|
||||
# The LightningApp executes the run method of the main LightningFlow
|
||||
|
|
|
@ -32,7 +32,6 @@ def render_fn(state: AppState):
|
|||
|
||||
# Step 4: Implement a Static Web Frontend. This could be react, vue, etc.
|
||||
class UIStatic(L.LightningFlow):
|
||||
|
||||
# Step 5:
|
||||
def configure_layout(self):
|
||||
return StaticWebFrontend(os.path.join(os.path.dirname(__file__), "ui"))
|
||||
|
|
|
@ -36,12 +36,12 @@ First, let's define the component we need:
|
|||
* A collection of model work to train all models in parallel.
|
||||
|
||||
.. literalinclude:: ../../../examples/app_dag/app.py
|
||||
:lines: 55-79
|
||||
:lines: 53-75
|
||||
|
||||
And its run method executes the steps described above.
|
||||
|
||||
.. literalinclude:: ../../../examples/app_dag/app.py
|
||||
:lines: 80-103
|
||||
:lines: 77-100
|
||||
|
||||
----
|
||||
|
||||
|
@ -50,4 +50,4 @@ Step 2: Define the scheduling
|
|||
*****************************
|
||||
|
||||
.. literalinclude:: ../../../examples/app_dag/app.py
|
||||
:lines: 106-135
|
||||
:lines: 103-132
|
||||
|
|
|
@ -90,7 +90,7 @@ class FileServer(L.LightningWork):
|
|||
"size": full_size,
|
||||
"drive_path": uploaded_file,
|
||||
}
|
||||
with open(self.get_filepath(meta_file), "wt") as f:
|
||||
with open(self.get_filepath(meta_file), "w") as f:
|
||||
json.dump(meta, f)
|
||||
|
||||
# 5: Put the file to the drive.
|
||||
|
@ -163,7 +163,6 @@ from lightning import LightningWork
|
|||
|
||||
|
||||
class TestFileServer(LightningWork):
|
||||
|
||||
def __init__(self, drive: Drive):
|
||||
super().__init__(cache_calls=True)
|
||||
self.drive = drive
|
||||
|
@ -188,7 +187,6 @@ from lightning import LightningApp, LightningFlow
|
|||
|
||||
|
||||
class Flow(LightningFlow):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 1: Create a drive to share data between works
|
||||
|
|
|
@ -73,7 +73,6 @@ class GithubRepoRunner(TracerPythonScript):
|
|||
|
||||
|
||||
class PyTorchLightningGithubRepoRunner(GithubRepoRunner):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.best_model_path = None
|
||||
|
|
|
@ -8,7 +8,6 @@ from lightning.app.storage import Path
|
|||
|
||||
|
||||
class MLServer(LightningWork):
|
||||
|
||||
"""This components uses SeldonIO MLServer library.
|
||||
|
||||
The model endpoint: /v2/models/{MODEL_NAME}/versions/{VERSION}/infer.
|
||||
|
|
|
@ -7,7 +7,6 @@ from lightning.app.storage import Path
|
|||
|
||||
|
||||
class TrainModel(LightningWork):
|
||||
|
||||
"""This component trains a Sklearn SVC model on digits dataset."""
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -15,7 +15,6 @@ class LitDash(L.LightningWork):
|
|||
self.selected_year = None
|
||||
|
||||
def run(self):
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/gapminderDataFiveYear.csv")
|
||||
self.df = Payload(df)
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from lightning.app.structures import Dict
|
|||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.notebooks = Dict()
|
||||
|
|
|
@ -12,7 +12,6 @@ class RunNotebookConfig(BaseModel):
|
|||
|
||||
|
||||
class RunNotebook(ClientCommand):
|
||||
|
||||
description = "Run a Notebook."
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -3,7 +3,6 @@ from lightning.app.api import Post
|
|||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
|
||||
# 1. Define the state
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -3,7 +3,6 @@ from lightning.app.api import Post
|
|||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
|
||||
# 1. Define the state
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -5,7 +5,6 @@ from lightning.app.api import Post
|
|||
|
||||
|
||||
class Flow(L.LightningFlow):
|
||||
|
||||
# 1. Define the state
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -10,7 +10,6 @@ class CustomConfig(BaseModel):
|
|||
|
||||
|
||||
class CustomCommand(ClientCommand):
|
||||
|
||||
description = "A command with a client."
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -5,7 +5,6 @@ from pytorch_lightning import Trainer
|
|||
|
||||
|
||||
class PLTracerPythonScript(TracerPythonScript):
|
||||
|
||||
"""This component can be used for ANY PyTorch Lightning script to track its progress and extract its best model
|
||||
path."""
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ from lightning.app.components import ServeGradio
|
|||
# Credit to @akhaliq for his inspiring work.
|
||||
# Find his original code there: https://huggingface.co/spaces/akhaliq/AnimeGANv2/blob/main/app.py
|
||||
class AnimeGANv2UI(ServeGradio):
|
||||
|
||||
inputs = gr.inputs.Image(type="pil")
|
||||
outputs = gr.outputs.Image(type="pil")
|
||||
elon = "https://upload.wikimedia.org/wikipedia/commons/3/34/Elon_Musk_Royal_Society_%28crop2%29.jpg"
|
||||
|
|
|
@ -17,7 +17,6 @@ def get_path(path):
|
|||
|
||||
|
||||
class GetDataWork(L.LightningWork):
|
||||
|
||||
"""This component is responsible to download some data and store them with a PayLoad."""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -34,7 +33,6 @@ class GetDataWork(L.LightningWork):
|
|||
|
||||
|
||||
class ModelWork(L.LightningWork):
|
||||
|
||||
"""This component is receiving some data and train a sklearn model."""
|
||||
|
||||
def __init__(self, model_path: str, parallel: bool):
|
||||
|
@ -53,7 +51,6 @@ class ModelWork(L.LightningWork):
|
|||
|
||||
|
||||
class DAG(L.LightningFlow):
|
||||
|
||||
"""This component is a DAG."""
|
||||
|
||||
def __init__(self, models_paths: list):
|
||||
|
|
|
@ -28,7 +28,6 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
|||
# Credit to the PyTorch team
|
||||
# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted.
|
||||
def run(hparams):
|
||||
|
||||
torch.manual_seed(hparams.seed)
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
|
|
@ -90,7 +90,6 @@ class ModelToProfile(LightningModule):
|
|||
|
||||
|
||||
class CIFAR10DataModule(LightningDataModule):
|
||||
|
||||
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
|
||||
|
||||
def train_dataloader(self, *args, **kwargs):
|
||||
|
|
|
@ -73,7 +73,6 @@ class BaseKFoldDataModule(LightningDataModule, ABC):
|
|||
|
||||
@dataclass
|
||||
class MNISTKFoldDataModule(BaseKFoldDataModule):
|
||||
|
||||
train_dataset: Optional[Dataset] = None
|
||||
test_dataset: Optional[Dataset] = None
|
||||
train_fold: Optional[Dataset] = None
|
||||
|
|
|
@ -43,7 +43,6 @@ class LitModule(LightningModule):
|
|||
|
||||
|
||||
class CIFAR10DataModule(LightningDataModule):
|
||||
|
||||
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
|
||||
|
||||
def train_dataloader(self, *args, **kwargs):
|
||||
|
@ -57,7 +56,6 @@ class CIFAR10DataModule(LightningDataModule):
|
|||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Image:
|
||||
|
||||
height: Optional[int] = None
|
||||
width: Optional[int] = None
|
||||
extension: str = "JPEG"
|
||||
|
|
|
@ -215,12 +215,10 @@ class Post(_HttpMethod):
|
|||
|
||||
|
||||
class Get(_HttpMethod):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Put(_HttpMethod):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ logger = Logger(__name__)
|
|||
|
||||
|
||||
def app(app_name: str) -> None:
|
||||
|
||||
if app_name is None:
|
||||
app_name = _capture_valid_app_component_name(resource_type="app")
|
||||
|
||||
|
|
|
@ -148,7 +148,6 @@ def gallery_component(name: str, yes_arg: bool, version_arg: str, cwd: str = Non
|
|||
|
||||
|
||||
def non_gallery_component(gh_url: str, yes_arg: bool, cwd: str = None) -> None:
|
||||
|
||||
# give the user the chance to do a manual install
|
||||
git_url = _show_non_gallery_install_component_prompt(gh_url, yes_arg)
|
||||
|
||||
|
@ -157,7 +156,6 @@ def non_gallery_component(gh_url: str, yes_arg: bool, cwd: str = None) -> None:
|
|||
|
||||
|
||||
def gallery_app(name: str, yes_arg: bool, version_arg: str, cwd: str = None, overwrite: bool = False) -> str:
|
||||
|
||||
# make sure org/app-name syntax is correct
|
||||
org, app = _validate_name(name, resource_type="app", example="lightning/quick-start")
|
||||
|
||||
|
@ -179,7 +177,6 @@ def gallery_app(name: str, yes_arg: bool, version_arg: str, cwd: str = None, ove
|
|||
|
||||
|
||||
def non_gallery_app(gh_url: str, yes_arg: bool, cwd: str = None, overwrite: bool = False) -> None:
|
||||
|
||||
# give the user the chance to do a manual install
|
||||
repo_url, folder_name = _show_non_gallery_install_app_prompt(gh_url, yes_arg)
|
||||
|
||||
|
|
|
@ -39,7 +39,6 @@ def logs(app_name: str, components: List[str], follow: bool) -> None:
|
|||
|
||||
|
||||
def _show_logs(app_name: str, components: List[str], follow: bool) -> None:
|
||||
|
||||
client = LightningClient()
|
||||
project = _get_project(client)
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
r"""
|
||||
To test a lightning component:
|
||||
r"""To test a lightning component:
|
||||
|
||||
1. Init the component.
|
||||
2. call .run()
|
||||
|
|
|
@ -30,7 +30,6 @@ logger = Logger(__name__)
|
|||
# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
|
||||
# Discussions: https://github.com/encode/uvicorn/discussions/1708
|
||||
class _DatabaseUvicornServer(uvicorn.Server):
|
||||
|
||||
has_started_queue = None
|
||||
|
||||
def run(self, sockets=None):
|
||||
|
|
|
@ -22,7 +22,6 @@ class _PyTorchSpawnWorkProtocol(Protocol):
|
|||
|
||||
|
||||
class _PyTorchSpawnRunExecutor(WorkRunExecutor):
|
||||
|
||||
enable_start_observer: bool = False
|
||||
|
||||
def __call__(
|
||||
|
|
|
@ -22,7 +22,6 @@ class Code(TypedDict):
|
|||
|
||||
|
||||
class TracerPythonScript(LightningWork):
|
||||
|
||||
_start_method = "spawn"
|
||||
|
||||
def on_before_run(self):
|
||||
|
|
|
@ -143,8 +143,8 @@ def _create_fastapi(title: str) -> FastAPI:
|
|||
|
||||
|
||||
class _LoadBalancer(LightningWork):
|
||||
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton API
|
||||
asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
|
||||
r"""The LoadBalancer is a LightningWork component that collects the requests and sends them to the prediciton
|
||||
API asynchronously using RoundRobin scheduling. It also performs auto batching of the incoming requests.
|
||||
|
||||
The LoadBalancer exposes system endpoints with a basic HTTP authentication, in order to activate the authentication
|
||||
you need to provide a system password from environment variable::
|
||||
|
|
|
@ -13,7 +13,6 @@ else:
|
|||
|
||||
|
||||
class ServeGradio(LightningWork, abc.ABC):
|
||||
|
||||
"""The ServeGradio Class enables to quickly create a ``gradio`` based UI for your LightningApp.
|
||||
|
||||
In the example below, the ``ServeGradio`` is subclassed to deploy ``AnimeGANv2``.
|
||||
|
|
|
@ -142,7 +142,6 @@ class Number(BaseModel):
|
|||
|
||||
|
||||
class PythonServer(LightningWork, abc.ABC):
|
||||
|
||||
_start_method = "spawn"
|
||||
|
||||
@requires(["torch"])
|
||||
|
|
|
@ -417,7 +417,6 @@ def register_global_routes():
|
|||
|
||||
|
||||
class LightningUvicornServer(uvicorn.Server):
|
||||
|
||||
has_started_queue = None
|
||||
|
||||
def run(self, sockets=None):
|
||||
|
|
|
@ -18,7 +18,6 @@ from lightning_app.utilities.packaging.cloud_compute import _maybe_create_cloud_
|
|||
|
||||
|
||||
class LightningFlow:
|
||||
|
||||
_INTERNAL_STATE_VARS = {
|
||||
# Internal protected variables that are still part of the state (even though they are prefixed with "_")
|
||||
"_paths",
|
||||
|
|
|
@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class LightningWork:
|
||||
|
||||
_INTERNAL_STATE_VARS = (
|
||||
# Internal protected variables that are still part of the state (even though they are prefixed with "_")
|
||||
"_paths",
|
||||
|
|
|
@ -196,7 +196,6 @@ def _generate_works_json_gallery(filepath: str) -> str:
|
|||
|
||||
@dataclass
|
||||
class CloudRuntime(Runtime):
|
||||
|
||||
backend: Union[str, CloudBackend] = "cloud"
|
||||
|
||||
def dispatch(
|
||||
|
|
|
@ -21,7 +21,6 @@ from lightning_app.utilities.port import disable_port
|
|||
|
||||
@dataclass
|
||||
class MultiProcessRuntime(Runtime):
|
||||
|
||||
"""Runtime to launch the LightningApp into multiple processes.
|
||||
|
||||
The MultiProcessRuntime will generate 1 process for each :class:`~lightning_app.core.work.LightningWork` and attach
|
||||
|
|
|
@ -156,7 +156,6 @@ class Runtime:
|
|||
raise NotImplementedError
|
||||
|
||||
def _add_stopped_status_to_work(self, work: "lightning_app.LightningWork") -> None:
|
||||
|
||||
if work.status.stage == WorkStageStatus.STOPPED:
|
||||
return
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ from lightning_app.utilities.component import _is_flow_context
|
|||
|
||||
|
||||
class Drive:
|
||||
|
||||
__IDENTIFIER__ = "__drive__"
|
||||
__PROTOCOLS__ = ["lit://"]
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@ class Mount:
|
|||
mount_path: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
for protocol in __MOUNT_PROTOCOLS__:
|
||||
if self.source.startswith(protocol):
|
||||
protocol = protocol
|
||||
|
|
|
@ -243,7 +243,6 @@ class _BasePayload(ABC):
|
|||
|
||||
|
||||
class Payload(_BasePayload):
|
||||
|
||||
"""The Payload object enables to transfer python objects from one work to another in a similar fashion as
|
||||
:class:`~lightning_app.storage.path.Path`."""
|
||||
|
||||
|
|
|
@ -407,7 +407,6 @@ class LightningJSONEncoder(json.JSONEncoder):
|
|||
|
||||
|
||||
class Logger:
|
||||
|
||||
"""This class is used to improve the debugging experience."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
|
|
|
@ -36,7 +36,6 @@ def makedirs(path: str):
|
|||
|
||||
|
||||
class ClientCommand:
|
||||
|
||||
description: str = ""
|
||||
requirements: List[str] = []
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ def _get_extras(extras: str) -> str:
|
|||
|
||||
|
||||
def requires(module_paths: Union[str, List]):
|
||||
|
||||
if not isinstance(module_paths, list):
|
||||
module_paths = [module_paths]
|
||||
|
||||
|
|
|
@ -269,7 +269,6 @@ class LightningProfilerVisitor(LightningVisitor):
|
|||
|
||||
|
||||
class Scanner:
|
||||
|
||||
"""
|
||||
Finds relevant Lightning objects in files in the file system.
|
||||
Attributes
|
||||
|
|
|
@ -341,7 +341,6 @@ class ComponentDelta:
|
|||
|
||||
@dataclass
|
||||
class WorkRunExecutor:
|
||||
|
||||
work: "LightningWork"
|
||||
work_run: Callable
|
||||
delta_queue: "BaseQueue"
|
||||
|
|
|
@ -35,7 +35,6 @@ def headers_for(context: Dict[str, str]) -> Dict[str, str]:
|
|||
|
||||
|
||||
class AppState:
|
||||
|
||||
_APP_PRIVATE_KEYS: Tuple[str, ...] = (
|
||||
"_host",
|
||||
"_session_id",
|
||||
|
|
|
@ -413,8 +413,7 @@ class Fabric:
|
|||
def all_gather(
|
||||
self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
|
||||
) -> Union[Tensor, Dict, List, Tuple]:
|
||||
r"""
|
||||
Gather tensors or collections of tensors from multiple processes.
|
||||
r"""Gather tensors or collections of tensors from multiple processes.
|
||||
|
||||
Args:
|
||||
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
|
||||
|
|
|
@ -596,7 +596,6 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
import deepspeed
|
||||
|
||||
def load(module: torch.nn.Module, prefix: str = "") -> None:
|
||||
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
|
|
@ -25,8 +25,7 @@ _HYDRA_AVAILABLE = RequirementCache("hydra-core")
|
|||
|
||||
|
||||
class _SubprocessScriptLauncher(_Launcher):
|
||||
r"""
|
||||
A process laucher that invokes the current script as many times as desired in a single node.
|
||||
r"""A process laucher that invokes the current script as many times as desired in a single node.
|
||||
|
||||
This launcher needs to be invoked on each node.
|
||||
In its default behavior, the main process in each node then spawns N-1 child processes via :func:`subprocess.Popen`,
|
||||
|
|
|
@ -1779,7 +1779,7 @@ class LightningModule(
|
|||
split_x: Union[Tensor, List[Tensor]]
|
||||
if isinstance(x, Tensor):
|
||||
split_x = x[:, t : t + split_size]
|
||||
elif isinstance(x, collections.Sequence):
|
||||
elif isinstance(x, collections.abc.Sequence):
|
||||
split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))]
|
||||
|
||||
batch_split.append(split_x)
|
||||
|
|
|
@ -25,7 +25,7 @@ def assert_opt_parameters_on_device(opt, device: str):
|
|||
# Not sure there are any global tensors in the state dict
|
||||
if isinstance(param, Tensor):
|
||||
assert param.data.device.type == device
|
||||
elif isinstance(param, collections.Mapping):
|
||||
elif isinstance(param, collections.abc.Mapping):
|
||||
for subparam in param.values():
|
||||
if isinstance(subparam, Tensor):
|
||||
assert param.data.device.type == device
|
||||
|
|
Loading…
Reference in New Issue