From 4e3087951df3d986bceaa0b9703bd6fabff047f3 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Fri, 28 Jul 2023 05:22:51 +0530 Subject: [PATCH] Add GPU Acceleration apple silicon examples (#18127) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli --- examples/app/server_with_auto_scaler/app.py | 7 ++++++- examples/fabric/dcgan/train_torch.py | 7 ++++++- examples/fabric/image_classifier/train_torch.py | 8 +++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/app/server_with_auto_scaler/app.py b/examples/app/server_with_auto_scaler/app.py index 8e0907b8f2..8761119634 100644 --- a/examples/app/server_with_auto_scaler/app.py +++ b/examples/app/server_with_auto_scaler/app.py @@ -26,7 +26,12 @@ class PyTorchServer(L.app.components.PythonServer): ) def setup(self): - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + self._device = torch.device("cuda:0") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + self._device = torch.device("mps") + else: + self._device = torch.device("cpu") self._model = torchvision.models.resnet18(pretrained=True).to(self._device) def predict(self, requests: BatchRequestModel): diff --git a/examples/fabric/dcgan/train_torch.py b/examples/fabric/dcgan/train_torch.py index 6362736107..d7ce548b16 100644 --- a/examples/fabric/dcgan/train_torch.py +++ b/examples/fabric/dcgan/train_torch.py @@ -69,7 +69,12 @@ def main(): dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers) # Decide which device we want to run on - device = torch.device("cuda:0" if (torch.cuda.is_available() and num_gpus > 0) else "cpu") + if torch.cuda.is_available() and num_gpus > 0: + device = torch.device("cuda:0") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") output_dir = Path("outputs-torch", time.strftime("%Y%m%d-%H%M%S")) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/examples/fabric/image_classifier/train_torch.py b/examples/fabric/image_classifier/train_torch.py index e2bfd750f7..43f3128aba 100644 --- a/examples/fabric/image_classifier/train_torch.py +++ b/examples/fabric/image_classifier/train_torch.py @@ -56,7 +56,13 @@ def run(hparams): torch.manual_seed(hparams.seed) use_cuda = torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") + use_mps = torch.backends.mps.is_available() + if use_cuda: + device = torch.device("cuda") + elif use_mps: + device = torch.device("mps") + else: + device = torch.device("cpu") transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))]) train_dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transform)