diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d59884d2e0..9b451e8ef8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -3,6 +3,7 @@ import sys import warnings import logging as log from typing import Union, Optional, List, Dict, Tuple, Iterable +from argparse import ArgumentParser import torch import torch.distributed as dist @@ -116,6 +117,7 @@ class Trainer(TrainerIOMixin, profiler: Optional[BaseProfiler] = None, benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, + **kwargs ): r""" @@ -627,6 +629,7 @@ class Trainer(TrainerIOMixin, # Transfer params # Backward compatibility + self.num_nodes = num_nodes if nb_gpu_nodes is not None: warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) @@ -747,10 +750,12 @@ class Trainer(TrainerIOMixin, self.weights_save_path = weights_save_path # accumulated grads + self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # allow int, string and gpu list - self.data_parallel_device_ids = parse_gpu_ids(gpus) + self.gpus = gpus + self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) # tpu state flags @@ -797,6 +802,7 @@ class Trainer(TrainerIOMixin, self.row_log_interval = row_log_interval # how much of the data to use + self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) @@ -822,6 +828,28 @@ class Trainer(TrainerIOMixin, job_id = None return job_id + @classmethod + def default_attributes(cls): + return vars(cls()) + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: + """Extend existing argparse by default `Trainer` attributes.""" + parser = ArgumentParser(parents=[parent_parser]) + + trainer_default_params = Trainer.default_attributes() + + for arg in trainer_default_params: + parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg) + + return parser + + @classmethod + def from_argparse_args(cls, args): + + params = vars(args) + return cls(**params) + def __parse_gpu_ids(self, gpus): """Parse GPUs id. diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 373de37058..c468e1ba61 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -52,8 +52,6 @@ class TrainerTrainingTricksMixin(ABC): log.info(param, param.grad) def configure_accumulated_gradients(self, accumulate_grad_batches): - self.accumulate_grad_batches = None - if isinstance(accumulate_grad_batches, dict): self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) elif isinstance(accumulate_grad_batches, int): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4c16c92129..d95fce3e5c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,10 +1,11 @@ import math import os - import pytest import torch +import argparse import tests.models.utils as tutils +from unittest import mock from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( EarlyStopping, @@ -600,3 +601,22 @@ def test_testpass_overrides(tmpdir): model = LightningTestModel(hparams) Trainer().test(model) + +@mock.patch('argparse.ArgumentParser.parse_args', + return_value=argparse.Namespace(**Trainer.default_attributes())) +def test_default_args(tmpdir): + """Tests default argument parser for Trainer""" + tutils.reset_seed() + + # logger file to get meta + logger = tutils.get_test_tube_logger(tmpdir, False) + + parser = argparse.ArgumentParser(add_help=False) + args = parser.parse_args() + args.logger = logger + + args.max_epochs = 5 + trainer = Trainer.from_argparse_args(args) + + assert isinstance(trainer, Trainer) + assert trainer.max_epochs == 5