* docs

* docs

* docs

* docs
This commit is contained in:
William Falcon 2020-03-05 19:49:18 -05:00 committed by GitHub
parent 0ebfb78570
commit 15e268d6df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 86 deletions

View File

@ -512,7 +512,7 @@ class TrainerDPMixin(ABC):
# check for this bug (amp + dp + !01 doesn't work)
# https://github.com/NVIDIA/apex/issues/227
if self.use_dp and self.use_amp:
if self.amp_level == 'O2':
if self.amp_level == 'O2': # pragma: no cover
m = f"""
Amp level {self.amp_level} with DataParallel is not supported.
See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.

View File

@ -461,40 +461,6 @@ class Trainer(TrainerIOMixin,
params = vars(args)
return cls(**params)
def __parse_gpu_ids(self, gpus):
"""Parse GPUs id.
:param list|str|int gpus: input GPU ids
:return list(int):
"""
# if gpus = -1 then use all available devices
# otherwise, split the string using commas
if gpus is not None:
if isinstance(gpus, list):
gpus = gpus
elif isinstance(gpus, str):
if gpus == '-1':
gpus = list(range(0, torch.cuda.device_count()))
else:
gpus = [int(x.strip()) for x in gpus.split(',')]
elif isinstance(gpus, int):
gpus = gpus
else:
raise ValueError('`gpus` has to be a string, int or list of ints')
return gpus
def __set_root_gpu(self, gpus):
if gpus is None:
return None
# set root gpu
root_gpu = 0
if isinstance(gpus, list):
root_gpu = gpus[0]
return root_gpu
@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
@ -617,7 +583,7 @@ class Trainer(TrainerIOMixin,
elif self.single_gpu:
self.single_gpu_train(model)
elif self.use_tpu:
elif self.use_tpu: # pragma: no cover
log.info(f'training on {self.num_tpu_cores} TPU cores')
# COLAB_GPU is an env var available by default in Colab environments.
@ -877,7 +843,7 @@ class Trainer(TrainerIOMixin,
if model is not None:
self.model = model
self.fit(model)
elif self.use_ddp or self.use_tpu:
elif self.use_ddp or self.use_tpu: # pragma: no cover
# attempt to load weights from a spawn
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
test_model = self.model
@ -902,51 +868,3 @@ class _PatchDataLoader(object):
def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader
def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
objects and whether or not we should overright the corresponding dataloader
in the model
Args:
model (LightningModule): The model to check
dataloader: If a pytorch dataloader (or a list of pytorch dataloaders)
is passed, it will be incorporate into the model as model.attribute.
If attribute alreay exist it will warn the userpass. If not a
dataloader will throw an error
attribute (str): The attribute to save the dataloader under
'''
# Check if attribute comes directly from base class or
# derived in user subclass
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
# Val and test should be list of dataloaders
dataloader = dataloader if attribute == 'train_dataloader' or \
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]
# Check we are given valid dataloaders
is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader)
is_dataloader_list = isinstance(dataloader, list)
valid_loaders = None
if is_dataloader_list:
valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader)
if is_dataloader or is_dataloader_list and valid_loaders:
# Overwrite abstract methods
def dl():
return dataloader
dl.__name__ = attribute
setattr(model, attribute, dl)
elif dataloader and dataloader != [None]:
raise ValueError(f'`{attribute}` needs to be an instance of '
'`torch.utils.data.DataLoader` or a list of '
'DataLoaders, instead got %r`' % dataloader)
elif dataloader: # if default (None) is passed, do not warn the user
warnings.warn(f'Model has predefined `{attribute}`,'
f' will skip `{attribute}={dataloader}` passed to fit method.')

View File

@ -206,7 +206,7 @@ class TrainerIOMixin(ABC):
signal.signal(signal.SIGUSR1, self.sig_handler)
signal.signal(signal.SIGTERM, self.term_handler)
def sig_handler(self, signum, frame):
def sig_handler(self, signum, frame): # pragma: no cover
if self.proc_rank == 0:
# save weights
log.info('handling SIGUSR1')