updated args

This commit is contained in:
William Falcon 2019-06-25 18:51:41 -04:00
parent de0f7fc936
commit d4ca295762
2 changed files with 7 additions and 4 deletions

View File

@ -98,6 +98,7 @@ def main(hparams, cluster, results_dict):
cluster=cluster,
checkpoint_callback=checkpoint,
early_stop_callback=early_stop,
gpus=hparams.gpus
)
# train model

View File

@ -22,7 +22,7 @@ class Trainer(TrainerIO):
cluster=None,
process_position=0,
current_gpu_name=0,
on_gpu=False,
gpus=None,
enable_tqdm=True,
overfit_pct=0.0,
track_grad_norm=-1,
@ -43,7 +43,7 @@ class Trainer(TrainerIO):
self.enable_early_stop = enable_early_stop
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
self.on_gpu = on_gpu
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.enable_tqdm = enable_tqdm
self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
@ -63,8 +63,10 @@ class Trainer(TrainerIO):
self.lr_schedulers = []
self.amp_level = amp_level
self.check_grad_nans = check_grad_nans
self.data_parallel_device_ids = [0]
self.data_parallel = False
self.data_parallel_device_ids = gpus
self.data_parallel = gpus is not None and len(gpus) > 0
pdb.set_trace()
# training state
self.optimizers = None