Add GPU Acceleration apple silicon examples (#18127)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
a7bbbcea9a
commit
4e3087951d
|
@ -26,7 +26,12 @@ class PyTorchServer(L.app.components.PythonServer):
|
|||
)
|
||||
|
||||
def setup(self):
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device("cuda:0")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
self._device = torch.device("mps")
|
||||
else:
|
||||
self._device = torch.device("cpu")
|
||||
self._model = torchvision.models.resnet18(pretrained=True).to(self._device)
|
||||
|
||||
def predict(self, requests: BatchRequestModel):
|
||||
|
|
|
@ -69,7 +69,12 @@ def main():
|
|||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
|
||||
|
||||
# Decide which device we want to run on
|
||||
device = torch.device("cuda:0" if (torch.cuda.is_available() and num_gpus > 0) else "cpu")
|
||||
if torch.cuda.is_available() and num_gpus > 0:
|
||||
device = torch.device("cuda:0")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
output_dir = Path("outputs-torch", time.strftime("%Y%m%d-%H%M%S"))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -56,7 +56,13 @@ def run(hparams):
|
|||
torch.manual_seed(hparams.seed)
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
use_mps = torch.backends.mps.is_available()
|
||||
if use_cuda:
|
||||
device = torch.device("cuda")
|
||||
elif use_mps:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
|
||||
train_dataset = MNIST(DATASETS_PATH, train=True, download=True, transform=transform)
|
||||
|
|
Loading…
Reference in New Issue