yapf examples (#5709)
This commit is contained in:
parent
07f24d2438
commit
21d313edc5
|
@ -1,8 +1,5 @@
|
|||
.git/*
|
||||
|
||||
# TODO
|
||||
pl_examples/*
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/*
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets')
|
|||
_TORCHVISION_AVAILABLE = _module_available("torchvision")
|
||||
_DALI_AVAILABLE = _module_available("nvidia.dali")
|
||||
|
||||
|
||||
LIGHTNING_LOGO = """
|
||||
####
|
||||
###########
|
||||
|
|
|
@ -43,12 +43,12 @@ class LitAutoEncoder(pl.LightningModule):
|
|||
self.encoder = nn.Sequential(
|
||||
nn.Linear(28 * 28, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3)
|
||||
nn.Linear(64, 3),
|
||||
)
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Linear(3, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 28 * 28)
|
||||
nn.Linear(64, 28 * 28),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -36,6 +36,7 @@ class Backbone(torch.nn.Module):
|
|||
(l2): Linear(...)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
|
||||
|
@ -55,6 +56,7 @@ class LitClassifier(pl.LightningModule):
|
|||
(backbone): ...
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, learning_rate=1e-3):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
|
|
@ -39,16 +39,17 @@ if _BOLTS_AVAILABLE:
|
|||
import pl_bolts
|
||||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
||||
|
||||
|
||||
#####################
|
||||
# Modules #
|
||||
#####################
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
###############################
|
||||
# LightningModule #
|
||||
###############################
|
||||
|
@ -61,6 +62,7 @@ class LitResnet(pl.LightningModule):
|
|||
(sequential_module): Sequential(...)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
|
||||
super().__init__()
|
||||
|
||||
|
@ -90,9 +92,7 @@ class LitResnet(pl.LightningModule):
|
|||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
Flatten(),
|
||||
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(4096, 1024),
|
||||
nn.ReLU(inplace=False),
|
||||
|
@ -159,7 +159,8 @@ class LitResnet(pl.LightningModule):
|
|||
optimizer,
|
||||
0.1,
|
||||
epochs=self.trainer.max_epochs,
|
||||
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
|
||||
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)
|
||||
),
|
||||
'interval': 'step',
|
||||
}
|
||||
}
|
||||
|
@ -173,6 +174,7 @@ class LitResnet(pl.LightningModule):
|
|||
# Instantiate Data Module #
|
||||
#################################
|
||||
|
||||
|
||||
def instantiate_datamodule(args):
|
||||
train_transforms = torchvision.transforms.Compose([
|
||||
torchvision.transforms.RandomCrop(32, padding=4),
|
||||
|
|
|
@ -95,22 +95,30 @@ class DALIClassificationLoader(DALIClassificationIterator):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipelines,
|
||||
size=-1,
|
||||
reader_name=None,
|
||||
auto_reset=False,
|
||||
fill_last_batch=True,
|
||||
dynamic_shape=False,
|
||||
last_batch_padded=False,
|
||||
self,
|
||||
pipelines,
|
||||
size=-1,
|
||||
reader_name=None,
|
||||
auto_reset=False,
|
||||
fill_last_batch=True,
|
||||
dynamic_shape=False,
|
||||
last_batch_padded=False,
|
||||
):
|
||||
if NEW_DALI_API:
|
||||
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
|
||||
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
|
||||
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
|
||||
super().__init__(
|
||||
pipelines,
|
||||
size,
|
||||
reader_name,
|
||||
auto_reset,
|
||||
dynamic_shape,
|
||||
last_batch_policy=last_batch_policy,
|
||||
last_batch_padded=last_batch_padded
|
||||
)
|
||||
else:
|
||||
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
|
||||
dynamic_shape, last_batch_padded)
|
||||
super().__init__(
|
||||
pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded
|
||||
)
|
||||
self._fill_last_batch = fill_last_batch
|
||||
|
||||
def __len__(self):
|
||||
|
@ -120,6 +128,7 @@ class DALIClassificationLoader(DALIClassificationIterator):
|
|||
|
||||
|
||||
class LitClassifier(pl.LightningModule):
|
||||
|
||||
def __init__(self, hidden_dim=128, learning_rate=1e-3):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
|
|
@ -58,8 +58,10 @@ class MNISTDataModule(LightningDataModule):
|
|||
super().__init__(*args, **kwargs)
|
||||
if num_workers and platform.system() == "Windows":
|
||||
# see: https://stackoverflow.com/a/59680818
|
||||
warn(f"You have requested num_workers={num_workers} on Windows,"
|
||||
" but currently recommended is 0, so we set it for you")
|
||||
warn(
|
||||
f"You have requested num_workers={num_workers} on Windows,"
|
||||
" but currently recommended is 0, so we set it for you"
|
||||
)
|
||||
num_workers = 0
|
||||
|
||||
self.dims = (1, 28, 28)
|
||||
|
@ -132,9 +134,9 @@ class MNISTDataModule(LightningDataModule):
|
|||
if not _TORCHVISION_AVAILABLE:
|
||||
return None
|
||||
if self.normalize:
|
||||
mnist_transforms = transform_lib.Compose(
|
||||
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
|
||||
)
|
||||
mnist_transforms = transform_lib.Compose([
|
||||
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
|
||||
])
|
||||
else:
|
||||
mnist_transforms = transform_lib.ToTensor()
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ class LitClassifier(pl.LightningModule):
|
|||
(l2): Linear(...)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim=128, learning_rate=1e-3):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
|
|
@ -33,6 +33,7 @@ class RandomDataset(Dataset):
|
|||
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
|
||||
<...bug_report_model.RandomDataset object at ...>
|
||||
"""
|
||||
|
||||
def __init__(self, size, length):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
@ -124,9 +125,11 @@ class BoringModel(LightningModule):
|
|||
# parser = ArgumentParser()
|
||||
# args = parser.parse_args(opt)
|
||||
|
||||
|
||||
def test_run():
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
print('override any method to prove your bug')
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -60,14 +59,12 @@ from pytorch_lightning.callbacks.finetuning import BaseFinetuningCallback
|
|||
|
||||
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
|
||||
|
||||
|
||||
# --- Finetunning Callback ---
|
||||
|
||||
|
||||
class MilestonesFinetuningCallback(BaseFinetuningCallback):
|
||||
|
||||
def __init__(self,
|
||||
milestones: tuple = (5, 10),
|
||||
train_bn: bool = True):
|
||||
def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
|
||||
self.milestones = milestones
|
||||
self.train_bn = train_bn
|
||||
|
||||
|
@ -78,17 +75,13 @@ class MilestonesFinetuningCallback(BaseFinetuningCallback):
|
|||
if epoch == self.milestones[0]:
|
||||
# unfreeze 5 last layers
|
||||
self.unfreeze_and_add_param_group(
|
||||
module=pl_module.feature_extractor[-5:],
|
||||
optimizer=optimizer,
|
||||
train_bn=self.train_bn
|
||||
module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
|
||||
)
|
||||
|
||||
elif epoch == self.milestones[1]:
|
||||
# unfreeze remaing layers
|
||||
self.unfreeze_and_add_param_group(
|
||||
module=pl_module.feature_extractor[:-5],
|
||||
optimizer=optimizer,
|
||||
train_bn=self.train_bn
|
||||
module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
|
||||
)
|
||||
|
||||
|
||||
|
@ -149,10 +142,12 @@ class TransferLearningModel(pl.LightningModule):
|
|||
self.feature_extractor = nn.Sequential(*_layers)
|
||||
|
||||
# 2. Classifier:
|
||||
_fc_layers = [nn.Linear(2048, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 32),
|
||||
nn.Linear(32, 1)]
|
||||
_fc_layers = [
|
||||
nn.Linear(2048, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 32),
|
||||
nn.Linear(32, 1),
|
||||
]
|
||||
self.fc = nn.Sequential(*_fc_layers)
|
||||
|
||||
# 3. Loss:
|
||||
|
@ -218,25 +213,21 @@ class TransferLearningModel(pl.LightningModule):
|
|||
|
||||
train_dataset = ImageFolder(
|
||||
root=data_path.joinpath("train"),
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]),
|
||||
)
|
||||
|
||||
valid_dataset = ImageFolder(
|
||||
root=data_path.joinpath("validation"),
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]),
|
||||
)
|
||||
|
||||
self.train_dataset = train_dataset
|
||||
|
|
|
@ -43,6 +43,7 @@ class Generator(nn.Module):
|
|||
(model): Sequential(...)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
|
||||
super().__init__()
|
||||
self.img_shape = img_shape
|
||||
|
@ -60,7 +61,7 @@ class Generator(nn.Module):
|
|||
*block(256, 512),
|
||||
*block(512, 1024),
|
||||
nn.Linear(1024, int(np.prod(img_shape))),
|
||||
nn.Tanh()
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
|
@ -76,6 +77,7 @@ class Discriminator(nn.Module):
|
|||
(model): Sequential(...)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, img_shape):
|
||||
super().__init__()
|
||||
|
||||
|
@ -106,13 +108,14 @@ class GAN(LightningModule):
|
|||
)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_shape: tuple = (1, 28, 28),
|
||||
lr: float = 0.0002,
|
||||
b1: float = 0.5,
|
||||
b2: float = 0.999,
|
||||
latent_dim: int = 100,
|
||||
self,
|
||||
img_shape: tuple = (1, 28, 28),
|
||||
lr: float = 0.0002,
|
||||
b1: float = 0.5,
|
||||
b2: float = 0.999,
|
||||
latent_dim: int = 100,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -130,12 +133,9 @@ class GAN(LightningModule):
|
|||
def add_argparse_args(parent_parser: ArgumentParser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
|
||||
parser.add_argument("--b1", type=float, default=0.5,
|
||||
help="adam: decay of first order momentum of gradient")
|
||||
parser.add_argument("--b2", type=float, default=0.999,
|
||||
help="adam: decay of second order momentum of gradient")
|
||||
parser.add_argument("--latent_dim", type=int, default=100,
|
||||
help="dimensionality of the latent space")
|
||||
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
|
||||
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
|
||||
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
|
||||
|
||||
return parser
|
||||
|
||||
|
@ -180,8 +180,7 @@ class GAN(LightningModule):
|
|||
fake = torch.zeros(imgs.size(0), 1)
|
||||
fake = fake.type_as(imgs)
|
||||
|
||||
fake_loss = self.adversarial_loss(
|
||||
self.discriminator(self(z).detach()), fake)
|
||||
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
|
||||
|
||||
# discriminator loss is the average of these
|
||||
d_loss = (real_loss + fake_loss) / 2
|
||||
|
@ -213,14 +212,14 @@ class MNISTDataModule(LightningDataModule):
|
|||
>>> MNISTDataModule() # doctest: +ELLIPSIS
|
||||
<...generative_adversarial_net.MNISTDataModule object at ...>
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.data_path = data_path
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])])
|
||||
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
||||
self.dims = (1, 28, 28)
|
||||
|
||||
def prepare_data(self, stage=None):
|
||||
|
|
|
@ -63,16 +63,16 @@ class ImageNetLightningModel(LightningModule):
|
|||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
arch: str = 'resnet18',
|
||||
pretrained: bool = False,
|
||||
lr: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 1e-4,
|
||||
batch_size: int = 4,
|
||||
workers: int = 2,
|
||||
**kwargs,
|
||||
self,
|
||||
data_path: str,
|
||||
arch: str = 'resnet18',
|
||||
pretrained: bool = False,
|
||||
lr: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 1e-4,
|
||||
batch_size: int = 4,
|
||||
workers: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
@ -109,7 +109,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
self.log('val_acc5', acc5, on_step=True, on_epoch=True)
|
||||
|
||||
@staticmethod
|
||||
def __accuracy(output, target, topk=(1,)):
|
||||
def __accuracy(output, target, topk=(1, )):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
|
@ -126,16 +126,8 @@ class ImageNetLightningModel(LightningModule):
|
|||
return res
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.SGD(
|
||||
self.parameters(),
|
||||
lr=self.lr,
|
||||
momentum=self.momentum,
|
||||
weight_decay=self.weight_decay
|
||||
)
|
||||
scheduler = lr_scheduler.LambdaLR(
|
||||
optimizer,
|
||||
lambda epoch: 0.1 ** (epoch // 30)
|
||||
)
|
||||
optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1**(epoch // 30))
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def train_dataloader(self):
|
||||
|
@ -152,7 +144,8 @@ class ImageNetLightningModel(LightningModule):
|
|||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]))
|
||||
])
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset=train_dataset,
|
||||
|
@ -169,12 +162,15 @@ class ImageNetLightningModel(LightningModule):
|
|||
)
|
||||
val_dir = os.path.join(self.data_path, 'val')
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
datasets.ImageFolder(val_dir, transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])),
|
||||
datasets.ImageFolder(
|
||||
val_dir,
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.workers,
|
||||
|
@ -203,26 +199,40 @@ class ImageNetLightningModel(LightningModule):
|
|||
@staticmethod
|
||||
def add_model_specific_args(parent_parser): # pragma: no-cover
|
||||
parser = ArgumentParser(parents=[parent_parser])
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
|
||||
choices=ImageNetLightningModel.MODEL_NAMES,
|
||||
help=('model architecture: ' + ' | '.join(ImageNetLightningModel.MODEL_NAMES)
|
||||
+ ' (default: resnet18)'))
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument(
|
||||
'-a',
|
||||
'--arch',
|
||||
metavar='ARCH',
|
||||
default='resnet18',
|
||||
choices=ImageNetLightningModel.MODEL_NAMES,
|
||||
help=('model architecture: ' + ' | '.join(ImageNetLightningModel.MODEL_NAMES) + ' (default: resnet18)')
|
||||
)
|
||||
parser.add_argument(
|
||||
'-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-b',
|
||||
'--batch-size',
|
||||
default=256,
|
||||
type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total batch size of all GPUs on the current node'
|
||||
' when using Data Parallel or Distributed Data Parallel'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr'
|
||||
)
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
|
||||
parser.add_argument(
|
||||
'--wd',
|
||||
'--weight-decay',
|
||||
default=1e-4,
|
||||
type=float,
|
||||
metavar='W',
|
||||
help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay'
|
||||
)
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -249,12 +259,11 @@ def main(args: Namespace) -> None:
|
|||
def run_cli():
|
||||
parent_parser = ArgumentParser(add_help=False)
|
||||
parent_parser = pl.Trainer.add_argparse_args(parent_parser)
|
||||
parent_parser.add_argument('--data-path', metavar='DIR', type=str,
|
||||
help='path to dataset')
|
||||
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parent_parser.add_argument('--seed', type=int, default=42,
|
||||
help='seed for initializing training.')
|
||||
parent_parser.add_argument('--data-path', metavar='DIR', type=str, help='path to dataset')
|
||||
parent_parser.add_argument(
|
||||
'-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set'
|
||||
)
|
||||
parent_parser.add_argument('--seed', type=int, default=42, help='seed for initializing training.')
|
||||
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
|
||||
parser.set_defaults(
|
||||
profiler="simple",
|
||||
|
|
|
@ -70,7 +70,7 @@ class DQN(nn.Module):
|
|||
self.net = nn.Sequential(
|
||||
nn.Linear(obs_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, n_actions)
|
||||
nn.Linear(hidden_size, n_actions),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -78,9 +78,7 @@ class DQN(nn.Module):
|
|||
|
||||
|
||||
# Named tuple for storing experience steps gathered in training
|
||||
Experience = namedtuple(
|
||||
'Experience', field_names=['state', 'action', 'reward',
|
||||
'done', 'new_state'])
|
||||
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'done', 'new_state'])
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
|
@ -114,8 +112,13 @@ class ReplayBuffer:
|
|||
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
|
||||
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
|
||||
|
||||
return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
|
||||
np.array(dones, dtype=np.bool), np.array(next_states))
|
||||
return (
|
||||
np.array(states),
|
||||
np.array(actions),
|
||||
np.array(rewards, dtype=np.float32),
|
||||
np.array(dones, dtype=np.bool),
|
||||
np.array(next_states),
|
||||
)
|
||||
|
||||
|
||||
class RLDataset(IterableDataset):
|
||||
|
@ -236,20 +239,21 @@ class DQNLightning(pl.LightningModule):
|
|||
)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: str,
|
||||
replay_size: int = 200,
|
||||
warm_start_steps: int = 200,
|
||||
gamma: float = 0.99,
|
||||
eps_start: float = 1.0,
|
||||
eps_end: float = 0.01,
|
||||
eps_last_frame: int = 200,
|
||||
sync_rate: int = 10,
|
||||
lr: float = 1e-2,
|
||||
episode_length: int = 50,
|
||||
batch_size: int = 4,
|
||||
**kwargs,
|
||||
self,
|
||||
env: str,
|
||||
replay_size: int = 200,
|
||||
warm_start_steps: int = 200,
|
||||
gamma: float = 0.99,
|
||||
eps_start: float = 1.0,
|
||||
eps_end: float = 0.01,
|
||||
eps_last_frame: int = 200,
|
||||
sync_rate: int = 10,
|
||||
lr: float = 1e-2,
|
||||
episode_length: int = 50,
|
||||
batch_size: int = 4,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.replay_size = replay_size
|
||||
|
@ -353,9 +357,11 @@ class DQNLightning(pl.LightningModule):
|
|||
if self.global_step % self.sync_rate == 0:
|
||||
self.target_net.load_state_dict(self.net.state_dict())
|
||||
|
||||
log = {'total_reward': torch.tensor(self.total_reward).to(device),
|
||||
'reward': torch.tensor(reward).to(device),
|
||||
'steps': torch.tensor(self.global_step).to(device)}
|
||||
log = {
|
||||
'total_reward': torch.tensor(self.total_reward).to(device),
|
||||
'reward': torch.tensor(reward).to(device),
|
||||
'steps': torch.tensor(self.global_step).to(device)
|
||||
}
|
||||
|
||||
return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})
|
||||
|
||||
|
@ -389,21 +395,20 @@ class DQNLightning(pl.LightningModule):
|
|||
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
|
||||
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
|
||||
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
|
||||
parser.add_argument("--sync_rate", type=int, default=10,
|
||||
help="how many frames do we update the target network")
|
||||
parser.add_argument("--replay_size", type=int, default=1000,
|
||||
help="capacity of the replay buffer")
|
||||
parser.add_argument("--warm_start_size", type=int, default=1000,
|
||||
help="how many samples do we use to fill our buffer at the start of training")
|
||||
parser.add_argument("--eps_last_frame", type=int, default=1000,
|
||||
help="what frame should epsilon stop decaying")
|
||||
parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network")
|
||||
parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")
|
||||
parser.add_argument(
|
||||
"--warm_start_size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="how many samples do we use to fill our buffer at the start of training"
|
||||
)
|
||||
parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying")
|
||||
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
|
||||
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
|
||||
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
|
||||
parser.add_argument("--max_episode_reward", type=int, default=200,
|
||||
help="max episode reward in the environment")
|
||||
parser.add_argument("--warm_start_steps", type=int, default=1000,
|
||||
help="max episode reward in the environment")
|
||||
parser.add_argument("--max_episode_reward", type=int, default=200, help="max episode reward in the environment")
|
||||
parser.add_argument("--warm_start_steps", type=int, default=1000, help="max episode reward in the environment")
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -413,7 +418,7 @@ def main(args) -> None:
|
|||
trainer = pl.Trainer(
|
||||
gpus=1,
|
||||
accelerator='dp',
|
||||
val_check_interval=100
|
||||
val_check_interval=100,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -50,7 +50,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
|
|||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, n_actions)
|
||||
nn.Linear(hidden_size, n_actions),
|
||||
)
|
||||
|
||||
return network
|
||||
|
@ -159,6 +159,7 @@ class PPOLightning(pl.LightningModule):
|
|||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: str,
|
||||
|
@ -173,7 +174,6 @@ class PPOLightning(pl.LightningModule):
|
|||
clip_ratio: float = 0.2,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
"""
|
||||
Args:
|
||||
env: gym environment tag
|
||||
|
@ -213,8 +213,10 @@ class PPOLightning(pl.LightningModule):
|
|||
actor_mlp = create_mlp(self.env.observation_space.shape, self.env.action_space.n)
|
||||
self.actor = ActorCategorical(actor_mlp)
|
||||
else:
|
||||
raise NotImplementedError('Env action space should be of type Box (continous) or Discrete (categorical). '
|
||||
f'Got type: {type(self.env.action_space)}')
|
||||
raise NotImplementedError(
|
||||
'Env action space should be of type Box (continous) or Discrete (categorical).'
|
||||
f' Got type: {type(self.env.action_space)}'
|
||||
)
|
||||
|
||||
self.batch_states = []
|
||||
self.batch_actions = []
|
||||
|
@ -287,9 +289,7 @@ class PPOLightning(pl.LightningModule):
|
|||
|
||||
return adv
|
||||
|
||||
def generate_trajectory_samples(
|
||||
self,
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
|
||||
def generate_trajectory_samples(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
|
||||
"""
|
||||
Contains the logic for generating trajectory data to train policy and value network
|
||||
Yield:
|
||||
|
@ -345,8 +345,8 @@ class PPOLightning(pl.LightningModule):
|
|||
|
||||
if epoch_end:
|
||||
train_data = zip(
|
||||
self.batch_states, self.batch_actions, self.batch_logp,
|
||||
self.batch_qvals, self.batch_adv)
|
||||
self.batch_states, self.batch_actions, self.batch_logp, self.batch_qvals, self.batch_adv
|
||||
)
|
||||
|
||||
for state, action, logp_old, qval, adv in train_data:
|
||||
yield state, action, logp_old, qval, adv
|
||||
|
@ -454,12 +454,18 @@ class PPOLightning(pl.LightningModule):
|
|||
parser.add_argument("--lr_critic", type=float, default=1e-3, help="learning rate of critic network")
|
||||
parser.add_argument("--max_episode_len", type=int, default=1000, help="capacity of the replay buffer")
|
||||
parser.add_argument("--batch_size", type=int, default=512, help="batch_size when training network")
|
||||
parser.add_argument("--steps_per_epoch", type=int, default=2048,
|
||||
help="how many action-state pairs to rollout for trajectory collection per epoch")
|
||||
parser.add_argument("--nb_optim_iters", type=int, default=4,
|
||||
help="how many steps of gradient descent to perform on each batch")
|
||||
parser.add_argument("--clip_ratio", type=float, default=0.2,
|
||||
help="hyperparameter for clipping in the policy objective")
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="how many action-state pairs to rollout for trajectory collection per epoch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nb_optim_iters", type=int, default=4, help="how many steps of gradient descent to perform on each batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_ratio", type=float, default=0.2, help="hyperparameter for clipping in the policy objective"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
|
|
@ -178,15 +178,16 @@ class SegModel(pl.LightningModule):
|
|||
)
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
batch_size: int = 4,
|
||||
lr: float = 1e-3,
|
||||
num_layers: int = 3,
|
||||
features_start: int = 64,
|
||||
bilinear: bool = False,
|
||||
**kwargs,
|
||||
self,
|
||||
data_path: str,
|
||||
batch_size: int = 4,
|
||||
lr: float = 1e-3,
|
||||
num_layers: int = 3,
|
||||
features_start: int = 64,
|
||||
bilinear: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.data_path = data_path
|
||||
|
@ -196,12 +197,12 @@ class SegModel(pl.LightningModule):
|
|||
self.features_start = features_start
|
||||
self.bilinear = bilinear
|
||||
|
||||
self.net = UNet(num_classes=19, num_layers=self.num_layers,
|
||||
features_start=self.features_start, bilinear=self.bilinear)
|
||||
self.net = UNet(
|
||||
num_classes=19, num_layers=self.num_layers, features_start=self.features_start, bilinear=self.bilinear
|
||||
)
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
|
||||
std=[0.32064945, 0.32098866, 0.32325324])
|
||||
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
|
||||
])
|
||||
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
|
||||
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)
|
||||
|
@ -250,8 +251,12 @@ class SegModel(pl.LightningModule):
|
|||
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
|
||||
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
|
||||
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
|
||||
parser.add_argument("--bilinear", action='store_true', default=False,
|
||||
help="whether to use bilinear interpolation or transposed")
|
||||
parser.add_argument(
|
||||
"--bilinear",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="whether to use bilinear interpolation or transposed"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -36,11 +36,11 @@ class UNet(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int = 19,
|
||||
num_layers: int = 5,
|
||||
features_start: int = 64,
|
||||
bilinear: bool = False,
|
||||
self,
|
||||
num_classes: int = 19,
|
||||
num_layers: int = 5,
|
||||
features_start: int = 64,
|
||||
bilinear: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -97,7 +97,7 @@ class DoubleConv(nn.Module):
|
|||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True)
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -121,10 +121,7 @@ class Down(nn.Module):
|
|||
|
||||
def __init__(self, in_ch: int, out_ch: int):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
DoubleConv(in_ch, out_ch)
|
||||
)
|
||||
self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
|
|
@ -53,11 +53,13 @@ ARGS_DDP_AMP = ARGS_DEFAULT + """
|
|||
"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize('import_cli', [
|
||||
'pl_examples.basic_examples.simple_image_classifier',
|
||||
'pl_examples.basic_examples.backbone_image_classifier',
|
||||
'pl_examples.basic_examples.autoencoder',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
'import_cli', [
|
||||
'pl_examples.basic_examples.simple_image_classifier',
|
||||
'pl_examples.basic_examples.backbone_image_classifier',
|
||||
'pl_examples.basic_examples.autoencoder',
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.parametrize('cli_args', [ARGS_DP, ARGS_DP_AMP])
|
||||
def test_examples_dp(tmpdir, import_cli, cli_args):
|
||||
|
@ -88,11 +90,13 @@ def test_examples_dp(tmpdir, import_cli, cli_args):
|
|||
# module.cli_main()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('import_cli', [
|
||||
'pl_examples.basic_examples.simple_image_classifier',
|
||||
'pl_examples.basic_examples.backbone_image_classifier',
|
||||
'pl_examples.basic_examples.autoencoder',
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
'import_cli', [
|
||||
'pl_examples.basic_examples.simple_image_classifier',
|
||||
'pl_examples.basic_examples.backbone_image_classifier',
|
||||
'pl_examples.basic_examples.autoencoder',
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize('cli_args', [ARGS_DEFAULT])
|
||||
def test_examples_cpu(tmpdir, import_cli, cli_args):
|
||||
|
||||
|
|
Loading…
Reference in New Issue