Skepticleo trainer argparser (#1023)

* Added default parser for trainer and class method to construct trainer from default args

* Removed print statement

* Added test for constructing Trainer from command line args

* Removed extra line

* Removed redundant imports, removed whitespace from empty lines

* Fixed typo

* Updated default parser creation to get class attributes automatically

* Updated default parser creation to get class attributes automatically

* Added method to get default args for trainer

* Trimmed trainer get default args method

* Updated from argparse method to not return trainer with static arguments

* Update trainer get default args to classmethod

* adjustment

* fix

* Fixed variable name

* Update trainer.py

* Update test_trainer.py

* Update trainer.py

* Update tests/trainer/test_trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update trainer.py

* Update test_trainer.py

* Update trainer.py

* Update test_trainer.py

* Update tests/trainer/test_trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update trainer.py

* Update test_trainer.py

Co-authored-by: Mudit Tanwani <mudittanwani@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-03-03 09:32:15 -05:00 committed by GitHub
parent 05676de2d9
commit 4c5e82c065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 4 deletions

View File

@ -3,6 +3,7 @@ import sys
import warnings
import logging as log
from typing import Union, Optional, List, Dict, Tuple, Iterable
from argparse import ArgumentParser
import torch
import torch.distributed as dist
@ -116,6 +117,7 @@ class Trainer(TrainerIOMixin,
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
**kwargs
):
r"""
@ -627,6 +629,7 @@ class Trainer(TrainerIOMixin,
# Transfer params
# Backward compatibility
self.num_nodes = num_nodes
if nb_gpu_nodes is not None:
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
@ -747,10 +750,12 @@ class Trainer(TrainerIOMixin,
self.weights_save_path = weights_save_path
# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)
# allow int, string and gpu list
self.data_parallel_device_ids = parse_gpu_ids(gpus)
self.gpus = gpus
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
# tpu state flags
@ -797,6 +802,7 @@ class Trainer(TrainerIOMixin,
self.row_log_interval = row_log_interval
# how much of the data to use
self.overfit_pct = overfit_pct
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)
@ -822,6 +828,28 @@ class Trainer(TrainerIOMixin,
job_id = None
return job_id
@classmethod
def default_attributes(cls):
return vars(cls())
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser])
trainer_default_params = Trainer.default_attributes()
for arg in trainer_default_params:
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg)
return parser
@classmethod
def from_argparse_args(cls, args):
params = vars(args)
return cls(**params)
def __parse_gpu_ids(self, gpus):
"""Parse GPUs id.

View File

@ -52,8 +52,6 @@ class TrainerTrainingTricksMixin(ABC):
log.info(param, param.grad)
def configure_accumulated_gradients(self, accumulate_grad_batches):
self.accumulate_grad_batches = None
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):

View File

@ -1,10 +1,11 @@
import math
import os
import pytest
import torch
import argparse
import tests.models.utils as tutils
from unittest import mock
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
@ -600,3 +601,22 @@ def test_testpass_overrides(tmpdir):
model = LightningTestModel(hparams)
Trainer().test(model)
@mock.patch('argparse.ArgumentParser.parse_args',
return_value=argparse.Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()
# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)
parser = argparse.ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger
args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5