lightning/tests/models/data/ddp/train_test_variations.py

45 lines
1006 B
Python

"""
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
"""
from argparse import ArgumentParser
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
def variation_fit(trainer, model):
trainer.fit(model)
def variation_test(trainer, model):
trainer.test(model)
def get_variations():
variations = [
"variation_fit",
"variation_test",
]
return variations
def main():
seed_everything(1234)
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--variation', default=variation_fit.__name__)
parser.set_defaults(gpus=2)
parser.set_defaults(distributed_backend="ddp")
args = parser.parse_args()
model = EvalModelTemplate()
trainer = Trainer.from_argparse_args(args)
# run the chosen variation
run_variation = globals()[args.variation]
run_variation(trainer, model)
if __name__ == '__main__':
main()