Restructure Lite examples and add GAN (#16240)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2023-01-05 15:07:43 +01:00 committed by GitHub
parent 876574695a
commit b0c272e8b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 660 additions and 55 deletions

View File

@ -103,14 +103,19 @@ or use the :meth:`~lightning_fabric.fabric.Fabric.launch` method in a notebook.
|
That's it! You can now train on any device at any scale with a switch of a flag.
Check out our examples that use Fabric:
DDP with 8 GPUs and `torch.bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_ precision:
- `Image Classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md>`_
- `Generative Adversarial Network (GAN) <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/dcgan/README.md>`_
Here is how you run DDP with 8 GPUs and `torch.bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html>`_ precision:
.. code-block:: bash
lightning run model ./path/to/train.py --strategy=ddp --devices=8 --accelerator=cuda --precision="bf16"
`DeepSpeed Zero3 <https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html>`_ with mixed precision:
Or `DeepSpeed Zero3 <https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html>`_ with mixed precision:
.. code-block:: bash

View File

@ -16,8 +16,8 @@ ______________________________________________________________________
We show how to accelerate your PyTorch code with [Lightning Fabric](https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_fabric.html) with minimal code changes.
You stay in full control of the training loop.
- [MNIST with vanilla PyTorch](fabric/image_classifier_1_pytorch.py)
- [MNIST with Lightning Fabric](fabric/image_classifier_2_fabric.py)
- [MNIST: Vanilla PyTorch vs. Fabric](fabric/image_classifier/README.md)
- [DCGAN: Vanilla PyTorch vs. Fabric](fabric/dcgan/README.md)
______________________________________________________________________

View File

@ -0,0 +1,45 @@
## DCGAN
This is an example of a GAN (Generative Adversarial Network) that learns to generate realistic images of faces.
We show two code versions:
The first one is implemented in raw PyTorch, but isn't easy to scale.
The second one is using [Lightning Fabric](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_fabric.html) to accelerate and scale the model.
Tip: You can easily inspect the difference between the two files with:
```bash
sdiff train_torch.py train_fabric.py
```
| Real | Generated |
| :------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------: |
| ![sample-data](https://user-images.githubusercontent.com/5495193/206484557-2e9e3810-a9c8-4ae0-bc6e-126866fef4f0.png) | ![fake-7914](https://user-images.githubusercontent.com/5495193/206484621-5dc4a9a6-c782-4c71-8e80-27580cdcc7e6.png) |
### Run
**Raw PyTorch:**
```commandline
python train_torch.py
```
**Accelerated using Lightning Fabric:**
```commandline
python train_fabric.py
```
Generated images get saved to the _outputs_ folder.
### Notes
The CelebA dataset is hosted through a Google Drive link by the authors, but the downloads are limited.
You may get a message saying that the daily quota was reached. In this case,
[manually download the data](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg)
through your browser.
### References
- [DCGAN Tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434)
- [Large-scale CelebFaces Attributes (CelebA) Dataset](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)

View File

@ -0,0 +1,268 @@
"""
DCGAN - Accelerated with Lightning Fabric
Code adapted from the official PyTorch DCGAN tutorial:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
"""
import os
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils
from torchvision.datasets import CelebA
from lightning.fabric import Fabric, seed_everything
# Root directory for dataset
dataroot = "data/"
# Number of workers for dataloader
workers = os.cpu_count()
# Batch size during training
batch_size = 128
# Spatial size of training images
image_size = 64
# Number of channels in the training images
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
# Number of GPUs to use
num_gpus = 1
def main():
# Set random seed for reproducibility
seed_everything(999)
fabric = Fabric(accelerator="auto", devices=1)
fabric.launch()
dataset = CelebA(
root=dataroot,
split="all",
download=True,
transform=transforms.Compose(
[
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
output_dir = Path("outputs-fabric", time.strftime("%Y%m%d-%H%M%S"))
output_dir.mkdir(parents=True, exist_ok=True)
# Plot some training images
real_batch = next(iter(dataloader))
torchvision.utils.save_image(
real_batch[0][:64],
output_dir / "sample-data.png",
padding=2,
normalize=True,
)
# Create the generator
generator = Generator()
# Apply the weights_init function to randomly initialize all weights
generator.apply(weights_init)
# Create the Discriminator
discriminator = Discriminator()
# Apply the weights_init function to randomly initialize all weights
discriminator.apply(weights_init)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=fabric.device)
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Set up Adam optimizers for both G and D
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
discriminator, optimizer_d = fabric.setup(discriminator, optimizer_d)
generator, optimizer_g = fabric.setup(generator, optimizer_g)
dataloader = fabric.setup_dataloaders(dataloader)
# Lists to keep track of progress
losses_g = []
losses_d = []
iteration = 0
# Training loop
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
# (a) Train with all-real batch
discriminator.zero_grad()
real = data[0]
b_size = real.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=fabric.device)
# Forward pass real batch through D
output = discriminator(real).view(-1)
# Calculate loss on all-real batch
err_d_real = criterion(output, label)
# Calculate gradients for D in backward pass
fabric.backward(err_d_real)
d_x = output.mean().item()
# (b) Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=fabric.device)
# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = discriminator(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
err_d_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
fabric.backward(err_d_fake)
d_g_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
err_d = err_d_real + err_d_fake
# Update D
optimizer_d.step()
# (2) Update G network: maximize log(D(G(z)))
generator.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = discriminator(fake).view(-1)
# Calculate G's loss based on this output
err_g = criterion(output, label)
# Calculate gradients for G
fabric.backward(err_g)
d_g_z2 = output.mean().item()
# Update G
optimizer_g.step()
# Output training stats
if i % 50 == 0:
fabric.print(
f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\t"
f"Loss_D: {err_d.item():.4f}\t"
f"Loss_G: {err_g.item():.4f}\t"
f"D(x): {d_x:.4f}\t"
f"D(G(z)): {d_g_z1:.4f} / {d_g_z2:.4f}"
)
# Save Losses for plotting later
losses_g.append(err_g.item())
losses_d.append(err_d.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iteration % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
with torch.no_grad():
fake = generator(fixed_noise).detach().cpu()
if fabric.is_global_zero:
torchvision.utils.save_image(
fake,
output_dir / f"fake-{iteration:04d}.png",
padding=2,
normalize=True,
)
fabric.barrier()
iteration += 1
def weights_init(m):
# custom weights initialization called on netG and netD
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, input):
return self.main(input)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,271 @@
"""
DCGAN - Raw PyTorch Implementation
Code adapted from the official PyTorch DCGAN tutorial:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
"""
import os
import random
import time
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils
from torchvision.datasets import CelebA
# Root directory for dataset
dataroot = "data/"
# Number of workers for dataloader
workers = os.cpu_count()
# Batch size during training
batch_size = 128
# Spatial size of training images
image_size = 64
# Number of channels in the training images
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
# Number of GPUs to use
num_gpus = 1
def main():
# Set random seed for reproducibility
seed = 999
print("Random Seed: ", seed)
random.seed(seed)
torch.manual_seed(seed)
dataset = CelebA(
root=dataroot,
split="all",
download=True,
transform=transforms.Compose(
[
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and num_gpus > 0) else "cpu")
output_dir = Path("outputs-torch", time.strftime("%Y%m%d-%H%M%S"))
output_dir.mkdir(parents=True, exist_ok=True)
# Plot some training images
real_batch = next(iter(dataloader))
torchvision.utils.save_image(
real_batch[0][:64],
output_dir / "sample-data.png",
padding=2,
normalize=True,
)
# Create the generator
generator = Generator().to(device)
# Handle multi-gpu if desired
if (device.type == "cuda") and (num_gpus > 1):
generator = nn.DataParallel(generator, list(range(num_gpus)))
# Apply the weights_init function to randomly initialize all weights
generator.apply(weights_init)
# Create the Discriminator
discriminator = Discriminator().to(device)
# Handle multi-gpu if desired
if (device.type == "cuda") and (num_gpus > 1):
discriminator = nn.DataParallel(discriminator, list(range(num_gpus)))
# Apply the weights_init function to randomly initialize all weights
discriminator.apply(weights_init)
# Initialize BCELoss function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
# Set up Adam optimizers for both G and D
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
# Lists to keep track of progress
losses_g = []
losses_d = []
iteration = 0
# Training loop
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
# (a) Train with all-real batch
discriminator.zero_grad()
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = discriminator(real_cpu).view(-1)
# Calculate loss on all-real batch
err_d_real = criterion(output, label)
# Calculate gradients for D in backward pass
err_d_real.backward()
d_x = output.mean().item()
# (b) Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = discriminator(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
err_d_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
err_d_fake.backward()
d_g_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
err_d = err_d_real + err_d_fake
# Update D
optimizer_d.step()
# (2) Update G network: maximize log(D(G(z)))
generator.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = discriminator(fake).view(-1)
# Calculate G's loss based on this output
err_g = criterion(output, label)
# Calculate gradients for G
err_g.backward()
d_g_z2 = output.mean().item()
# Update G
optimizer_g.step()
# Output training stats
if i % 50 == 0:
print(
f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\t"
f"Loss_D: {err_d.item():.4f}\t"
f"Loss_G: {err_g.item():.4f}\t"
f"D(x): {d_x:.4f}\t"
f"D(G(z)): {d_g_z1:.4f} / {d_g_z2:.4f}"
)
# Save Losses for plotting later
losses_g.append(err_g.item())
losses_d.append(err_d.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iteration % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
with torch.no_grad():
fake = generator(fixed_noise).detach().cpu()
torchvision.utils.save_image(
fake,
output_dir / f"fake-{iteration:04d}.png",
padding=2,
normalize=True,
)
iteration += 1
def weights_init(m):
# custom weights initialization called on netG and netD
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, input):
return self.main(input)
if __name__ == "__main__":
main()

View File

@ -4,13 +4,19 @@ Here are two MNIST classifiers implemented in PyTorch.
The first one is implemented in pure PyTorch, but isn't easy to scale.
The second one is using [Lightning Fabric](https://pytorch-lightning.readthedocs.io/en/stable/starter/lightning_fabric.html) to accelerate and scale the model.
Tip: You can easily inspect the difference between the two files with:
```bash
sdiff train_torch.py train_fabric.py
```
#### 1. Image Classifier with Vanilla PyTorch
Trains a simple CNN over MNIST using vanilla PyTorch. It only supports singe GPU training.
```bash
# CPU
python image_classifier_1_pytorch.py
python train_torch.py
```
______________________________________________________________________
@ -21,11 +27,11 @@ This script shows you how to scale the pure PyTorch code to enable GPU and multi
```bash
# CPU
lightning run model image_classifier_2_fabric.py
lightning run model train_fabric.py
# GPU (CUDA or M1 Mac)
lightning run model image_classifier_2_fabric.py --accelerator=gpu
lightning run model train_fabric.py --accelerator=gpu
# Multiple GPUs
lightning run model image_classifier_2_fabric.py --accelerator=gpu --devices=4
lightning run model train_fabric.py --accelerator=gpu --devices=4
```

View File

@ -32,10 +32,10 @@ import argparse
from os import path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from models import Net
from torch.optim.lr_scheduler import StepLR
from torchmetrics.classification import Accuracy
from torchvision.datasets import MNIST
@ -43,7 +43,33 @@ from torchvision.datasets import MNIST
from lightning.fabric import Fabric # import Fabric
from lightning.fabric import seed_everything
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets")
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def run(hparams):

View File

@ -15,18 +15,44 @@ import argparse
from os import path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from models import Net
from torch.optim.lr_scheduler import StepLR
from torchvision.datasets import MNIST
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets")
# Credit to the PyTorch team
# Taken from https://github.com/pytorch/examples/blob/master/mnist/main.py and slightly adapted.
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def run(hparams):
torch.manual_seed(hparams.seed)

View File

@ -1,42 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

View File

@ -18,4 +18,4 @@ export PYTHONPATH="${PYTHONPATH}:$(pwd)"
dir_path=$(dirname "${BASH_SOURCE[0]}")
args="--epochs=1"
python -m lightning_fabric.cli "${dir_path}/fabric/image_classifier_2_fabric.py" ${args} "$@"
python -m lightning_fabric.cli "${dir_path}/fabric/image_classifier/train_fabric.py" ${args} "$@"