yapf examples (#5709)

This commit is contained in:
Jirka Borovec 2021-01-30 11:17:12 +01:00 committed by GitHub
parent 07f24d2438
commit 21d313edc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 243 additions and 212 deletions

View File

@ -1,8 +1,5 @@
.git/*
# TODO
pl_examples/*
# TODO
pytorch_lightning/*

View File

@ -9,7 +9,6 @@ _DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets')
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_DALI_AVAILABLE = _module_available("nvidia.dali")
LIGHTNING_LOGO = """
####
###########

View File

@ -43,12 +43,12 @@ class LitAutoEncoder(pl.LightningModule):
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3)
nn.Linear(64, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28)
nn.Linear(64, 28 * 28),
)
def forward(self, x):

View File

@ -36,6 +36,7 @@ class Backbone(torch.nn.Module):
(l2): Linear(...)
)
"""
def __init__(self, hidden_dim=128):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
@ -55,6 +56,7 @@ class LitClassifier(pl.LightningModule):
(backbone): ...
)
"""
def __init__(self, backbone, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

View File

@ -39,16 +39,17 @@ if _BOLTS_AVAILABLE:
import pl_bolts
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
#####################
# Modules #
#####################
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
###############################
# LightningModule #
###############################
@ -61,6 +62,7 @@ class LitResnet(pl.LightningModule):
(sequential_module): Sequential(...)
)
"""
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
super().__init__()
@ -90,9 +92,7 @@ class LitResnet(pl.LightningModule):
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=False),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=False),
@ -159,7 +159,8 @@ class LitResnet(pl.LightningModule):
optimizer,
0.1,
epochs=self.trainer.max_epochs,
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)),
steps_per_epoch=math.ceil(45000 / self.hparams.batch_size)
),
'interval': 'step',
}
}
@ -173,6 +174,7 @@ class LitResnet(pl.LightningModule):
# Instantiate Data Module #
#################################
def instantiate_datamodule(args):
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),

View File

@ -95,22 +95,30 @@ class DALIClassificationLoader(DALIClassificationIterator):
"""
def __init__(
self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=True,
dynamic_shape=False,
last_batch_padded=False,
self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=True,
dynamic_shape=False,
last_batch_padded=False,
):
if NEW_DALI_API:
last_batch_policy = LastBatchPolicy.FILL if fill_last_batch else LastBatchPolicy.DROP
super().__init__(pipelines, size, reader_name, auto_reset, dynamic_shape,
last_batch_policy=last_batch_policy, last_batch_padded=last_batch_padded)
super().__init__(
pipelines,
size,
reader_name,
auto_reset,
dynamic_shape,
last_batch_policy=last_batch_policy,
last_batch_padded=last_batch_padded
)
else:
super().__init__(pipelines, size, reader_name, auto_reset, fill_last_batch,
dynamic_shape, last_batch_padded)
super().__init__(
pipelines, size, reader_name, auto_reset, fill_last_batch, dynamic_shape, last_batch_padded
)
self._fill_last_batch = fill_last_batch
def __len__(self):
@ -120,6 +128,7 @@ class DALIClassificationLoader(DALIClassificationIterator):
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

View File

@ -58,8 +58,10 @@ class MNISTDataModule(LightningDataModule):
super().__init__(*args, **kwargs)
if num_workers and platform.system() == "Windows":
# see: https://stackoverflow.com/a/59680818
warn(f"You have requested num_workers={num_workers} on Windows,"
" but currently recommended is 0, so we set it for you")
warn(
f"You have requested num_workers={num_workers} on Windows,"
" but currently recommended is 0, so we set it for you"
)
num_workers = 0
self.dims = (1, 28, 28)
@ -132,9 +134,9 @@ class MNISTDataModule(LightningDataModule):
if not _TORCHVISION_AVAILABLE:
return None
if self.normalize:
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
])
else:
mnist_transforms = transform_lib.ToTensor()

View File

@ -31,6 +31,7 @@ class LitClassifier(pl.LightningModule):
(l2): Linear(...)
)
"""
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

View File

@ -33,6 +33,7 @@ class RandomDataset(Dataset):
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
<...bug_report_model.RandomDataset object at ...>
"""
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
@ -124,9 +125,11 @@ class BoringModel(LightningModule):
# parser = ArgumentParser()
# args = parser.parse_args(opt)
def test_run():
class TestModel(BoringModel):
def on_train_epoch_start(self) -> None:
print('override any method to prove your bug')

View File

@ -1,4 +1,3 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -60,14 +59,12 @@ from pytorch_lightning.callbacks.finetuning import BaseFinetuningCallback
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
# --- Finetunning Callback ---
class MilestonesFinetuningCallback(BaseFinetuningCallback):
def __init__(self,
milestones: tuple = (5, 10),
train_bn: bool = True):
def __init__(self, milestones: tuple = (5, 10), train_bn: bool = True):
self.milestones = milestones
self.train_bn = train_bn
@ -78,17 +75,13 @@ class MilestonesFinetuningCallback(BaseFinetuningCallback):
if epoch == self.milestones[0]:
# unfreeze 5 last layers
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[-5:],
optimizer=optimizer,
train_bn=self.train_bn
module=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
)
elif epoch == self.milestones[1]:
# unfreeze remaing layers
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor[:-5],
optimizer=optimizer,
train_bn=self.train_bn
module=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
)
@ -149,10 +142,12 @@ class TransferLearningModel(pl.LightningModule):
self.feature_extractor = nn.Sequential(*_layers)
# 2. Classifier:
_fc_layers = [nn.Linear(2048, 256),
nn.ReLU(),
nn.Linear(256, 32),
nn.Linear(32, 1)]
_fc_layers = [
nn.Linear(2048, 256),
nn.ReLU(),
nn.Linear(256, 32),
nn.Linear(32, 1),
]
self.fc = nn.Sequential(*_fc_layers)
# 3. Loss:
@ -218,25 +213,21 @@ class TransferLearningModel(pl.LightningModule):
train_dataset = ImageFolder(
root=data_path.joinpath("train"),
transform=transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
valid_dataset = ImageFolder(
root=data_path.joinpath("validation"),
transform=transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]
),
transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]),
)
self.train_dataset = train_dataset

View File

@ -43,6 +43,7 @@ class Generator(nn.Module):
(model): Sequential(...)
)
"""
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape
@ -60,7 +61,7 @@ class Generator(nn.Module):
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
nn.Tanh(),
)
def forward(self, z):
@ -76,6 +77,7 @@ class Discriminator(nn.Module):
(model): Sequential(...)
)
"""
def __init__(self, img_shape):
super().__init__()
@ -106,13 +108,14 @@ class GAN(LightningModule):
)
)
"""
def __init__(
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
):
super().__init__()
@ -130,12 +133,9 @@ class GAN(LightningModule):
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999,
help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100,
help="dimensionality of the latent space")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
return parser
@ -180,8 +180,7 @@ class GAN(LightningModule):
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
fake_loss = self.adversarial_loss(
self.discriminator(self(z).detach()), fake)
fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
@ -213,14 +212,14 @@ class MNISTDataModule(LightningDataModule):
>>> MNISTDataModule() # doctest: +ELLIPSIS
<...generative_adversarial_net.MNISTDataModule object at ...>
"""
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
self.data_path = data_path
self.num_workers = num_workers
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
self.dims = (1, 28, 28)
def prepare_data(self, stage=None):

View File

@ -63,16 +63,16 @@ class ImageNetLightningModel(LightningModule):
)
def __init__(
self,
data_path: str,
arch: str = 'resnet18',
pretrained: bool = False,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
batch_size: int = 4,
workers: int = 2,
**kwargs,
self,
data_path: str,
arch: str = 'resnet18',
pretrained: bool = False,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
batch_size: int = 4,
workers: int = 2,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
@ -109,7 +109,7 @@ class ImageNetLightningModel(LightningModule):
self.log('val_acc5', acc5, on_step=True, on_epoch=True)
@staticmethod
def __accuracy(output, target, topk=(1,)):
def __accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
@ -126,16 +126,8 @@ class ImageNetLightningModel(LightningModule):
return res
def configure_optimizers(self):
optimizer = optim.SGD(
self.parameters(),
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay
)
scheduler = lr_scheduler.LambdaLR(
optimizer,
lambda epoch: 0.1 ** (epoch // 30)
)
optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1**(epoch // 30))
return [optimizer], [scheduler]
def train_dataloader(self):
@ -152,7 +144,8 @@ class ImageNetLightningModel(LightningModule):
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
])
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
@ -169,12 +162,15 @@ class ImageNetLightningModel(LightningModule):
)
val_dir = os.path.join(self.data_path, 'val')
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(val_dir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
datasets.ImageFolder(
val_dir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.workers,
@ -203,26 +199,40 @@ class ImageNetLightningModel(LightningModule):
@staticmethod
def add_model_specific_args(parent_parser): # pragma: no-cover
parser = ArgumentParser(parents=[parent_parser])
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=ImageNetLightningModel.MODEL_NAMES,
help=('model architecture: ' + ' | '.join(ImageNetLightningModel.MODEL_NAMES)
+ ' (default: resnet18)'))
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument(
'-a',
'--arch',
metavar='ARCH',
default='resnet18',
choices=ImageNetLightningModel.MODEL_NAMES,
help=('model architecture: ' + ' | '.join(ImageNetLightningModel.MODEL_NAMES) + ' (default: resnet18)')
)
parser.add_argument(
'-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)'
)
parser.add_argument(
'-b',
'--batch-size',
default=256,
type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total batch size of all GPUs on the current node'
' when using Data Parallel or Distributed Data Parallel'
)
parser.add_argument(
'--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr'
)
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument(
'--wd',
'--weight-decay',
default=1e-4,
type=float,
metavar='W',
help='weight decay (default: 1e-4)',
dest='weight_decay'
)
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
return parser
@ -249,12 +259,11 @@ def main(args: Namespace) -> None:
def run_cli():
parent_parser = ArgumentParser(add_help=False)
parent_parser = pl.Trainer.add_argparse_args(parent_parser)
parent_parser.add_argument('--data-path', metavar='DIR', type=str,
help='path to dataset')
parent_parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parent_parser.add_argument('--seed', type=int, default=42,
help='seed for initializing training.')
parent_parser.add_argument('--data-path', metavar='DIR', type=str, help='path to dataset')
parent_parser.add_argument(
'-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set'
)
parent_parser.add_argument('--seed', type=int, default=42, help='seed for initializing training.')
parser = ImageNetLightningModel.add_model_specific_args(parent_parser)
parser.set_defaults(
profiler="simple",

View File

@ -70,7 +70,7 @@ class DQN(nn.Module):
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
nn.Linear(hidden_size, n_actions),
)
def forward(self, x):
@ -78,9 +78,7 @@ class DQN(nn.Module):
# Named tuple for storing experience steps gathered in training
Experience = namedtuple(
'Experience', field_names=['state', 'action', 'reward',
'done', 'new_state'])
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'done', 'new_state'])
class ReplayBuffer:
@ -114,8 +112,13 @@ class ReplayBuffer:
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool), np.array(next_states))
return (
np.array(states),
np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(dones, dtype=np.bool),
np.array(next_states),
)
class RLDataset(IterableDataset):
@ -236,20 +239,21 @@ class DQNLightning(pl.LightningModule):
)
)
"""
def __init__(
self,
env: str,
replay_size: int = 200,
warm_start_steps: int = 200,
gamma: float = 0.99,
eps_start: float = 1.0,
eps_end: float = 0.01,
eps_last_frame: int = 200,
sync_rate: int = 10,
lr: float = 1e-2,
episode_length: int = 50,
batch_size: int = 4,
**kwargs,
self,
env: str,
replay_size: int = 200,
warm_start_steps: int = 200,
gamma: float = 0.99,
eps_start: float = 1.0,
eps_end: float = 0.01,
eps_last_frame: int = 200,
sync_rate: int = 10,
lr: float = 1e-2,
episode_length: int = 50,
batch_size: int = 4,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.replay_size = replay_size
@ -353,9 +357,11 @@ class DQNLightning(pl.LightningModule):
if self.global_step % self.sync_rate == 0:
self.target_net.load_state_dict(self.net.state_dict())
log = {'total_reward': torch.tensor(self.total_reward).to(device),
'reward': torch.tensor(reward).to(device),
'steps': torch.tensor(self.global_step).to(device)}
log = {
'total_reward': torch.tensor(self.total_reward).to(device),
'reward': torch.tensor(reward).to(device),
'steps': torch.tensor(self.global_step).to(device)
}
return OrderedDict({'loss': loss, 'log': log, 'progress_bar': log})
@ -389,21 +395,20 @@ class DQNLightning(pl.LightningModule):
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag")
parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
parser.add_argument("--sync_rate", type=int, default=10,
help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000,
help="capacity of the replay buffer")
parser.add_argument("--warm_start_size", type=int, default=1000,
help="how many samples do we use to fill our buffer at the start of training")
parser.add_argument("--eps_last_frame", type=int, default=1000,
help="what frame should epsilon stop decaying")
parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network")
parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")
parser.add_argument(
"--warm_start_size",
type=int,
default=1000,
help="how many samples do we use to fill our buffer at the start of training"
)
parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying")
parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
parser.add_argument("--max_episode_reward", type=int, default=200,
help="max episode reward in the environment")
parser.add_argument("--warm_start_steps", type=int, default=1000,
help="max episode reward in the environment")
parser.add_argument("--max_episode_reward", type=int, default=200, help="max episode reward in the environment")
parser.add_argument("--warm_start_steps", type=int, default=1000, help="max episode reward in the environment")
return parser
@ -413,7 +418,7 @@ def main(args) -> None:
trainer = pl.Trainer(
gpus=1,
accelerator='dp',
val_check_interval=100
val_check_interval=100,
)
trainer.fit(model)

View File

@ -50,7 +50,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions)
nn.Linear(hidden_size, n_actions),
)
return network
@ -159,6 +159,7 @@ class PPOLightning(pl.LightningModule):
trainer = Trainer()
trainer.fit(model)
"""
def __init__(
self,
env: str,
@ -173,7 +174,6 @@ class PPOLightning(pl.LightningModule):
clip_ratio: float = 0.2,
**kwargs,
) -> None:
"""
Args:
env: gym environment tag
@ -213,8 +213,10 @@ class PPOLightning(pl.LightningModule):
actor_mlp = create_mlp(self.env.observation_space.shape, self.env.action_space.n)
self.actor = ActorCategorical(actor_mlp)
else:
raise NotImplementedError('Env action space should be of type Box (continous) or Discrete (categorical). '
f'Got type: {type(self.env.action_space)}')
raise NotImplementedError(
'Env action space should be of type Box (continous) or Discrete (categorical).'
f' Got type: {type(self.env.action_space)}'
)
self.batch_states = []
self.batch_actions = []
@ -287,9 +289,7 @@ class PPOLightning(pl.LightningModule):
return adv
def generate_trajectory_samples(
self,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
def generate_trajectory_samples(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
"""
Contains the logic for generating trajectory data to train policy and value network
Yield:
@ -345,8 +345,8 @@ class PPOLightning(pl.LightningModule):
if epoch_end:
train_data = zip(
self.batch_states, self.batch_actions, self.batch_logp,
self.batch_qvals, self.batch_adv)
self.batch_states, self.batch_actions, self.batch_logp, self.batch_qvals, self.batch_adv
)
for state, action, logp_old, qval, adv in train_data:
yield state, action, logp_old, qval, adv
@ -454,12 +454,18 @@ class PPOLightning(pl.LightningModule):
parser.add_argument("--lr_critic", type=float, default=1e-3, help="learning rate of critic network")
parser.add_argument("--max_episode_len", type=int, default=1000, help="capacity of the replay buffer")
parser.add_argument("--batch_size", type=int, default=512, help="batch_size when training network")
parser.add_argument("--steps_per_epoch", type=int, default=2048,
help="how many action-state pairs to rollout for trajectory collection per epoch")
parser.add_argument("--nb_optim_iters", type=int, default=4,
help="how many steps of gradient descent to perform on each batch")
parser.add_argument("--clip_ratio", type=float, default=0.2,
help="hyperparameter for clipping in the policy objective")
parser.add_argument(
"--steps_per_epoch",
type=int,
default=2048,
help="how many action-state pairs to rollout for trajectory collection per epoch"
)
parser.add_argument(
"--nb_optim_iters", type=int, default=4, help="how many steps of gradient descent to perform on each batch"
)
parser.add_argument(
"--clip_ratio", type=float, default=0.2, help="hyperparameter for clipping in the policy objective"
)
return parser

View File

@ -178,15 +178,16 @@ class SegModel(pl.LightningModule):
)
)
"""
def __init__(
self,
data_path: str,
batch_size: int = 4,
lr: float = 1e-3,
num_layers: int = 3,
features_start: int = 64,
bilinear: bool = False,
**kwargs,
self,
data_path: str,
batch_size: int = 4,
lr: float = 1e-3,
num_layers: int = 3,
features_start: int = 64,
bilinear: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.data_path = data_path
@ -196,12 +197,12 @@ class SegModel(pl.LightningModule):
self.features_start = features_start
self.bilinear = bilinear
self.net = UNet(num_classes=19, num_layers=self.num_layers,
features_start=self.features_start, bilinear=self.bilinear)
self.net = UNet(
num_classes=19, num_layers=self.num_layers, features_start=self.features_start, bilinear=self.bilinear
)
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
])
self.trainset = KITTI(self.data_path, split='train', transform=self.transform)
self.validset = KITTI(self.data_path, split='valid', transform=self.transform)
@ -250,8 +251,12 @@ class SegModel(pl.LightningModule):
parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
parser.add_argument("--bilinear", action='store_true', default=False,
help="whether to use bilinear interpolation or transposed")
parser.add_argument(
"--bilinear",
action='store_true',
default=False,
help="whether to use bilinear interpolation or transposed"
)
return parser

View File

@ -36,11 +36,11 @@ class UNet(nn.Module):
"""
def __init__(
self,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
self,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False,
):
"""
Args:
@ -97,7 +97,7 @@ class DoubleConv(nn.Module):
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
)
def forward(self, x):
@ -121,10 +121,7 @@ class Down(nn.Module):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConv(in_ch, out_ch)
)
self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))
def forward(self, x):
return self.net(x)

View File

@ -53,11 +53,13 @@ ARGS_DDP_AMP = ARGS_DEFAULT + """
"""
@pytest.mark.parametrize('import_cli', [
'pl_examples.basic_examples.simple_image_classifier',
'pl_examples.basic_examples.backbone_image_classifier',
'pl_examples.basic_examples.autoencoder',
])
@pytest.mark.parametrize(
'import_cli', [
'pl_examples.basic_examples.simple_image_classifier',
'pl_examples.basic_examples.backbone_image_classifier',
'pl_examples.basic_examples.autoencoder',
]
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize('cli_args', [ARGS_DP, ARGS_DP_AMP])
def test_examples_dp(tmpdir, import_cli, cli_args):
@ -88,11 +90,13 @@ def test_examples_dp(tmpdir, import_cli, cli_args):
# module.cli_main()
@pytest.mark.parametrize('import_cli', [
'pl_examples.basic_examples.simple_image_classifier',
'pl_examples.basic_examples.backbone_image_classifier',
'pl_examples.basic_examples.autoencoder',
])
@pytest.mark.parametrize(
'import_cli', [
'pl_examples.basic_examples.simple_image_classifier',
'pl_examples.basic_examples.backbone_image_classifier',
'pl_examples.basic_examples.autoencoder',
]
)
@pytest.mark.parametrize('cli_args', [ARGS_DEFAULT])
def test_examples_cpu(tmpdir, import_cli, cli_args):