Update README.md
This commit is contained in:
parent
bbbc111d52
commit
0474464c45
21
README.md
21
README.md
|
@ -183,22 +183,30 @@ trainer = pl.Trainer()
|
|||
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
|
||||
```
|
||||
|
||||
#### And without changing a single line of code, you could run on GPUs
|
||||
#### And without changing a single line of code, you could run on GPU/TPUss
|
||||
```python
|
||||
# 8 GPUs
|
||||
trainer = Trainer(max_epochs=1, gpus=8)
|
||||
|
||||
# 256 GPUs
|
||||
trainer = Trainer(max_epochs=1, gpus=8, num_nodes=32)
|
||||
|
||||
# TPUs
|
||||
trainer = Trainer(tpu_cores=8)
|
||||
```
|
||||
|
||||
Or TPUs
|
||||
#### And even export for production via onnx or torchscript
|
||||
```python
|
||||
# Distributes TPU core training
|
||||
trainer = Trainer(tpu_cores=8)
|
||||
# torchscript
|
||||
autoencoder = LitAutoEncoder()
|
||||
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
|
||||
|
||||
# Single TPU core training
|
||||
trainer = Trainer(tpu_cores=[1])
|
||||
# onnx
|
||||
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
|
||||
autoencoder = LitAutoEncoder()
|
||||
input_sample = torch.randn((1, 64))
|
||||
autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
|
||||
os.path.isfile(tmpfile.name)
|
||||
```
|
||||
|
||||
#### For advanced users, you can still own complex training loops
|
||||
|
@ -218,7 +226,6 @@ class LitAutoEncoder(pl.LightningModule):
|
|||
self.manual_backward(loss_b, opt_b)
|
||||
opt_b.step()
|
||||
opt_b.zero_grad()
|
||||
|
||||
```
|
||||
---
|
||||
|
||||
|
|
Loading…
Reference in New Issue