make experiment param in trainer optional (#77)

* removed forced exp

* modified test to also run without exp
This commit is contained in:
William Falcon 2019-08-08 10:59:16 -04:00 committed by GitHub
parent aa7245d9db
commit 3d23a56ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 19 deletions

View File

@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus):
class Trainer(TrainerIO): class Trainer(TrainerIO):
def __init__(self, def __init__(self,
experiment, experiment=None,
early_stop_callback=None, early_stop_callback=None,
checkpoint_callback=None, checkpoint_callback=None,
gradient_clip=0, gradient_clip=0,
@ -122,7 +122,9 @@ class Trainer(TrainerIO):
self.on_gpu = gpus is not None and torch.cuda.is_available() self.on_gpu = gpus is not None and torch.cuda.is_available()
self.progress_bar = progress_bar self.progress_bar = progress_bar
self.experiment = experiment self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) self.exp_save_path = None
if self.experiment is not None:
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster self.cluster = cluster
self.process_position = process_position self.process_position = process_position
self.current_gpu_name = current_gpu_name self.current_gpu_name = current_gpu_name
@ -312,13 +314,15 @@ class Trainer(TrainerIO):
@property @property
def __tng_tqdm_dic(self): def __tng_tqdm_dic(self):
# ForkedPdb().set_trace()
tqdm_dic = { tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss), 'tng_loss': '{0:.3f}'.format(self.avg_loss),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch), 'epoch': '{}'.format(self.current_epoch),
'batch_nb': '{}'.format(self.batch_nb), 'batch_nb': '{}'.format(self.batch_nb),
} }
if self.experiment is not None:
tqdm_dic['v_nb'] = self.experiment.version
tqdm_dic.update(self.tqdm_metrics) tqdm_dic.update(self.tqdm_metrics)
if self.on_gpu: if self.on_gpu:
@ -462,7 +466,8 @@ dataloader = Dataloader(dataset, sampler=dist_sampler)
if self.use_ddp: if self.use_ddp:
# must copy only the meta of the exp so it survives pickle/unpickle # must copy only the meta of the exp so it survives pickle/unpickle
# when going to new process # when going to new process
self.experiment = self.experiment.get_meta_copy() if self.experiment is not None:
self.experiment = self.experiment.get_meta_copy()
if self.is_slurm_managing_tasks: if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID']) task = int(os.environ['SLURM_LOCALID'])
@ -564,8 +569,9 @@ We recommend you switch to ddp if you want to use amp
# recover original exp before went into process # recover original exp before went into process
# init in write mode only on proc 0 # init in write mode only on proc 0
self.experiment.debug = self.proc_rank > 0 if self.experiment is not None:
self.experiment = self.experiment.get_non_ddp_exp() self.experiment.debug = self.proc_rank > 0
self.experiment = self.experiment.get_non_ddp_exp()
# show progbar only on prog_rank 0 # show progbar only on prog_rank 0
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0 self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
@ -575,7 +581,8 @@ We recommend you switch to ddp if you want to use amp
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids) self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
# let the exp know the rank to avoid overwriting logs # let the exp know the rank to avoid overwriting logs
self.experiment.rank = self.proc_rank if self.experiment is not None:
self.experiment.rank = self.proc_rank
# set up server using proc 0's ip address # set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken # try to init for 20 times at max in case ports are taken
@ -673,10 +680,12 @@ We recommend you switch to ddp if you want to use amp
# give model convenience properties # give model convenience properties
ref_model.trainer = self ref_model.trainer = self
ref_model.experiment = self.experiment
if self.experiment is not None:
ref_model.experiment = self.experiment
# save exp to get started # save exp to get started
if self.proc_rank == 0: if self.proc_rank == 0 and self.experiment is not None:
self.experiment.save() self.experiment.save()
# track model now. # track model now.
@ -756,7 +765,7 @@ We recommend you switch to ddp if you want to use amp
# when batch should be saved # when batch should be saved
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch: if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
if self.proc_rank == 0: if self.proc_rank == 0 and self.experiment is not None:
self.experiment.save() self.experiment.save()
# when metrics should be logged # when metrics should be logged
@ -784,7 +793,7 @@ We recommend you switch to ddp if you want to use amp
# log metrics # log metrics
scalar_metrics = self.__metrics_to_scalars( scalar_metrics = self.__metrics_to_scalars(
metrics, blacklist=self.__log_vals_blacklist()) metrics, blacklist=self.__log_vals_blacklist())
if self.proc_rank == 0: if self.proc_rank == 0 and self.experiment is not None:
self.experiment.log(scalar_metrics, global_step=self.global_step) self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save() self.experiment.save()
@ -813,7 +822,7 @@ We recommend you switch to ddp if you want to use amp
if stop: if stop:
return return
def __metrics_to_scalars(self, metrics, blacklist=[]): def __metrics_to_scalars(self, metrics, blacklist=set()):
new_metrics = {} new_metrics = {}
for k, v in metrics.items(): for k, v in metrics.items():
if type(v) is torch.Tensor: if type(v) is torch.Tensor:

View File

@ -37,16 +37,10 @@ def test_simple_cpu():
save_dir = init_save_dir() save_dir = init_save_dir()
# exp file to get meta # exp file to get meta
test_exp_version = 10
exp = get_exp(False, version=test_exp_version)
exp.argparse(hparams)
exp.save()
trainer_options = dict( trainer_options = dict(
max_nb_epochs=1, max_nb_epochs=1,
val_percent_check=0.1, val_percent_check=0.1,
train_percent_check=0.1, train_percent_check=0.1,
experiment=exp,
) )
# fit model # fit model