From 0914873bc2f588ae592c896460c0c9e6577324cf Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 1 Jun 2020 21:08:52 +0530 Subject: [PATCH] Fix domain_template scripts (#2014) * Fix domain_templates * Fix type of fake labels * type * args --- .../computer_vision_fine_tuning.py | 1 - .../generative_adversarial_net.py | 4 +-- pl_examples/domain_templates/imagenet.py | 31 ++++++++++--------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 4371c86945..e2db1b98fd 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -450,5 +450,4 @@ def get_args() -> argparse.Namespace: if __name__ == '__main__': - main(get_args()) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 68197e32b8..4417e5e02c 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -7,7 +7,7 @@ After a few epochs, launch TensorBoard to see the images being generated at ever tensorboard --logdir default """ import os -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from collections import OrderedDict import numpy as np @@ -183,7 +183,7 @@ class GAN(LightningModule): self.logger.experiment.add_image('generated_images', grid, self.current_epoch) -def main(args): +def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index e6584d7655..19a85b8794 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -1,7 +1,7 @@ """ This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py """ -import argparse +from argparse import ArgumentParser, Namespace import os import random from collections import OrderedDict @@ -183,7 +183,7 @@ class ImageNetLightningModel(LightningModule): @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover - parser = argparse.ArgumentParser(parents=[parent_parser]) + parser = ArgumentParser(parents=[parent_parser]) parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES, help='model architecture: ' + ' | '.join(MODEL_NAMES) + @@ -210,7 +210,7 @@ class ImageNetLightningModel(LightningModule): def get_args(): - parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = ArgumentParser(add_help=False) parent_parser.add_argument('--data-path', metavar='DIR', type=str, help='path to dataset') parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str, @@ -228,20 +228,23 @@ def get_args(): return parser.parse_args() -def main(hparams): - model = ImageNetLightningModel(hparams) - if hparams.seed is not None: - random.seed(hparams.seed) - torch.manual_seed(hparams.seed) +def main(args: Namespace) -> None: + model = ImageNetLightningModel(**vars(args)) + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) cudnn.deterministic = True + trainer = pl.Trainer( - default_root_dir=hparams.save_path, - gpus=hparams.gpus, - max_epochs=hparams.epochs, - distributed_backend=hparams.distributed_backend, - precision=16 if hparams.use_16bit else 32, + default_root_dir=args.save_path, + gpus=args.gpus, + max_epochs=args.epochs, + distributed_backend=args.distributed_backend, + precision=16 if args.use_16bit else 32, ) - if hparams.evaluate: + + if args.evaluate: trainer.run_evaluation() else: trainer.fit(model)