make experiment param in trainer optional (#77)
* removed forced exp * modified test to also run without exp
This commit is contained in:
parent
aa7245d9db
commit
3d23a56ed2
|
@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus):
|
||||||
class Trainer(TrainerIO):
|
class Trainer(TrainerIO):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
experiment,
|
experiment=None,
|
||||||
early_stop_callback=None,
|
early_stop_callback=None,
|
||||||
checkpoint_callback=None,
|
checkpoint_callback=None,
|
||||||
gradient_clip=0,
|
gradient_clip=0,
|
||||||
|
@ -122,7 +122,9 @@ class Trainer(TrainerIO):
|
||||||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||||
self.progress_bar = progress_bar
|
self.progress_bar = progress_bar
|
||||||
self.experiment = experiment
|
self.experiment = experiment
|
||||||
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
self.exp_save_path = None
|
||||||
|
if self.experiment is not None:
|
||||||
|
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
||||||
self.cluster = cluster
|
self.cluster = cluster
|
||||||
self.process_position = process_position
|
self.process_position = process_position
|
||||||
self.current_gpu_name = current_gpu_name
|
self.current_gpu_name = current_gpu_name
|
||||||
|
@ -312,13 +314,15 @@ class Trainer(TrainerIO):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def __tng_tqdm_dic(self):
|
def __tng_tqdm_dic(self):
|
||||||
# ForkedPdb().set_trace()
|
|
||||||
tqdm_dic = {
|
tqdm_dic = {
|
||||||
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
||||||
'v_nb': '{}'.format(self.experiment.version),
|
|
||||||
'epoch': '{}'.format(self.current_epoch),
|
'epoch': '{}'.format(self.current_epoch),
|
||||||
'batch_nb': '{}'.format(self.batch_nb),
|
'batch_nb': '{}'.format(self.batch_nb),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.experiment is not None:
|
||||||
|
tqdm_dic['v_nb'] = self.experiment.version
|
||||||
|
|
||||||
tqdm_dic.update(self.tqdm_metrics)
|
tqdm_dic.update(self.tqdm_metrics)
|
||||||
|
|
||||||
if self.on_gpu:
|
if self.on_gpu:
|
||||||
|
@ -462,7 +466,8 @@ dataloader = Dataloader(dataset, sampler=dist_sampler)
|
||||||
if self.use_ddp:
|
if self.use_ddp:
|
||||||
# must copy only the meta of the exp so it survives pickle/unpickle
|
# must copy only the meta of the exp so it survives pickle/unpickle
|
||||||
# when going to new process
|
# when going to new process
|
||||||
self.experiment = self.experiment.get_meta_copy()
|
if self.experiment is not None:
|
||||||
|
self.experiment = self.experiment.get_meta_copy()
|
||||||
|
|
||||||
if self.is_slurm_managing_tasks:
|
if self.is_slurm_managing_tasks:
|
||||||
task = int(os.environ['SLURM_LOCALID'])
|
task = int(os.environ['SLURM_LOCALID'])
|
||||||
|
@ -564,8 +569,9 @@ We recommend you switch to ddp if you want to use amp
|
||||||
|
|
||||||
# recover original exp before went into process
|
# recover original exp before went into process
|
||||||
# init in write mode only on proc 0
|
# init in write mode only on proc 0
|
||||||
self.experiment.debug = self.proc_rank > 0
|
if self.experiment is not None:
|
||||||
self.experiment = self.experiment.get_non_ddp_exp()
|
self.experiment.debug = self.proc_rank > 0
|
||||||
|
self.experiment = self.experiment.get_non_ddp_exp()
|
||||||
|
|
||||||
# show progbar only on prog_rank 0
|
# show progbar only on prog_rank 0
|
||||||
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
|
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
|
||||||
|
@ -575,7 +581,8 @@ We recommend you switch to ddp if you want to use amp
|
||||||
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
|
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
|
||||||
|
|
||||||
# let the exp know the rank to avoid overwriting logs
|
# let the exp know the rank to avoid overwriting logs
|
||||||
self.experiment.rank = self.proc_rank
|
if self.experiment is not None:
|
||||||
|
self.experiment.rank = self.proc_rank
|
||||||
|
|
||||||
# set up server using proc 0's ip address
|
# set up server using proc 0's ip address
|
||||||
# try to init for 20 times at max in case ports are taken
|
# try to init for 20 times at max in case ports are taken
|
||||||
|
@ -673,10 +680,12 @@ We recommend you switch to ddp if you want to use amp
|
||||||
|
|
||||||
# give model convenience properties
|
# give model convenience properties
|
||||||
ref_model.trainer = self
|
ref_model.trainer = self
|
||||||
ref_model.experiment = self.experiment
|
|
||||||
|
if self.experiment is not None:
|
||||||
|
ref_model.experiment = self.experiment
|
||||||
|
|
||||||
# save exp to get started
|
# save exp to get started
|
||||||
if self.proc_rank == 0:
|
if self.proc_rank == 0 and self.experiment is not None:
|
||||||
self.experiment.save()
|
self.experiment.save()
|
||||||
|
|
||||||
# track model now.
|
# track model now.
|
||||||
|
@ -756,7 +765,7 @@ We recommend you switch to ddp if you want to use amp
|
||||||
|
|
||||||
# when batch should be saved
|
# when batch should be saved
|
||||||
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
|
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
|
||||||
if self.proc_rank == 0:
|
if self.proc_rank == 0 and self.experiment is not None:
|
||||||
self.experiment.save()
|
self.experiment.save()
|
||||||
|
|
||||||
# when metrics should be logged
|
# when metrics should be logged
|
||||||
|
@ -784,7 +793,7 @@ We recommend you switch to ddp if you want to use amp
|
||||||
# log metrics
|
# log metrics
|
||||||
scalar_metrics = self.__metrics_to_scalars(
|
scalar_metrics = self.__metrics_to_scalars(
|
||||||
metrics, blacklist=self.__log_vals_blacklist())
|
metrics, blacklist=self.__log_vals_blacklist())
|
||||||
if self.proc_rank == 0:
|
if self.proc_rank == 0 and self.experiment is not None:
|
||||||
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
||||||
self.experiment.save()
|
self.experiment.save()
|
||||||
|
|
||||||
|
@ -813,7 +822,7 @@ We recommend you switch to ddp if you want to use amp
|
||||||
if stop:
|
if stop:
|
||||||
return
|
return
|
||||||
|
|
||||||
def __metrics_to_scalars(self, metrics, blacklist=[]):
|
def __metrics_to_scalars(self, metrics, blacklist=set()):
|
||||||
new_metrics = {}
|
new_metrics = {}
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if type(v) is torch.Tensor:
|
if type(v) is torch.Tensor:
|
||||||
|
|
|
@ -37,16 +37,10 @@ def test_simple_cpu():
|
||||||
save_dir = init_save_dir()
|
save_dir = init_save_dir()
|
||||||
|
|
||||||
# exp file to get meta
|
# exp file to get meta
|
||||||
test_exp_version = 10
|
|
||||||
exp = get_exp(False, version=test_exp_version)
|
|
||||||
exp.argparse(hparams)
|
|
||||||
exp.save()
|
|
||||||
|
|
||||||
trainer_options = dict(
|
trainer_options = dict(
|
||||||
max_nb_epochs=1,
|
max_nb_epochs=1,
|
||||||
val_percent_check=0.1,
|
val_percent_check=0.1,
|
||||||
train_percent_check=0.1,
|
train_percent_check=0.1,
|
||||||
experiment=exp,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# fit model
|
# fit model
|
||||||
|
|
Loading…
Reference in New Issue