diff --git a/pl_examples/basic_examples/cpu_template.py b/pl_examples/basic_examples/cpu_template.py index 0714b1aac0..7781ba3271 100644 --- a/pl_examples/basic_examples/cpu_template.py +++ b/pl_examples/basic_examples/cpu_template.py @@ -28,7 +28,7 @@ def main(hparams): # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = pl.Trainer() + trainer = pl.Trainer(max_epochs=hparams.epochs) # ------------------------ # 3 START TRAINING diff --git a/pl_examples/basic_examples/gpu_template.py b/pl_examples/basic_examples/gpu_template.py index c661eef65f..090b2adcba 100644 --- a/pl_examples/basic_examples/gpu_template.py +++ b/pl_examples/basic_examples/gpu_template.py @@ -29,6 +29,7 @@ def main(hparams): # 2 INIT TRAINER # ------------------------ trainer = pl.Trainer( + max_epochs=hparams.epochs, gpus=hparams.gpus, distributed_backend=hparams.distributed_backend, use_amp=hparams.use_16bit