parent
0ebfb78570
commit
15e268d6df
|
@ -512,7 +512,7 @@ class TrainerDPMixin(ABC):
|
|||
# check for this bug (amp + dp + !01 doesn't work)
|
||||
# https://github.com/NVIDIA/apex/issues/227
|
||||
if self.use_dp and self.use_amp:
|
||||
if self.amp_level == 'O2':
|
||||
if self.amp_level == 'O2': # pragma: no cover
|
||||
m = f"""
|
||||
Amp level {self.amp_level} with DataParallel is not supported.
|
||||
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.
|
||||
|
|
|
@ -461,40 +461,6 @@ class Trainer(TrainerIOMixin,
|
|||
params = vars(args)
|
||||
return cls(**params)
|
||||
|
||||
def __parse_gpu_ids(self, gpus):
|
||||
"""Parse GPUs id.
|
||||
|
||||
:param list|str|int gpus: input GPU ids
|
||||
:return list(int):
|
||||
"""
|
||||
# if gpus = -1 then use all available devices
|
||||
# otherwise, split the string using commas
|
||||
if gpus is not None:
|
||||
if isinstance(gpus, list):
|
||||
gpus = gpus
|
||||
elif isinstance(gpus, str):
|
||||
if gpus == '-1':
|
||||
gpus = list(range(0, torch.cuda.device_count()))
|
||||
else:
|
||||
gpus = [int(x.strip()) for x in gpus.split(',')]
|
||||
elif isinstance(gpus, int):
|
||||
gpus = gpus
|
||||
else:
|
||||
raise ValueError('`gpus` has to be a string, int or list of ints')
|
||||
|
||||
return gpus
|
||||
|
||||
def __set_root_gpu(self, gpus):
|
||||
if gpus is None:
|
||||
return None
|
||||
|
||||
# set root gpu
|
||||
root_gpu = 0
|
||||
if isinstance(gpus, list):
|
||||
root_gpu = gpus[0]
|
||||
|
||||
return root_gpu
|
||||
|
||||
@property
|
||||
def num_gpus(self) -> int:
|
||||
gpus = self.data_parallel_device_ids
|
||||
|
@ -617,7 +583,7 @@ class Trainer(TrainerIOMixin,
|
|||
elif self.single_gpu:
|
||||
self.single_gpu_train(model)
|
||||
|
||||
elif self.use_tpu:
|
||||
elif self.use_tpu: # pragma: no cover
|
||||
log.info(f'training on {self.num_tpu_cores} TPU cores')
|
||||
|
||||
# COLAB_GPU is an env var available by default in Colab environments.
|
||||
|
@ -877,7 +843,7 @@ class Trainer(TrainerIOMixin,
|
|||
if model is not None:
|
||||
self.model = model
|
||||
self.fit(model)
|
||||
elif self.use_ddp or self.use_tpu:
|
||||
elif self.use_ddp or self.use_tpu: # pragma: no cover
|
||||
# attempt to load weights from a spawn
|
||||
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
|
||||
test_model = self.model
|
||||
|
@ -902,51 +868,3 @@ class _PatchDataLoader(object):
|
|||
|
||||
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
||||
return self.dataloader
|
||||
|
||||
|
||||
def _set_dataloader(model, dataloader, attribute):
|
||||
r'''
|
||||
Check dataloaders passed to .fit() method if they are pytorch DataLoader
|
||||
objects and whether or not we should overright the corresponding dataloader
|
||||
in the model
|
||||
|
||||
Args:
|
||||
model (LightningModule): The model to check
|
||||
|
||||
dataloader: If a pytorch dataloader (or a list of pytorch dataloaders)
|
||||
is passed, it will be incorporate into the model as model.attribute.
|
||||
If attribute alreay exist it will warn the userpass. If not a
|
||||
dataloader will throw an error
|
||||
|
||||
attribute (str): The attribute to save the dataloader under
|
||||
|
||||
'''
|
||||
# Check if attribute comes directly from base class or
|
||||
# derived in user subclass
|
||||
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
|
||||
# Val and test should be list of dataloaders
|
||||
dataloader = dataloader if attribute == 'train_dataloader' or \
|
||||
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]
|
||||
|
||||
# Check we are given valid dataloaders
|
||||
is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader)
|
||||
is_dataloader_list = isinstance(dataloader, list)
|
||||
valid_loaders = None
|
||||
if is_dataloader_list:
|
||||
valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader)
|
||||
if is_dataloader or is_dataloader_list and valid_loaders:
|
||||
|
||||
# Overwrite abstract methods
|
||||
def dl():
|
||||
return dataloader
|
||||
dl.__name__ = attribute
|
||||
setattr(model, attribute, dl)
|
||||
|
||||
elif dataloader and dataloader != [None]:
|
||||
raise ValueError(f'`{attribute}` needs to be an instance of '
|
||||
'`torch.utils.data.DataLoader` or a list of '
|
||||
'DataLoaders, instead got %r`' % dataloader)
|
||||
|
||||
elif dataloader: # if default (None) is passed, do not warn the user
|
||||
warnings.warn(f'Model has predefined `{attribute}`,'
|
||||
f' will skip `{attribute}={dataloader}` passed to fit method.')
|
||||
|
|
|
@ -206,7 +206,7 @@ class TrainerIOMixin(ABC):
|
|||
signal.signal(signal.SIGUSR1, self.sig_handler)
|
||||
signal.signal(signal.SIGTERM, self.term_handler)
|
||||
|
||||
def sig_handler(self, signum, frame):
|
||||
def sig_handler(self, signum, frame): # pragma: no cover
|
||||
if self.proc_rank == 0:
|
||||
# save weights
|
||||
log.info('handling SIGUSR1')
|
||||
|
|
Loading…
Reference in New Issue