Fix domain_template scripts (#2014)
* Fix domain_templates * Fix type of fake labels * type * args
This commit is contained in:
parent
82a20296e3
commit
0914873bc2
|
@ -450,5 +450,4 @@ def get_args() -> argparse.Namespace:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
main(get_args())
|
||||
|
|
|
@ -7,7 +7,7 @@ After a few epochs, launch TensorBoard to see the images being generated at ever
|
|||
tensorboard --logdir default
|
||||
"""
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
@ -183,7 +183,7 @@ class GAN(LightningModule):
|
|||
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args: Namespace) -> None:
|
||||
# ------------------------
|
||||
# 1 INIT LIGHTNING MODEL
|
||||
# ------------------------
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
import argparse
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import os
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
@ -183,7 +183,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser): # pragma: no-cover
|
||||
parser = argparse.ArgumentParser(parents=[parent_parser])
|
||||
parser = ArgumentParser(parents=[parent_parser])
|
||||
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=MODEL_NAMES,
|
||||
help='model architecture: ' +
|
||||
' | '.join(MODEL_NAMES) +
|
||||
|
@ -210,7 +210,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
|
||||
|
||||
def get_args():
|
||||
parent_parser = argparse.ArgumentParser(add_help=False)
|
||||
parent_parser = ArgumentParser(add_help=False)
|
||||
parent_parser.add_argument('--data-path', metavar='DIR', type=str,
|
||||
help='path to dataset')
|
||||
parent_parser.add_argument('--save-path', metavar='DIR', default=".", type=str,
|
||||
|
@ -228,20 +228,23 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(hparams):
|
||||
model = ImageNetLightningModel(hparams)
|
||||
if hparams.seed is not None:
|
||||
random.seed(hparams.seed)
|
||||
torch.manual_seed(hparams.seed)
|
||||
def main(args: Namespace) -> None:
|
||||
model = ImageNetLightningModel(**vars(args))
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
cudnn.deterministic = True
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=hparams.save_path,
|
||||
gpus=hparams.gpus,
|
||||
max_epochs=hparams.epochs,
|
||||
distributed_backend=hparams.distributed_backend,
|
||||
precision=16 if hparams.use_16bit else 32,
|
||||
default_root_dir=args.save_path,
|
||||
gpus=args.gpus,
|
||||
max_epochs=args.epochs,
|
||||
distributed_backend=args.distributed_backend,
|
||||
precision=16 if args.use_16bit else 32,
|
||||
)
|
||||
if hparams.evaluate:
|
||||
|
||||
if args.evaluate:
|
||||
trainer.run_evaluation()
|
||||
else:
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue