Fix weights path (#1445)

* renamed default path to actual root_dir

* added default weights path

* added default weights path

* added default weights path
This commit is contained in:
William Falcon 2020-04-10 12:02:59 -04:00 committed by GitHub
parent 7ac1580a31
commit b78c3d4da8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 119 additions and 101 deletions

View File

@ -222,7 +222,7 @@ def main(hparams):
torch.manual_seed(hparams.seed)
cudnn.deterministic = True
trainer = pl.Trainer(
default_save_path=hparams.save_path,
default_root_dir=hparams.save_path,
gpus=hparams.gpus,
max_epochs=hparams.epochs,
distributed_backend=hparams.distributed_backend,

View File

@ -21,7 +21,7 @@ or ``tuple`` of loggers.
trainer = Trainer(logger=[tb_logger, comet_logger])
.. note:: All loggers log by default to ``os.getcwd()``. To change the path without creating a logger set
``Trainer(default_save_path='/your/path/to/save/checkpoints')``
``Trainer(default_root_dir='/your/path/to/save/checkpoints')``
Custom logger
-------------

View File

@ -211,7 +211,7 @@ Example::
prefix=''
)
default_save_path
default_root_dir
^^^^^^^^^^^^^^^^^
Default path for logs and weights when no logger
@ -222,7 +222,7 @@ are stored. If you don't then use this method for convenience.
Example::
# default used by the Trainer
trainer = Trainer(default_save_path=os.getcwd())
trainer = Trainer(default_root_path=os.getcwd())
distributed_backend
^^^^^^^^^^^^^^^^^^^

View File

@ -10,7 +10,7 @@ class TrainerCallbackConfigMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
default_save_path: str
default_root_dir: str
logger: Union[LightningLoggerBase, bool]
weights_save_path: str
ckpt_path: str
@ -32,13 +32,18 @@ class TrainerCallbackConfigMixin(ABC):
User provided weights_saved_path
Otherwise use os.getcwd()
"""
ckpt_path = self.default_save_path
ckpt_path = self.default_root_dir
if self.checkpoint_callback is True:
# init a default one
if self.logger is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_save_path)
self.default_root_dir)
# weights_save_path overrides anything
if self.weights_save_path is not None:
save_dir = self.weights_save_path
ckpt_path = os.path.join(
save_dir,
self.logger.name,
@ -46,7 +51,7 @@ class TrainerCallbackConfigMixin(ABC):
"checkpoints"
)
else:
ckpt_path = os.path.join(self.default_save_path, "checkpoints")
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overriden('validation_step')
@ -72,7 +77,7 @@ class TrainerCallbackConfigMixin(ABC):
# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_save_path
self.weights_save_path = self.default_root_dir
def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True or None:

View File

@ -143,7 +143,7 @@ class TrainerDDPMixin(ABC):
distributed_backend: str
amp_level: str
use_tpu: bool
default_save_path: str
default_root_dir: str
@property
@abstractmethod
@ -354,7 +354,7 @@ class TrainerDDPMixin(ABC):
:return:
"""
if self.proc_rank == 0:
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
def load_spawn_weights(self, original_model):
@ -369,7 +369,7 @@ class TrainerDDPMixin(ABC):
if self.proc_rank == 0:
# load weights saved in ddp
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)
# copy loaded weights to old model

View File

@ -20,7 +20,7 @@ class TrainerLoggingMixin(ABC):
proc_rank: int
use_dp: bool
use_ddp2: bool
default_save_path: str
default_root_dir: str
slurm_job_id: int
num_gpus: int
@ -28,7 +28,7 @@ class TrainerLoggingMixin(ABC):
if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_save_path,
save_dir=self.default_root_dir,
version=self.slurm_job_id,
name='lightning_logs'
)

View File

@ -85,7 +85,7 @@ class Trainer(
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
callbacks: List[Callback] = [],
default_save_path: Optional[str] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
process_position: int = 0,
num_nodes: int = 1,
@ -122,6 +122,7 @@ class Trainer(
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
@ -144,7 +145,12 @@ class Trainer(
callbacks: Add a list of callbacks.
default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed
default_save_path:
.. warning:: .. deprecated:: 0.7.3
Use `default_root_dir` instead. Will remove 0.9.0.
gradient_clip_val: 0 means don't clip.
@ -244,7 +250,9 @@ class Trainer(
weights_summary: Prints a summary of the weights when training begins.
weights_save_path: Where to save weights if specified.
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
amp_level: The optimization level to use (O1, O2, etc...).
@ -348,9 +356,14 @@ class Trainer(
' val and test loop using a single batch')
# set default save path if user didn't provide one
self.default_save_path = default_save_path
if self.default_save_path is None:
self.default_save_path = os.getcwd()
self.default_root_dir = default_root_dir
# Backward compatibility, TODO: remove in v0.8.0
if default_save_path is not None:
self.default_root_dir = default_save_path
if self.default_root_dir is None:
self.default_root_dir = os.getcwd()
# training bookeeping
self.total_batch_idx = 0
@ -917,7 +930,7 @@ class Trainer(
self.fit(model)
elif self.use_ddp or self.use_tpu: # pragma: no-cover
# attempt to load weights from a spawn
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
test_model = self.model
if os.path.exists(path):
test_model = self.load_spawn_weights(self.model)

View File

@ -5,7 +5,7 @@ Lightning can automate saving and loading checkpoints
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in::
Trainer(default_save_path='/your/path/to/save/checkpoints')
Trainer(default_root_dir='/your/path/to/save/checkpoints')
To modify the behavior of checkpointing pass in your own callback.

View File

@ -21,7 +21,7 @@ ROOT_PATH = os.path.abspath(os.path.dirname(__file__))
def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
# save_dir = trainer_options['default_save_path']
# save_dir = trainer_options['default_root_dir']
# fit model
trainer = Trainer(**trainer_options)
@ -33,7 +33,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
# test model loading
pretrained_model = load_model(trainer.logger,
trainer.checkpoint_callback.dirpath,
path_expt=trainer_options.get('default_save_path'))
path_expt=trainer_options.get('default_root_dir'))
# test new model accuracy
test_loaders = model.test_dataloader()
@ -50,7 +50,7 @@ def run_model_test_no_loggers(trainer_options, model, min_acc=0.50):
def run_model_test(trainer_options, model, on_gpu=True):
save_dir = trainer_options['default_save_path']
save_dir = trainer_options['default_root_dir']
# logger file to get meta
logger = get_default_testtube_logger(save_dir, False)

View File

@ -81,7 +81,7 @@ def test_custom_logger(tmpdir):
max_epochs=1,
train_percent_check=0.05,
logger=logger,
default_save_path=tmpdir
default_root_dir=tmpdir
)
trainer = Trainer(**trainer_options)
@ -103,7 +103,7 @@ def test_multiple_loggers(tmpdir):
max_epochs=1,
train_percent_check=0.05,
logger=[logger1, logger2],
default_save_path=tmpdir
default_root_dir=tmpdir
)
trainer = Trainer(**trainer_options)
@ -162,7 +162,7 @@ def test_adding_step_key(tmpdir):
model.training_epoch_end = _training_epoch_end
trainer_options = dict(
max_epochs=4,
default_save_path=tmpdir,
default_root_dir=tmpdir,
train_percent_check=0.001,
val_percent_check=0.01,
num_sanity_val_steps=0,

View File

@ -35,7 +35,7 @@ def test_comet_logger(tmpdir, monkeypatch):
)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
@ -145,7 +145,7 @@ def test_comet_pickle(tmpdir, monkeypatch):
)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)

View File

@ -25,7 +25,7 @@ def test_mlflow_logger(tmpdir):
logger.log_metrics({'acc': 'test'})
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
@ -43,7 +43,7 @@ def test_mlflow_pickle(tmpdir):
mlflow_dir = os.path.join(tmpdir, 'mlruns')
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)

View File

@ -18,7 +18,7 @@ def test_neptune_logger(tmpdir):
logger = NeptuneLogger(offline_mode=True)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
@ -87,7 +87,7 @@ def test_neptune_pickle(tmpdir):
logger = NeptuneLogger(offline_mode=True)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)
@ -109,7 +109,7 @@ def test_neptune_leave_open_experiment_after_fit(tmpdir):
logger._experiment = MagicMock()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger

View File

@ -16,7 +16,7 @@ def test_testtube_logger(tmpdir):
assert logger.name == 'lightning_logs'
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
@ -39,7 +39,7 @@ def test_testtube_pickle(tmpdir):
logger.save()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger

View File

@ -19,7 +19,7 @@ def test_trains_logger(tmpdir):
logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
@ -45,7 +45,7 @@ def test_trains_pickle(tmpdir):
logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)

View File

@ -20,7 +20,7 @@ def test_amp_single_gpu(tmpdir):
model = LightningTestModel(hparams)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
gpus=1,
distributed_backend='ddp',
@ -40,7 +40,7 @@ def test_no_amp_single_gpu(tmpdir):
model = LightningTestModel(hparams)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
gpus=1,
distributed_backend='dp',
@ -63,7 +63,7 @@ def test_amp_gpu_ddp(tmpdir):
model = LightningTestModel(hparams)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
gpus=2,
distributed_backend='ddp',
@ -123,7 +123,7 @@ def test_cpu_model_with_amp(tmpdir):
tutils.reset_seed()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
logger=tutils.get_default_testtube_logger(tmpdir),
max_epochs=1,
@ -146,7 +146,7 @@ def test_amp_gpu_dp(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
gpus='0, 1', # test init with gpu string
distributed_backend='dp',

View File

@ -22,7 +22,7 @@ def test_early_stopping_cpu_model(tmpdir):
stopping = EarlyStopping(monitor='val_loss', min_delta=0.1)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
early_stop_callback=stopping,
gradient_clip_val=1.0,
overfit_pct=0.20,
@ -45,7 +45,7 @@ def test_lbfgs_cpu_model(tmpdir):
tutils.reset_seed()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=2,
progress_bar_refresh_rate=0,
weights_summary='top',
@ -62,7 +62,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
tutils.reset_seed()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
@ -93,7 +93,7 @@ def test_running_test_after_fitting(tmpdir):
checkpoint = tutils.init_checkpoint_callback(logger)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=8,
train_percent_check=0.4,
@ -205,7 +205,7 @@ def test_simple_cpu(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
@ -224,7 +224,7 @@ def test_cpu_model(tmpdir):
tutils.reset_seed()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
logger=tutils.get_default_testtube_logger(tmpdir),
max_epochs=1,
@ -242,7 +242,7 @@ def test_all_features_cpu_model(tmpdir):
tutils.reset_seed()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
@ -308,7 +308,7 @@ def test_tbptt_cpu_model(tmpdir):
)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
@ -339,7 +339,7 @@ def test_single_gpu_model(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,

View File

@ -26,7 +26,7 @@ def test_multi_gpu_model_ddp2(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
@ -47,7 +47,7 @@ def test_multi_gpu_model_ddp(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
@ -67,7 +67,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
tutils.set_random_master_port()
model, hparams = tutils.get_default_model()
trainer_options = dict(default_save_path=tmpdir,
trainer_options = dict(default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.4,
@ -163,7 +163,7 @@ def test_multi_gpu_none_backend(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
@ -182,7 +182,7 @@ def test_multi_gpu_model_dp(tmpdir):
model, hparams = tutils.get_default_model()
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
distributed_backend='dp',
max_epochs=1,

View File

@ -120,7 +120,7 @@ def test_load_model_from_checkpoint(tmpdir):
val_percent_check=0.2,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
logger=False,
default_save_path=tmpdir,
default_root_dir=tmpdir,
)
# fit model
@ -331,7 +331,7 @@ def test_load_model_with_missing_hparams(tmpdir):
max_epochs=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
logger=False,
default_save_path=tmpdir,
default_root_dir=tmpdir,
)
# fit model

View File

@ -170,7 +170,7 @@ def test_early_stopping_without_val_step(tmpdir):
stopping = EarlyStopping(monitor='my_train_metric', min_delta=0.1)
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
early_stop_callback=stopping,
overfit_pct=0.20,
max_epochs=5,

View File

@ -23,7 +23,7 @@ def test_error_on_no_train_step(tmpdir):
def forward(self, x):
pass
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)
with pytest.raises(MisconfigurationException):
@ -39,7 +39,7 @@ def test_error_on_no_train_dataloader(tmpdir):
class CurrentTestModel(TestModelBase):
pass
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)
with pytest.raises(MisconfigurationException):
@ -58,7 +58,7 @@ def test_error_on_no_configure_optimizers(tmpdir):
def training_step(self, batch, batch_idx, optimizer_idx=None):
pass
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)
with pytest.raises(MisconfigurationException):
@ -76,7 +76,7 @@ def test_warning_on_wrong_validation_settings(tmpdir):
tutils.reset_seed()
hparams = tutils.get_default_hparams()
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)
class CurrentTestModel(LightTrainDataloader,
@ -120,7 +120,7 @@ def test_warning_on_wrong_test_settigs(tmpdir):
tutils.reset_seed()
hparams = tutils.get_default_hparams()
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
trainer_options = dict(default_root_dir=tmpdir, max_epochs=1)
trainer = Trainer(**trainer_options)
class CurrentTestModel(LightTrainDataloader,

View File

@ -39,7 +39,7 @@ def test_dataloader_config_errors(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=-0.1,
)
@ -54,7 +54,7 @@ def test_dataloader_config_errors(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=1.1,
)
@ -69,7 +69,7 @@ def test_dataloader_config_errors(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=10000
)
@ -84,7 +84,7 @@ def test_dataloader_config_errors(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=1.1
)
@ -112,7 +112,7 @@ def test_multiple_val_dataloader(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=1.0,
@ -151,7 +151,7 @@ def test_multiple_test_dataloader(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -185,7 +185,7 @@ def test_train_dataloaders_passed_to_fit(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -215,7 +215,7 @@ def test_train_val_dataloaders_passed_to_fit(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -250,7 +250,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -289,7 +289,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -330,7 +330,7 @@ def test_mixing_of_dataloader_options(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -371,14 +371,14 @@ def test_inf_train_dataloader(tmpdir):
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=0.5
)
trainer.fit(model)
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_check_interval=50
)
@ -388,7 +388,7 @@ def test_inf_train_dataloader(tmpdir):
assert result == 1
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
@ -413,7 +413,7 @@ def test_inf_val_dataloader(tmpdir):
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.5
)
@ -421,7 +421,7 @@ def test_inf_val_dataloader(tmpdir):
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
@ -447,7 +447,7 @@ def test_inf_test_dataloader(tmpdir):
# fit model
with pytest.raises(MisconfigurationException):
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
test_percent_check=0.5
)
@ -455,7 +455,7 @@ def test_inf_test_dataloader(tmpdir):
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1
)
result = trainer.fit(model)
@ -481,7 +481,7 @@ def test_error_on_zero_len_dataloader(tmpdir):
# fit model
with pytest.raises(ValueError):
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
test_percent_check=0.5
)
@ -506,7 +506,7 @@ def test_warning_with_few_workers(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2

View File

@ -32,7 +32,7 @@ def test_optimizer_with_scheduling(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -71,7 +71,7 @@ def test_multi_optimizer_with_scheduling(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -114,7 +114,7 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -163,7 +163,7 @@ def test_reduce_lr_on_plateau_scheduling(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
@ -263,7 +263,7 @@ def test_none_optimizer(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2

View File

@ -31,7 +31,7 @@ def test_hparams_save_load(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
)
@ -196,7 +196,7 @@ def test_gradient_accumulation_scheduling(tmpdir):
train_percent_check=0.1,
val_percent_check=0.1,
max_epochs=2,
default_save_path=tmpdir)
default_root_dir=tmpdir)
# for the test
trainer.optimizer_step = _optimizer_step
@ -336,7 +336,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir):
val_percent_check=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
logger=False,
default_save_path=tmpdir,
default_root_dir=tmpdir,
early_stop_callback=False,
val_check_interval=1.,
)
@ -386,7 +386,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
# define less train steps than epochs
trainer_options.update(dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=3,
max_steps=num_train_samples + 10
))
@ -421,7 +421,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
# define callback for stopping the model and default epochs
trainer_options.update(dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0),
val_check_interval=2,
min_epochs=1,
@ -472,7 +472,7 @@ def test_benchmark_option(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_epochs=1,
benchmark=True,
)
@ -591,7 +591,7 @@ def test_nan_loss_detection(tmpdir):
# fit model
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_steps=(test_step + 1),
)
@ -617,7 +617,7 @@ def test_nan_params_detection(tmpdir):
model = NanParamModel(hparams)
trainer = Trainer(
default_save_path=tmpdir,
default_root_dir=tmpdir,
max_steps=(test_step + 1),
)
@ -651,7 +651,7 @@ def test_trainer_interrupted_flag(tmpdir):
'train_percent_check': 0.2,
'progress_bar_refresh_rate': 0,
'logger': False,
'default_save_path': tmpdir,
'default_root_dir': tmpdir,
}
trainer = Trainer(**trainer_options)
@ -678,7 +678,7 @@ def test_gradient_clipping(tmpdir):
trainer = Trainer(max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_save_path=tmpdir)
default_root_dir=tmpdir)
# for the test
model.optimizer_step = _optimizer_step