Fix domain_template scripts (#2014)

* Fix domain_templates

* Fix type of fake labels

* type

* args
This commit is contained in:
Rohit Gupta 2020-06-01 21:08:52 +05:30 committed by GitHub
parent 82a20296e3
commit 0914873bc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 17 deletions

View File

@ -450,5 +450,4 @@ def get_args() -> argparse.Namespace:
if __name__ == '__main__':
main(get_args())

View File

@ -7,7 +7,7 @@ After a few epochs, launch TensorBoard to see the images being generated at ever
tensorboard --logdir default
"""
import os
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
import numpy as np
@ -183,7 +183,7 @@ class GAN(LightningModule):
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
def main(args):
def main(args: Namespace) -> None:
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------

View File

@ -1,7 +1,7 @@
"""
This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
import argparse
from argparse import ArgumentParser, Namespace
import os
import random
from collections import OrderedDict
@ -183,7 +183,7 @@ class ImageNetLightningModel(LightningModule):
@staticmethod
def add_model_specific_args(parent_parser): # pragma: no-cover
parser = argparse.ArgumentParser(parents=[parent_parser])
parser = ArgumentParser(parents=[parent_parser])
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES,
help='model architecture: ' +
' | '.join(MODEL_NAMES) +
@ -210,7 +210,7 @@ class ImageNetLightningModel(LightningModule):
def get_args():
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser = ArgumentParser(add_help=False)
parent_parser.add_argument('--data-path', metavar='DIR', type=str,
help='path to dataset')
parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str,
@ -228,20 +228,23 @@ def get_args():
return parser.parse_args()
def main(hparams):
model = ImageNetLightningModel(hparams)
if hparams.seed is not None:
random.seed(hparams.seed)
torch.manual_seed(hparams.seed)
def main(args: Namespace) -> None:
model = ImageNetLightningModel(**vars(args))
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
trainer = pl.Trainer(
default_root_dir=hparams.save_path,
gpus=hparams.gpus,
max_epochs=hparams.epochs,
distributed_backend=hparams.distributed_backend,
precision=16 if hparams.use_16bit else 32,
default_root_dir=args.save_path,
gpus=args.gpus,
max_epochs=args.epochs,
distributed_backend=args.distributed_backend,
precision=16 if args.use_16bit else 32,
)
if hparams.evaluate:
if args.evaluate:
trainer.run_evaluation()
else:
trainer.fit(model)