lightning/tests/tests_pytorch/accelerators/test_registry.py

67 lines
2.3 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import Accelerator, AcceleratorRegistry
def test_accelerator_registry_with_new_accelerator():
accelerator_name = "custom_accelerator"
accelerator_description = "Custom Accelerator"
class CustomAccelerator(Accelerator):
def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
super().__init__()
@staticmethod
def parse_devices(devices):
return devices
@staticmethod
def get_parallel_devices(devices):
return ["foo"] * devices
@staticmethod
def auto_device_count():
return 3
@staticmethod
def is_available():
return True
AcceleratorRegistry.register(
accelerator_name, CustomAccelerator, description=accelerator_description, param1="abc", param2=123
)
assert accelerator_name in AcceleratorRegistry
assert AcceleratorRegistry[accelerator_name]["description"] == accelerator_description
assert AcceleratorRegistry[accelerator_name]["init_params"] == {"param1": "abc", "param2": 123}
assert AcceleratorRegistry[accelerator_name]["accelerator_name"] == accelerator_name
assert isinstance(AcceleratorRegistry.get(accelerator_name), CustomAccelerator)
trainer = Trainer(accelerator=accelerator_name, devices="auto")
assert isinstance(trainer.accelerator, CustomAccelerator)
assert trainer.strategy.parallel_devices == ["foo"] * 3
AcceleratorRegistry.remove(accelerator_name)
assert accelerator_name not in AcceleratorRegistry
def test_available_accelerators_in_registry():
Merge different gpu backends with accelerator='gpu' (#13642) * Rename GPUAccelerator to CUDAAccelerator * Add back GPUAccelerator and deprecate it * Remove temporary registration * accelerator connector reroute * accelerator_connector tests * update enums * lite support + tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move "gpu" support up before actual accelerator flag checks * Stupid arguments * fix tests * change exception type * fix registry test * pre-commit * CI: debug HPU flow (#13419) * Update the hpu-tests.yml to pull docker from vault * fire & sudo * habana-gaudi-hpus * Check the driver status on gaudi server (#13718) Co-authored-by: arao <arao@habana.ai> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akarsha Rao <94624926+raoakarsha@users.noreply.github.com> * Update typing-extensions requirement from <4.2.1,>=4.0.0 to >=4.0.0,<4.3.1 in /requirements (#13529) Update typing-extensions requirement in /requirements Updates the requirements on [typing-extensions](https://github.com/python/typing_extensions) to permit the latest version. - [Release notes](https://github.com/python/typing_extensions/releases) - [Changelog](https://github.com/python/typing_extensions/blob/main/CHANGELOG.md) - [Commits](https://github.com/python/typing_extensions/compare/4.0.0...4.3.0) --- updated-dependencies: - dependency-name: typing-extensions dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit suggestions (#13540) updates: - [github.com/psf/black: 22.3.0 → 22.6.0](https://github.com/psf/black/compare/22.3.0...22.6.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [FIX] Native FSDP precision + tests (#12985) * Simplify fetching's loader types (#13111) * Include app templates to the lightning and app packages (#13731) * Include app templates to the package Co-authored-by: mansy <mansy@lightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py (#13617) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Fix typos initialize in docs (#13557) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix main progress bar counter when `val_check_interval=int` and `check_val_every_n_epoch=None` (#12832) * Fix mypy errors attributed to `pytorch_lightning.loggers.tensorboard.py` (#13688) Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Fix mypy errors attributed to `pytorch_lightning.loggers.mlflow` (#13691) Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> * fix mypy errors for loggers/wandb.py (#13483) Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> * Fix gatekeeper minimum check (#13769) * changelog * changelog * fix order * move up again * add missing test Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: arao <arao@habana.ai> Co-authored-by: Akarsha Rao <94624926+raoakarsha@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sean Naren <sean@grid.ai> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Mansy <ahmed.mansy156@gmail.com> Co-authored-by: mansy <mansy@lightning.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Lee Jungwon <33821003+BongYang@users.noreply.github.com> Co-authored-by: Nathaniel D'Amours <88633026+NathanielDamours@users.noreply.github.com> Co-authored-by: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: Gautier Dagan <s2234411@ed.ac.uk> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
2022-07-25 14:46:45 +00:00
assert AcceleratorRegistry.available_accelerators() == ["cpu", "cuda", "hpu", "ipu", "mps", "tpu"]