fixes ddp bugs (#1819)

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug
This commit is contained in:
William Falcon 2020-05-13 19:17:04 -04:00 committed by GitHub
parent 648d516668
commit 53d9316a56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 30 additions and 10 deletions

View File

@ -149,6 +149,10 @@ class ModelCheckpoint(Callback):
return True
if not isinstance(current, torch.Tensor):
rank_zero_warn(
f'{current} is supposed to be a torch.Tensor. Saving checkpoint may not work correctly. '
f'HINT: check the value of {self.monitor} in your validation loop', RuntimeWarning
)
current = torch.tensor(current)
monitor_op = {
@ -223,6 +227,12 @@ class ModelCheckpoint(Callback):
if self.save_top_k != -1:
current = metrics.get(self.monitor)
if not isinstance(current, torch.Tensor):
rank_zero_warn(
f'The metric you returned {current} must be a Torch.Tensor instance, checkpoint not saved '
f'HINT: what is the value of {self.monitor} in validation_end()?', RuntimeWarning
)
if current is None:
rank_zero_warn(
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning

View File

@ -111,8 +111,8 @@ class TrainerDataLoadingMixin(ABC):
if not is_dataloader or is_iterable_ds:
return dataloader
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
if self.replace_sampler_ddp and need_dist_sampler:
if self.replace_sampler_ddp and need_dist_sampler:
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
dl_args = {
@ -137,7 +137,7 @@ class TrainerDataLoadingMixin(ABC):
}
sampler = DistributedSampler(
dataloader.dataset,
num_replicas=world_size.get(self.distributed_backend, 0),
num_replicas=world_size[self.distributed_backend],
rank=self.proc_rank,
)

View File

@ -155,6 +155,7 @@ class TrainerDDPMixin(ABC):
default_root_dir: str
use_native_amp: bool
progress_bar_callback: ...
num_processes: int
@property
@abstractmethod
@ -204,14 +205,17 @@ class TrainerDDPMixin(ABC):
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
' Setting distributed_backend=ddp for you.')
self.use_ddp = True
elif distributed_backend == "dp":
self.distributed_backend = 'ddp'
distributed_backend = 'ddp'
if distributed_backend == "dp":
# do nothing if num_gpus == 0
if self.num_gpus == 1:
self.single_gpu = True
self.use_dp = True
elif self.num_gpus > 1:
self.use_dp = True
elif distributed_backend == "ddp":
if self.num_gpus == 0:
if self.num_nodes > 1 or self.num_processes > 1:
@ -222,6 +226,7 @@ class TrainerDDPMixin(ABC):
elif self.num_gpus > 1:
self.use_ddp = True
self.num_processes = self.num_gpus
elif distributed_backend == "ddp2":
# do nothing if num_gpus == 0
if self.num_gpus >= 1:
@ -314,7 +319,8 @@ class TrainerDDPMixin(ABC):
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
log.debug(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
# don't make this debug... this is good UX
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
def ddp_train(self, process_idx, model):
"""

View File

@ -530,6 +530,8 @@ class TrainerDPMixin(ABC):
# continue training routine
self.run_pretrain_routine(model)
# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
self.save_spawn_weights(model)
def dp_train(self, model):

View File

@ -121,7 +121,7 @@ class Trainer(
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: Optional[str] = 'full',
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 5,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[Union[BaseProfiler, bool]] = None,
@ -526,6 +526,8 @@ class Trainer(
self.amp_level = amp_level
self.init_amp(use_amp)
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# Callback system
self.on_init_end()
@ -821,7 +823,7 @@ class Trainer(
# train
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
# load weights if not interrupted
if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE'):
if self.on_colab_kaggle:
self.load_spawn_weights(model)
self.model = model
@ -840,7 +842,7 @@ class Trainer(
log.info(f'training on {self.num_tpu_cores} TPU cores')
# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') else 'spawn'
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
# track for predict
self.model = model

View File

@ -744,7 +744,7 @@ def test_gpu_choice(tmpdir):
),
pytest.param(
dict(distributed_backend=None, gpus=2),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=2),
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
),
pytest.param(