updated args
This commit is contained in:
parent
de0f7fc936
commit
d4ca295762
|
@ -98,6 +98,7 @@ def main(hparams, cluster, results_dict):
|
|||
cluster=cluster,
|
||||
checkpoint_callback=checkpoint,
|
||||
early_stop_callback=early_stop,
|
||||
gpus=hparams.gpus
|
||||
)
|
||||
|
||||
# train model
|
||||
|
|
|
@ -22,7 +22,7 @@ class Trainer(TrainerIO):
|
|||
cluster=None,
|
||||
process_position=0,
|
||||
current_gpu_name=0,
|
||||
on_gpu=False,
|
||||
gpus=None,
|
||||
enable_tqdm=True,
|
||||
overfit_pct=0.0,
|
||||
track_grad_norm=-1,
|
||||
|
@ -43,7 +43,7 @@ class Trainer(TrainerIO):
|
|||
self.enable_early_stop = enable_early_stop
|
||||
self.track_grad_norm = track_grad_norm
|
||||
self.fast_dev_run = fast_dev_run
|
||||
self.on_gpu = on_gpu
|
||||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||
self.enable_tqdm = enable_tqdm
|
||||
self.experiment = experiment
|
||||
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
||||
|
@ -63,8 +63,10 @@ class Trainer(TrainerIO):
|
|||
self.lr_schedulers = []
|
||||
self.amp_level = amp_level
|
||||
self.check_grad_nans = check_grad_nans
|
||||
self.data_parallel_device_ids = [0]
|
||||
self.data_parallel = False
|
||||
self.data_parallel_device_ids = gpus
|
||||
self.data_parallel = gpus is not None and len(gpus) > 0
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
# training state
|
||||
self.optimizers = None
|
||||
|
|
Loading…
Reference in New Issue