make experiment param in trainer optional (#77)
* removed forced exp * modified test to also run without exp
This commit is contained in:
parent
aa7245d9db
commit
3d23a56ed2
|
@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus):
|
|||
class Trainer(TrainerIO):
|
||||
|
||||
def __init__(self,
|
||||
experiment,
|
||||
experiment=None,
|
||||
early_stop_callback=None,
|
||||
checkpoint_callback=None,
|
||||
gradient_clip=0,
|
||||
|
@ -122,7 +122,9 @@ class Trainer(TrainerIO):
|
|||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||
self.progress_bar = progress_bar
|
||||
self.experiment = experiment
|
||||
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
||||
self.exp_save_path = None
|
||||
if self.experiment is not None:
|
||||
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
||||
self.cluster = cluster
|
||||
self.process_position = process_position
|
||||
self.current_gpu_name = current_gpu_name
|
||||
|
@ -312,13 +314,15 @@ class Trainer(TrainerIO):
|
|||
|
||||
@property
|
||||
def __tng_tqdm_dic(self):
|
||||
# ForkedPdb().set_trace()
|
||||
tqdm_dic = {
|
||||
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
||||
'v_nb': '{}'.format(self.experiment.version),
|
||||
'epoch': '{}'.format(self.current_epoch),
|
||||
'batch_nb': '{}'.format(self.batch_nb),
|
||||
}
|
||||
|
||||
if self.experiment is not None:
|
||||
tqdm_dic['v_nb'] = self.experiment.version
|
||||
|
||||
tqdm_dic.update(self.tqdm_metrics)
|
||||
|
||||
if self.on_gpu:
|
||||
|
@ -462,7 +466,8 @@ dataloader = Dataloader(dataset, sampler=dist_sampler)
|
|||
if self.use_ddp:
|
||||
# must copy only the meta of the exp so it survives pickle/unpickle
|
||||
# when going to new process
|
||||
self.experiment = self.experiment.get_meta_copy()
|
||||
if self.experiment is not None:
|
||||
self.experiment = self.experiment.get_meta_copy()
|
||||
|
||||
if self.is_slurm_managing_tasks:
|
||||
task = int(os.environ['SLURM_LOCALID'])
|
||||
|
@ -564,8 +569,9 @@ We recommend you switch to ddp if you want to use amp
|
|||
|
||||
# recover original exp before went into process
|
||||
# init in write mode only on proc 0
|
||||
self.experiment.debug = self.proc_rank > 0
|
||||
self.experiment = self.experiment.get_non_ddp_exp()
|
||||
if self.experiment is not None:
|
||||
self.experiment.debug = self.proc_rank > 0
|
||||
self.experiment = self.experiment.get_non_ddp_exp()
|
||||
|
||||
# show progbar only on prog_rank 0
|
||||
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
|
||||
|
@ -575,7 +581,8 @@ We recommend you switch to ddp if you want to use amp
|
|||
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
|
||||
|
||||
# let the exp know the rank to avoid overwriting logs
|
||||
self.experiment.rank = self.proc_rank
|
||||
if self.experiment is not None:
|
||||
self.experiment.rank = self.proc_rank
|
||||
|
||||
# set up server using proc 0's ip address
|
||||
# try to init for 20 times at max in case ports are taken
|
||||
|
@ -673,10 +680,12 @@ We recommend you switch to ddp if you want to use amp
|
|||
|
||||
# give model convenience properties
|
||||
ref_model.trainer = self
|
||||
ref_model.experiment = self.experiment
|
||||
|
||||
if self.experiment is not None:
|
||||
ref_model.experiment = self.experiment
|
||||
|
||||
# save exp to get started
|
||||
if self.proc_rank == 0:
|
||||
if self.proc_rank == 0 and self.experiment is not None:
|
||||
self.experiment.save()
|
||||
|
||||
# track model now.
|
||||
|
@ -756,7 +765,7 @@ We recommend you switch to ddp if you want to use amp
|
|||
|
||||
# when batch should be saved
|
||||
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
|
||||
if self.proc_rank == 0:
|
||||
if self.proc_rank == 0 and self.experiment is not None:
|
||||
self.experiment.save()
|
||||
|
||||
# when metrics should be logged
|
||||
|
@ -784,7 +793,7 @@ We recommend you switch to ddp if you want to use amp
|
|||
# log metrics
|
||||
scalar_metrics = self.__metrics_to_scalars(
|
||||
metrics, blacklist=self.__log_vals_blacklist())
|
||||
if self.proc_rank == 0:
|
||||
if self.proc_rank == 0 and self.experiment is not None:
|
||||
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
||||
self.experiment.save()
|
||||
|
||||
|
@ -813,7 +822,7 @@ We recommend you switch to ddp if you want to use amp
|
|||
if stop:
|
||||
return
|
||||
|
||||
def __metrics_to_scalars(self, metrics, blacklist=[]):
|
||||
def __metrics_to_scalars(self, metrics, blacklist=set()):
|
||||
new_metrics = {}
|
||||
for k, v in metrics.items():
|
||||
if type(v) is torch.Tensor:
|
||||
|
|
|
@ -37,16 +37,10 @@ def test_simple_cpu():
|
|||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
test_exp_version = 10
|
||||
exp = get_exp(False, version=test_exp_version)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.1,
|
||||
experiment=exp,
|
||||
)
|
||||
|
||||
# fit model
|
||||
|
|
Loading…
Reference in New Issue