make experiment param in trainer optional (#77)

* removed forced exp

* modified test to also run without exp
This commit is contained in:
William Falcon 2019-08-08 10:59:16 -04:00 committed by GitHub
parent aa7245d9db
commit 3d23a56ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 19 deletions

View File

@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus):
class Trainer(TrainerIO):
def __init__(self,
experiment,
experiment=None,
early_stop_callback=None,
checkpoint_callback=None,
gradient_clip=0,
@ -122,7 +122,9 @@ class Trainer(TrainerIO):
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.progress_bar = progress_bar
self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.exp_save_path = None
if self.experiment is not None:
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster
self.process_position = process_position
self.current_gpu_name = current_gpu_name
@ -312,13 +314,15 @@ class Trainer(TrainerIO):
@property
def __tng_tqdm_dic(self):
# ForkedPdb().set_trace()
tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch),
'batch_nb': '{}'.format(self.batch_nb),
}
if self.experiment is not None:
tqdm_dic['v_nb'] = self.experiment.version
tqdm_dic.update(self.tqdm_metrics)
if self.on_gpu:
@ -462,7 +466,8 @@ dataloader = Dataloader(dataset, sampler=dist_sampler)
if self.use_ddp:
# must copy only the meta of the exp so it survives pickle/unpickle
# when going to new process
self.experiment = self.experiment.get_meta_copy()
if self.experiment is not None:
self.experiment = self.experiment.get_meta_copy()
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
@ -564,8 +569,9 @@ We recommend you switch to ddp if you want to use amp
# recover original exp before went into process
# init in write mode only on proc 0
self.experiment.debug = self.proc_rank > 0
self.experiment = self.experiment.get_non_ddp_exp()
if self.experiment is not None:
self.experiment.debug = self.proc_rank > 0
self.experiment = self.experiment.get_non_ddp_exp()
# show progbar only on prog_rank 0
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
@ -575,7 +581,8 @@ We recommend you switch to ddp if you want to use amp
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
# let the exp know the rank to avoid overwriting logs
self.experiment.rank = self.proc_rank
if self.experiment is not None:
self.experiment.rank = self.proc_rank
# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
@ -673,10 +680,12 @@ We recommend you switch to ddp if you want to use amp
# give model convenience properties
ref_model.trainer = self
ref_model.experiment = self.experiment
if self.experiment is not None:
ref_model.experiment = self.experiment
# save exp to get started
if self.proc_rank == 0:
if self.proc_rank == 0 and self.experiment is not None:
self.experiment.save()
# track model now.
@ -756,7 +765,7 @@ We recommend you switch to ddp if you want to use amp
# when batch should be saved
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
if self.proc_rank == 0:
if self.proc_rank == 0 and self.experiment is not None:
self.experiment.save()
# when metrics should be logged
@ -784,7 +793,7 @@ We recommend you switch to ddp if you want to use amp
# log metrics
scalar_metrics = self.__metrics_to_scalars(
metrics, blacklist=self.__log_vals_blacklist())
if self.proc_rank == 0:
if self.proc_rank == 0 and self.experiment is not None:
self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save()
@ -813,7 +822,7 @@ We recommend you switch to ddp if you want to use amp
if stop:
return
def __metrics_to_scalars(self, metrics, blacklist=[]):
def __metrics_to_scalars(self, metrics, blacklist=set()):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:

View File

@ -37,16 +37,10 @@ def test_simple_cpu():
save_dir = init_save_dir()
# exp file to get meta
test_exp_version = 10
exp = get_exp(False, version=test_exp_version)
exp.argparse(hparams)
exp.save()
trainer_options = dict(
max_nb_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
experiment=exp,
)
# fit model