fixes ddp bugs (#1819)
* debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug
This commit is contained in:
parent
648d516668
commit
53d9316a56
|
@ -149,6 +149,10 @@ class ModelCheckpoint(Callback):
|
|||
return True
|
||||
|
||||
if not isinstance(current, torch.Tensor):
|
||||
rank_zero_warn(
|
||||
f'{current} is supposed to be a torch.Tensor. Saving checkpoint may not work correctly. '
|
||||
f'HINT: check the value of {self.monitor} in your validation loop', RuntimeWarning
|
||||
)
|
||||
current = torch.tensor(current)
|
||||
|
||||
monitor_op = {
|
||||
|
@ -223,6 +227,12 @@ class ModelCheckpoint(Callback):
|
|||
if self.save_top_k != -1:
|
||||
current = metrics.get(self.monitor)
|
||||
|
||||
if not isinstance(current, torch.Tensor):
|
||||
rank_zero_warn(
|
||||
f'The metric you returned {current} must be a Torch.Tensor instance, checkpoint not saved '
|
||||
f'HINT: what is the value of {self.monitor} in validation_end()?', RuntimeWarning
|
||||
)
|
||||
|
||||
if current is None:
|
||||
rank_zero_warn(
|
||||
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
|
||||
|
|
|
@ -111,8 +111,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
if not is_dataloader or is_iterable_ds:
|
||||
return dataloader
|
||||
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
|
||||
if self.replace_sampler_ddp and need_dist_sampler:
|
||||
|
||||
if self.replace_sampler_ddp and need_dist_sampler:
|
||||
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
|
||||
|
||||
dl_args = {
|
||||
|
@ -137,7 +137,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
}
|
||||
sampler = DistributedSampler(
|
||||
dataloader.dataset,
|
||||
num_replicas=world_size.get(self.distributed_backend, 0),
|
||||
num_replicas=world_size[self.distributed_backend],
|
||||
rank=self.proc_rank,
|
||||
)
|
||||
|
||||
|
|
|
@ -155,6 +155,7 @@ class TrainerDDPMixin(ABC):
|
|||
default_root_dir: str
|
||||
use_native_amp: bool
|
||||
progress_bar_callback: ...
|
||||
num_processes: int
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -204,14 +205,17 @@ class TrainerDDPMixin(ABC):
|
|||
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
|
||||
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
|
||||
' Setting distributed_backend=ddp for you.')
|
||||
self.use_ddp = True
|
||||
elif distributed_backend == "dp":
|
||||
self.distributed_backend = 'ddp'
|
||||
distributed_backend = 'ddp'
|
||||
|
||||
if distributed_backend == "dp":
|
||||
# do nothing if num_gpus == 0
|
||||
if self.num_gpus == 1:
|
||||
self.single_gpu = True
|
||||
self.use_dp = True
|
||||
elif self.num_gpus > 1:
|
||||
self.use_dp = True
|
||||
|
||||
elif distributed_backend == "ddp":
|
||||
if self.num_gpus == 0:
|
||||
if self.num_nodes > 1 or self.num_processes > 1:
|
||||
|
@ -222,6 +226,7 @@ class TrainerDDPMixin(ABC):
|
|||
elif self.num_gpus > 1:
|
||||
self.use_ddp = True
|
||||
self.num_processes = self.num_gpus
|
||||
|
||||
elif distributed_backend == "ddp2":
|
||||
# do nothing if num_gpus == 0
|
||||
if self.num_gpus >= 1:
|
||||
|
@ -314,7 +319,8 @@ class TrainerDDPMixin(ABC):
|
|||
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
|
||||
|
||||
log.debug(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
|
||||
# don't make this debug... this is good UX
|
||||
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
|
||||
|
||||
def ddp_train(self, process_idx, model):
|
||||
"""
|
||||
|
|
|
@ -530,7 +530,9 @@ class TrainerDPMixin(ABC):
|
|||
# continue training routine
|
||||
self.run_pretrain_routine(model)
|
||||
|
||||
self.save_spawn_weights(model)
|
||||
# when training ends on these platforms dump weights to get out of the main process
|
||||
if self.on_colab_kaggle:
|
||||
self.save_spawn_weights(model)
|
||||
|
||||
def dp_train(self, model):
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ class Trainer(
|
|||
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
|
||||
weights_summary: Optional[str] = 'full',
|
||||
weights_save_path: Optional[str] = None,
|
||||
num_sanity_val_steps: int = 5,
|
||||
num_sanity_val_steps: int = 2,
|
||||
truncated_bptt_steps: Optional[int] = None,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
profiler: Optional[Union[BaseProfiler, bool]] = None,
|
||||
|
@ -526,6 +526,8 @@ class Trainer(
|
|||
self.amp_level = amp_level
|
||||
self.init_amp(use_amp)
|
||||
|
||||
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
|
||||
|
@ -821,7 +823,7 @@ class Trainer(
|
|||
# train
|
||||
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
|
||||
# load weights if not interrupted
|
||||
if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE'):
|
||||
if self.on_colab_kaggle:
|
||||
self.load_spawn_weights(model)
|
||||
self.model = model
|
||||
|
||||
|
@ -840,7 +842,7 @@ class Trainer(
|
|||
log.info(f'training on {self.num_tpu_cores} TPU cores')
|
||||
|
||||
# COLAB_GPU is an env var available by default in Colab environments.
|
||||
start_method = 'fork' if os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') else 'spawn'
|
||||
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
|
||||
|
||||
# track for predict
|
||||
self.model = model
|
||||
|
|
|
@ -744,7 +744,7 @@ def test_gpu_choice(tmpdir):
|
|||
),
|
||||
pytest.param(
|
||||
dict(distributed_backend=None, gpus=2),
|
||||
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=1),
|
||||
dict(use_dp=False, use_ddp=True, use_ddp2=False, num_gpus=2, on_gpu=True, single_gpu=False, num_processes=2),
|
||||
marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")]
|
||||
),
|
||||
pytest.param(
|
||||
|
|
Loading…
Reference in New Issue