From 3d23a56ed23aa701a056c3f12568c0b7867f40c9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 10:59:16 -0400 Subject: [PATCH] make experiment param in trainer optional (#77) * removed forced exp * modified test to also run without exp --- pytorch_lightning/models/trainer.py | 35 ++++++++++++++++++----------- tests/test_models.py | 6 ----- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 605f2c78b6..e67a90cbd7 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus): class Trainer(TrainerIO): def __init__(self, - experiment, + experiment=None, early_stop_callback=None, checkpoint_callback=None, gradient_clip=0, @@ -122,7 +122,9 @@ class Trainer(TrainerIO): self.on_gpu = gpus is not None and torch.cuda.is_available() self.progress_bar = progress_bar self.experiment = experiment - self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) + self.exp_save_path = None + if self.experiment is not None: + self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) self.cluster = cluster self.process_position = process_position self.current_gpu_name = current_gpu_name @@ -312,13 +314,15 @@ class Trainer(TrainerIO): @property def __tng_tqdm_dic(self): - # ForkedPdb().set_trace() tqdm_dic = { 'tng_loss': '{0:.3f}'.format(self.avg_loss), - 'v_nb': '{}'.format(self.experiment.version), 'epoch': '{}'.format(self.current_epoch), 'batch_nb': '{}'.format(self.batch_nb), } + + if self.experiment is not None: + tqdm_dic['v_nb'] = self.experiment.version + tqdm_dic.update(self.tqdm_metrics) if self.on_gpu: @@ -462,7 +466,8 @@ dataloader = Dataloader(dataset, sampler=dist_sampler) if self.use_ddp: # must copy only the meta of the exp so it survives pickle/unpickle # when going to new process - self.experiment = self.experiment.get_meta_copy() + if self.experiment is not None: + self.experiment = self.experiment.get_meta_copy() if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) @@ -564,8 +569,9 @@ We recommend you switch to ddp if you want to use amp # recover original exp before went into process # init in write mode only on proc 0 - self.experiment.debug = self.proc_rank > 0 - self.experiment = self.experiment.get_non_ddp_exp() + if self.experiment is not None: + self.experiment.debug = self.proc_rank > 0 + self.experiment = self.experiment.get_non_ddp_exp() # show progbar only on prog_rank 0 self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0 @@ -575,7 +581,8 @@ We recommend you switch to ddp if you want to use amp self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids) # let the exp know the rank to avoid overwriting logs - self.experiment.rank = self.proc_rank + if self.experiment is not None: + self.experiment.rank = self.proc_rank # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken @@ -673,10 +680,12 @@ We recommend you switch to ddp if you want to use amp # give model convenience properties ref_model.trainer = self - ref_model.experiment = self.experiment + + if self.experiment is not None: + ref_model.experiment = self.experiment # save exp to get started - if self.proc_rank == 0: + if self.proc_rank == 0 and self.experiment is not None: self.experiment.save() # track model now. @@ -756,7 +765,7 @@ We recommend you switch to ddp if you want to use amp # when batch should be saved if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch: - if self.proc_rank == 0: + if self.proc_rank == 0 and self.experiment is not None: self.experiment.save() # when metrics should be logged @@ -784,7 +793,7 @@ We recommend you switch to ddp if you want to use amp # log metrics scalar_metrics = self.__metrics_to_scalars( metrics, blacklist=self.__log_vals_blacklist()) - if self.proc_rank == 0: + if self.proc_rank == 0 and self.experiment is not None: self.experiment.log(scalar_metrics, global_step=self.global_step) self.experiment.save() @@ -813,7 +822,7 @@ We recommend you switch to ddp if you want to use amp if stop: return - def __metrics_to_scalars(self, metrics, blacklist=[]): + def __metrics_to_scalars(self, metrics, blacklist=set()): new_metrics = {} for k, v in metrics.items(): if type(v) is torch.Tensor: diff --git a/tests/test_models.py b/tests/test_models.py index c71d302392..cad55ff372 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -37,16 +37,10 @@ def test_simple_cpu(): save_dir = init_save_dir() # exp file to get meta - test_exp_version = 10 - exp = get_exp(False, version=test_exp_version) - exp.argparse(hparams) - exp.save() - trainer_options = dict( max_nb_epochs=1, val_percent_check=0.1, train_percent_check=0.1, - experiment=exp, ) # fit model