Skepticleo trainer argparser (#1023)
* Added default parser for trainer and class method to construct trainer from default args * Removed print statement * Added test for constructing Trainer from command line args * Removed extra line * Removed redundant imports, removed whitespace from empty lines * Fixed typo * Updated default parser creation to get class attributes automatically * Updated default parser creation to get class attributes automatically * Added method to get default args for trainer * Trimmed trainer get default args method * Updated from argparse method to not return trainer with static arguments * Update trainer get default args to classmethod * adjustment * fix * Fixed variable name * Update trainer.py * Update test_trainer.py * Update trainer.py * Update tests/trainer/test_trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update trainer.py * Update test_trainer.py * Update trainer.py * Update test_trainer.py * Update tests/trainer/test_trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update trainer.py * Update test_trainer.py Co-authored-by: Mudit Tanwani <mudittanwani@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
05676de2d9
commit
4c5e82c065
|
@ -3,6 +3,7 @@ import sys
|
|||
import warnings
|
||||
import logging as log
|
||||
from typing import Union, Optional, List, Dict, Tuple, Iterable
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -116,6 +117,7 @@ class Trainer(TrainerIOMixin,
|
|||
profiler: Optional[BaseProfiler] = None,
|
||||
benchmark: bool = False,
|
||||
reload_dataloaders_every_epoch: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
|
||||
|
@ -627,6 +629,7 @@ class Trainer(TrainerIOMixin,
|
|||
|
||||
# Transfer params
|
||||
# Backward compatibility
|
||||
self.num_nodes = num_nodes
|
||||
if nb_gpu_nodes is not None:
|
||||
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
|
||||
" and this method will be removed in v0.8.0", DeprecationWarning)
|
||||
|
@ -747,10 +750,12 @@ class Trainer(TrainerIOMixin,
|
|||
self.weights_save_path = weights_save_path
|
||||
|
||||
# accumulated grads
|
||||
self.accumulate_grad_batches = accumulate_grad_batches
|
||||
self.configure_accumulated_gradients(accumulate_grad_batches)
|
||||
|
||||
# allow int, string and gpu list
|
||||
self.data_parallel_device_ids = parse_gpu_ids(gpus)
|
||||
self.gpus = gpus
|
||||
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
|
||||
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
||||
|
||||
# tpu state flags
|
||||
|
@ -797,6 +802,7 @@ class Trainer(TrainerIOMixin,
|
|||
self.row_log_interval = row_log_interval
|
||||
|
||||
# how much of the data to use
|
||||
self.overfit_pct = overfit_pct
|
||||
self.determine_data_use_amount(train_percent_check, val_percent_check,
|
||||
test_percent_check, overfit_pct)
|
||||
|
||||
|
@ -822,6 +828,28 @@ class Trainer(TrainerIOMixin,
|
|||
job_id = None
|
||||
return job_id
|
||||
|
||||
@classmethod
|
||||
def default_attributes(cls):
|
||||
return vars(cls())
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||
"""Extend existing argparse by default `Trainer` attributes."""
|
||||
parser = ArgumentParser(parents=[parent_parser])
|
||||
|
||||
trainer_default_params = Trainer.default_attributes()
|
||||
|
||||
for arg in trainer_default_params:
|
||||
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args):
|
||||
|
||||
params = vars(args)
|
||||
return cls(**params)
|
||||
|
||||
def __parse_gpu_ids(self, gpus):
|
||||
"""Parse GPUs id.
|
||||
|
||||
|
|
|
@ -52,8 +52,6 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
log.info(param, param.grad)
|
||||
|
||||
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
||||
self.accumulate_grad_batches = None
|
||||
|
||||
if isinstance(accumulate_grad_batches, dict):
|
||||
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
||||
elif isinstance(accumulate_grad_batches, int):
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import math
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
import tests.models.utils as tutils
|
||||
from unittest import mock
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
|
@ -600,3 +601,22 @@ def test_testpass_overrides(tmpdir):
|
|||
|
||||
model = LightningTestModel(hparams)
|
||||
Trainer().test(model)
|
||||
|
||||
@mock.patch('argparse.ArgumentParser.parse_args',
|
||||
return_value=argparse.Namespace(**Trainer.default_attributes()))
|
||||
def test_default_args(tmpdir):
|
||||
"""Tests default argument parser for Trainer"""
|
||||
tutils.reset_seed()
|
||||
|
||||
# logger file to get meta
|
||||
logger = tutils.get_test_tube_logger(tmpdir, False)
|
||||
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.parse_args()
|
||||
args.logger = logger
|
||||
|
||||
args.max_epochs = 5
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
|
||||
assert isinstance(trainer, Trainer)
|
||||
assert trainer.max_epochs == 5
|
||||
|
|
Loading…
Reference in New Issue