Restructure Lite examples and add GAN (#16240)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
876574695a
commit
b0c272e8b7
|
@ -103,14 +103,19 @@ or use the :meth:`~lightning_fabric.fabric.Fabric.launch` method in a notebook.
|
|||
|
|
||||
|
||||
That's it! You can now train on any device at any scale with a switch of a flag.
|
||||
Check out our examples that use Fabric:
|
||||
|
||||
DDP with 8 GPUs and `torch.bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_ precision:
|
||||
- `Image Classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md>`_
|
||||
- `Generative Adversarial Network (GAN) <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/dcgan/README.md>`_
|
||||
|
||||
|
||||
Here is how you run DDP with 8 GPUs and `torch.bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_ precision:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
lightning run model ./path/to/train.py --strategy=ddp --devices=8 --accelerator=cuda --precision="bf16"
|
||||
|
||||
`DeepSpeed Zero3 <https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html>`_ with mixed precision:
|
||||
Or `DeepSpeed Zero3 <https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html>`_ with mixed precision:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ ______________________________________________________________________
|
|||
We show how to accelerate your PyTorch code with [Lightning Fabric](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_fabric.html) with minimal code changes.
|
||||
You stay in full control of the training loop.
|
||||
|
||||
- [MNIST with vanilla PyTorch](fabric/image_classifier_1_pytorch.py)
|
||||
- [MNIST with Lightning Fabric](fabric/image_classifier_2_fabric.py)
|
||||
- [MNIST: Vanilla PyTorch vs. Fabric](fabric/image_classifier/README.md)
|
||||
- [DCGAN: Vanilla PyTorch vs. Fabric](fabric/dcgan/README.md)
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
## DCGAN
|
||||
|
||||
This is an example of a GAN (Generative Adversarial Network) that learns to generate realistic images of faces.
|
||||
We show two code versions:
|
||||
The first one is implemented in raw PyTorch, but isn't easy to scale.
|
||||
The second one is using [Lightning Fabric](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_fabric.html) to accelerate and scale the model.
|
||||
|
||||
Tip: You can easily inspect the difference between the two files with:
|
||||
|
||||
```bash
|
||||
sdiff train_torch.py train_fabric.py
|
||||
```
|
||||
|
||||
| Real | Generated |
|
||||
| :------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------: |
|
||||
| ![sample-data](https://user-images.githubusercontent.com/5495193/206484557-2e9e3810-a9c8-4ae0-bc6e-126866fef4f0.png) | ![fake-7914](https://user-images.githubusercontent.com/5495193/206484621-5dc4a9a6-c782-4c71-8e80-27580cdcc7e6.png) |
|
||||
|
||||
### Run
|
||||
|
||||
**Raw PyTorch:**
|
||||
|
||||
```commandline
|
||||
python train_torch.py
|
||||
```
|
||||
|
||||
**Accelerated using Lightning Fabric:**
|
||||
|
||||
```commandline
|
||||
python train_fabric.py
|
||||
```
|
||||
|
||||
Generated images get saved to the _outputs_ folder.
|
||||
|
||||
### Notes
|
||||
|
||||
The CelebA dataset is hosted through a Google Drive link by the authors, but the downloads are limited.
|
||||
You may get a message saying that the daily quota was reached. In this case,
|
||||
[manually download the data](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg)
|
||||
through your browser.
|
||||
|
||||
### References
|
||||
|
||||
- [DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
|
||||
- [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434)
|
||||
- [Large-scale CelebFaces Attributes (CelebA) Dataset](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
|
|
@ -0,0 +1,268 @@
|
|||
"""
|
||||
DCGAN - Accelerated with Lightning Fabric
|
||||
|
||||
Code adapted from the official PyTorch DCGAN tutorial:
|
||||
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.utils
|
||||
from torchvision.datasets import CelebA
|
||||
|
||||
from lightning.fabric import Fabric, seed_everything
|
||||
|
||||
# Root directory for dataset
|
||||
dataroot = "data/"
|
||||
# Number of workers for dataloader
|
||||
workers = os.cpu_count()
|
||||
# Batch size during training
|
||||
batch_size = 128
|
||||
# Spatial size of training images
|
||||
image_size = 64
|
||||
# Number of channels in the training images
|
||||
nc = 3
|
||||
# Size of z latent vector (i.e. size of generator input)
|
||||
nz = 100
|
||||
# Size of feature maps in generator
|
||||
ngf = 64
|
||||
# Size of feature maps in discriminator
|
||||
ndf = 64
|
||||
# Number of training epochs
|
||||
num_epochs = 5
|
||||
# Learning rate for optimizers
|
||||
lr = 0.0002
|
||||
# Beta1 hyperparameter for Adam optimizers
|
||||
beta1 = 0.5
|
||||
# Number of GPUs to use
|
||||
num_gpus = 1
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
seed_everything(999)
|
||||
|
||||
fabric = Fabric(accelerator="auto", devices=1)
|
||||
fabric.launch()
|
||||
|
||||
dataset = CelebA(
|
||||
root=dataroot,
|
||||
split="all",
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
|
||||
|
||||
output_dir = Path("outputs-fabric", time.strftime("%Y%m%d-%H%M%S"))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Plot some training images
|
||||
real_batch = next(iter(dataloader))
|
||||
torchvision.utils.save_image(
|
||||
real_batch[0][:64],
|
||||
output_dir / "sample-data.png",
|
||||
padding=2,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
# Create the generator
|
||||
generator = Generator()
|
||||
|
||||
# Apply the weights_init function to randomly initialize all weights
|
||||
generator.apply(weights_init)
|
||||
|
||||
# Create the Discriminator
|
||||
discriminator = Discriminator()
|
||||
|
||||
# Apply the weights_init function to randomly initialize all weights
|
||||
discriminator.apply(weights_init)
|
||||
|
||||
# Initialize BCELoss function
|
||||
criterion = nn.BCELoss()
|
||||
|
||||
# Create batch of latent vectors that we will use to visualize
|
||||
# the progression of the generator
|
||||
fixed_noise = torch.randn(64, nz, 1, 1, device=fabric.device)
|
||||
|
||||
# Establish convention for real and fake labels during training
|
||||
real_label = 1.0
|
||||
fake_label = 0.0
|
||||
|
||||
# Set up Adam optimizers for both G and D
|
||||
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
|
||||
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
|
||||
|
||||
discriminator, optimizer_d = fabric.setup(discriminator, optimizer_d)
|
||||
generator, optimizer_g = fabric.setup(generator, optimizer_g)
|
||||
dataloader = fabric.setup_dataloaders(dataloader)
|
||||
|
||||
# Lists to keep track of progress
|
||||
losses_g = []
|
||||
losses_d = []
|
||||
iteration = 0
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
|
||||
# (a) Train with all-real batch
|
||||
discriminator.zero_grad()
|
||||
real = data[0]
|
||||
b_size = real.size(0)
|
||||
label = torch.full((b_size,), real_label, dtype=torch.float, device=fabric.device)
|
||||
# Forward pass real batch through D
|
||||
output = discriminator(real).view(-1)
|
||||
# Calculate loss on all-real batch
|
||||
err_d_real = criterion(output, label)
|
||||
# Calculate gradients for D in backward pass
|
||||
fabric.backward(err_d_real)
|
||||
d_x = output.mean().item()
|
||||
|
||||
# (b) Train with all-fake batch
|
||||
# Generate batch of latent vectors
|
||||
noise = torch.randn(b_size, nz, 1, 1, device=fabric.device)
|
||||
# Generate fake image batch with G
|
||||
fake = generator(noise)
|
||||
label.fill_(fake_label)
|
||||
# Classify all fake batch with D
|
||||
output = discriminator(fake.detach()).view(-1)
|
||||
# Calculate D's loss on the all-fake batch
|
||||
err_d_fake = criterion(output, label)
|
||||
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
|
||||
fabric.backward(err_d_fake)
|
||||
d_g_z1 = output.mean().item()
|
||||
# Compute error of D as sum over the fake and the real batches
|
||||
err_d = err_d_real + err_d_fake
|
||||
# Update D
|
||||
optimizer_d.step()
|
||||
|
||||
# (2) Update G network: maximize log(D(G(z)))
|
||||
generator.zero_grad()
|
||||
label.fill_(real_label) # fake labels are real for generator cost
|
||||
# Since we just updated D, perform another forward pass of all-fake batch through D
|
||||
output = discriminator(fake).view(-1)
|
||||
# Calculate G's loss based on this output
|
||||
err_g = criterion(output, label)
|
||||
# Calculate gradients for G
|
||||
fabric.backward(err_g)
|
||||
d_g_z2 = output.mean().item()
|
||||
# Update G
|
||||
optimizer_g.step()
|
||||
|
||||
# Output training stats
|
||||
if i % 50 == 0:
|
||||
fabric.print(
|
||||
f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\t"
|
||||
f"Loss_D: {err_d.item():.4f}\t"
|
||||
f"Loss_G: {err_g.item():.4f}\t"
|
||||
f"D(x): {d_x:.4f}\t"
|
||||
f"D(G(z)): {d_g_z1:.4f} / {d_g_z2:.4f}"
|
||||
)
|
||||
|
||||
# Save Losses for plotting later
|
||||
losses_g.append(err_g.item())
|
||||
losses_d.append(err_d.item())
|
||||
|
||||
# Check how the generator is doing by saving G's output on fixed_noise
|
||||
if (iteration % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
|
||||
with torch.no_grad():
|
||||
fake = generator(fixed_noise).detach().cpu()
|
||||
|
||||
if fabric.is_global_zero:
|
||||
torchvision.utils.save_image(
|
||||
fake,
|
||||
output_dir / f"fake-{iteration:04d}.png",
|
||||
padding=2,
|
||||
normalize=True,
|
||||
)
|
||||
fabric.barrier()
|
||||
|
||||
iteration += 1
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
# custom weights initialization called on netG and netD
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input is Z, going into a convolution
|
||||
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(ngf * 8),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf*8) x 4 x 4
|
||||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * 4),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf*4) x 8 x 8
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * 2),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf*2) x 16 x 16
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf) x 32 x 32
|
||||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
nn.Tanh()
|
||||
# state size. (nc) x 64 x 64
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input is (nc) x 64 x 64
|
||||
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf) x 32 x 32
|
||||
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*2) x 16 x 16
|
||||
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 4),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*4) x 8 x 8
|
||||
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 8),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*8) x 4 x 4
|
||||
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,271 @@
|
|||
"""
|
||||
DCGAN - Raw PyTorch Implementation
|
||||
|
||||
Code adapted from the official PyTorch DCGAN tutorial:
|
||||
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.utils
|
||||
from torchvision.datasets import CelebA
|
||||
|
||||
# Root directory for dataset
|
||||
dataroot = "data/"
|
||||
# Number of workers for dataloader
|
||||
workers = os.cpu_count()
|
||||
# Batch size during training
|
||||
batch_size = 128
|
||||
# Spatial size of training images
|
||||
image_size = 64
|
||||
# Number of channels in the training images
|
||||
nc = 3
|
||||
# Size of z latent vector (i.e. size of generator input)
|
||||
nz = 100
|
||||
# Size of feature maps in generator
|
||||
ngf = 64
|
||||
# Size of feature maps in discriminator
|
||||
ndf = 64
|
||||
# Number of training epochs
|
||||
num_epochs = 5
|
||||
# Learning rate for optimizers
|
||||
lr = 0.0002
|
||||
# Beta1 hyperparameter for Adam optimizers
|
||||
beta1 = 0.5
|
||||
# Number of GPUs to use
|
||||
num_gpus = 1
|
||||
|
||||
|
||||
def main():
|
||||
# Set random seed for reproducibility
|
||||
seed = 999
|
||||
print("Random Seed: ", seed)
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
dataset = CelebA(
|
||||
root=dataroot,
|
||||
split="all",
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
|
||||
|
||||
# Decide which device we want to run on
|
||||
device = torch.device("cuda:0" if (torch.cuda.is_available() and num_gpus > 0) else "cpu")
|
||||
|
||||
output_dir = Path("outputs-torch", time.strftime("%Y%m%d-%H%M%S"))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Plot some training images
|
||||
real_batch = next(iter(dataloader))
|
||||
torchvision.utils.save_image(
|
||||
real_batch[0][:64],
|
||||
output_dir / "sample-data.png",
|
||||
padding=2,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
# Create the generator
|
||||
generator = Generator().to(device)
|
||||
|
||||
# Handle multi-gpu if desired
|
||||
if (device.type == "cuda") and (num_gpus > 1):
|
||||
generator = nn.DataParallel(generator, list(range(num_gpus)))
|
||||
|
||||
# Apply the weights_init function to randomly initialize all weights
|
||||
generator.apply(weights_init)
|
||||
|
||||
# Create the Discriminator
|
||||
discriminator = Discriminator().to(device)
|
||||
|
||||
# Handle multi-gpu if desired
|
||||
if (device.type == "cuda") and (num_gpus > 1):
|
||||
discriminator = nn.DataParallel(discriminator, list(range(num_gpus)))
|
||||
|
||||
# Apply the weights_init function to randomly initialize all weights
|
||||
discriminator.apply(weights_init)
|
||||
|
||||
# Initialize BCELoss function
|
||||
criterion = nn.BCELoss()
|
||||
|
||||
# Create batch of latent vectors that we will use to visualize
|
||||
# the progression of the generator
|
||||
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
|
||||
|
||||
# Establish convention for real and fake labels during training
|
||||
real_label = 1.0
|
||||
fake_label = 0.0
|
||||
|
||||
# Set up Adam optimizers for both G and D
|
||||
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
|
||||
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
|
||||
|
||||
# Lists to keep track of progress
|
||||
losses_g = []
|
||||
losses_d = []
|
||||
iteration = 0
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
|
||||
# (a) Train with all-real batch
|
||||
discriminator.zero_grad()
|
||||
real_cpu = data[0].to(device)
|
||||
b_size = real_cpu.size(0)
|
||||
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
|
||||
# Forward pass real batch through D
|
||||
output = discriminator(real_cpu).view(-1)
|
||||
# Calculate loss on all-real batch
|
||||
err_d_real = criterion(output, label)
|
||||
# Calculate gradients for D in backward pass
|
||||
err_d_real.backward()
|
||||
d_x = output.mean().item()
|
||||
|
||||
# (b) Train with all-fake batch
|
||||
# Generate batch of latent vectors
|
||||
noise = torch.randn(b_size, nz, 1, 1, device=device)
|
||||
# Generate fake image batch with G
|
||||
fake = generator(noise)
|
||||
label.fill_(fake_label)
|
||||
# Classify all fake batch with D
|
||||
output = discriminator(fake.detach()).view(-1)
|
||||
# Calculate D's loss on the all-fake batch
|
||||
err_d_fake = criterion(output, label)
|
||||
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
|
||||
err_d_fake.backward()
|
||||
d_g_z1 = output.mean().item()
|
||||
# Compute error of D as sum over the fake and the real batches
|
||||
err_d = err_d_real + err_d_fake
|
||||
# Update D
|
||||
optimizer_d.step()
|
||||
|
||||
# (2) Update G network: maximize log(D(G(z)))
|
||||
generator.zero_grad()
|
||||
label.fill_(real_label) # fake labels are real for generator cost
|
||||
# Since we just updated D, perform another forward pass of all-fake batch through D
|
||||
output = discriminator(fake).view(-1)
|
||||
# Calculate G's loss based on this output
|
||||
err_g = criterion(output, label)
|
||||
# Calculate gradients for G
|
||||
err_g.backward()
|
||||
d_g_z2 = output.mean().item()
|
||||
# Update G
|
||||
optimizer_g.step()
|
||||
|
||||
# Output training stats
|
||||
if i % 50 == 0:
|
||||
print(
|
||||
f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\t"
|
||||
f"Loss_D: {err_d.item():.4f}\t"
|
||||
f"Loss_G: {err_g.item():.4f}\t"
|
||||
f"D(x): {d_x:.4f}\t"
|
||||
f"D(G(z)): {d_g_z1:.4f} / {d_g_z2:.4f}"
|
||||
)
|
||||
|
||||
# Save Losses for plotting later
|
||||
losses_g.append(err_g.item())
|
||||
losses_d.append(err_d.item())
|
||||
|
||||
# Check how the generator is doing by saving G's output on fixed_noise
|
||||
if (iteration % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
|
||||
with torch.no_grad():
|
||||
fake = generator(fixed_noise).detach().cpu()
|
||||
torchvision.utils.save_image(
|
||||
fake,
|
||||
output_dir / f"fake-{iteration:04d}.png",
|
||||
padding=2,
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
iteration += 1
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
# custom weights initialization called on netG and netD
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||||
elif classname.find("BatchNorm") != -1:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0)
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input is Z, going into a convolution
|
||||
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(ngf * 8),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf*8) x 4 x 4
|
||||
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * 4),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf*4) x 8 x 8
|
||||
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf * 2),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf*2) x 16 x 16
|
||||
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ngf),
|
||||
nn.ReLU(True),
|
||||
# state size. (ngf) x 32 x 32
|
||||
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
||||
nn.Tanh()
|
||||
# state size. (nc) x 64 x 64
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input is (nc) x 64 x 64
|
||||
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf) x 32 x 32
|
||||
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*2) x 16 x 16
|
||||
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 4),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*4) x 8 x 8
|
||||
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(ndf * 8),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*8) x 4 x 4
|
||||
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -4,13 +4,19 @@ Here are two MNIST classifiers implemented in PyTorch.
|
|||
The first one is implemented in pure PyTorch, but isn't easy to scale.
|
||||
The second one is using [Lightning Fabric](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_fabric.html) to accelerate and scale the model.
|
||||
|
||||
Tip: You can easily inspect the difference between the two files with:
|
||||
|
||||
```bash
|
||||
sdiff train_torch.py train_fabric.py
|
||||
```
|
||||
|
||||
#### 1. Image Classifier with Vanilla PyTorch
|
||||
|
||||
Trains a simple CNN over MNIST using vanilla PyTorch. It only supports singe GPU training.
|
||||
|
||||
```bash
|
||||
# CPU
|
||||
python image_classifier_1_pytorch.py
|
||||
python train_torch.py
|
||||
```
|
||||
|
||||
______________________________________________________________________
|
||||
|
@ -21,11 +27,11 @@ This script shows you how to scale the pure PyTorch code to enable GPU and multi
|
|||
|
||||
```bash
|
||||
# CPU
|
||||
lightning run model image_classifier_2_fabric.py
|
||||
lightning run model train_fabric.py
|
||||
|
||||
# GPU (CUDA or M1 Mac)
|
||||
lightning run model image_classifier_2_fabric.py --accelerator=gpu
|
||||
lightning run model train_fabric.py --accelerator=gpu
|
||||
|
||||
# Multiple GPUs
|
||||
lightning run model image_classifier_2_fabric.py --accelerator=gpu --devices=4
|
||||
lightning run model train_fabric.py --accelerator=gpu --devices=4
|
||||
```
|
|
@ -32,10 +32,10 @@ import argparse
|
|||
from os import path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as T
|
||||
from models import Net
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchmetrics.classification import Accuracy
|
||||
from torchvision.datasets import MNIST
|
||||
|
@ -43,7 +43,33 @@ from torchvision.datasets import MNIST
|
|||
from lightning.fabric import Fabric # import Fabric
|
||||
from lightning.fabric import seed_everything
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets")
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||
self.dropout1 = nn.Dropout(0.25)
|
||||
self.dropout2 = nn.Dropout(0.5)
|
||||
self.fc1 = nn.Linear(9216, 128)
|
||||
self.fc2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.dropout1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
|
||||
def run(hparams):
|
|
@ -15,18 +15,44 @@ import argparse
|
|||
from os import path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as T
|
||||
from models import Net
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets")
|
||||
|
||||
|
||||
# Credit to the PyTorch team
|
||||
# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted.
|
||||
class Net(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||
self.dropout1 = nn.Dropout(0.25)
|
||||
self.dropout2 = nn.Dropout(0.5)
|
||||
self.fc1 = nn.Linear(9216, 128)
|
||||
self.fc2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.dropout1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
|
||||
def run(hparams):
|
||||
torch.manual_seed(hparams.seed)
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1)
|
||||
self.dropout1 = nn.Dropout(0.25)
|
||||
self.dropout2 = nn.Dropout(0.5)
|
||||
self.fc1 = nn.Linear(9216, 128)
|
||||
self.fc2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = F.relu(x)
|
||||
x = F.max_pool2d(x, 2)
|
||||
x = self.dropout1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout2(x)
|
||||
x = self.fc2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
|
@ -18,4 +18,4 @@ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
|
|||
dir_path=$(dirname "${BASH_SOURCE[0]}")
|
||||
|
||||
args="--epochs=1"
|
||||
python -m lightning_fabric.cli "${dir_path}/fabric/image_classifier_2_fabric.py" ${args} "$@"
|
||||
python -m lightning_fabric.cli "${dir_path}/fabric/image_classifier/train_fabric.py" ${args} "$@"
|
||||
|
|
Loading…
Reference in New Issue