Provide backward compatibility for #124 (#400)

* Provide backward compatibility for e681253

* typo fix
This commit is contained in:
tamyiuchau 2019-10-21 14:16:55 +08:00 committed by William Falcon
parent 67f6e7bb19
commit 4103a5ca73
2 changed files with 43 additions and 1 deletions

View File

@ -7,6 +7,8 @@ from pytorch_lightning.root_module.model_saving import ModelIO
from pytorch_lightning.root_module.hooks import ModelHooks from pytorch_lightning.root_module.hooks import ModelHooks
from pytorch_lightning.root_module.decorators import data_loader from pytorch_lightning.root_module.decorators import data_loader
import warnings
class LightningModule(GradInformation, ModelIO, ModelHooks): class LightningModule(GradInformation, ModelIO, ModelHooks):
@ -110,13 +112,29 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
# clear gradients # clear gradients
optimizer.zero_grad() optimizer.zero_grad()
@data_loader
def tng_dataloader(self):
"""
Implement a PyTorch DataLoader
* Deprecated in v0.5.0. use train_dataloader instead. *
:return:
"""
raise NotImplementedError
@data_loader @data_loader
def train_dataloader(self): def train_dataloader(self):
""" """
Implement a PyTorch DataLoader Implement a PyTorch DataLoader
:return: :return:
""" """
raise NotImplementedError #
try:
output = self.tng_dataloader()
warnings.warn("tng_dataloader has been renamed to train_dataloader since v0.5.0",
DeprecationWarning)
return output
except NotImplementedError:
raise NotImplementedError
@data_loader @data_loader
def test_dataloader(self): def test_dataloader(self):

View File

@ -63,6 +63,7 @@ class Trainer(TrainerIOMixin):
early_stop_callback=True, early_stop_callback=True,
default_save_path=None, default_save_path=None,
gradient_clip_val=0, gradient_clip_val=0,
gradient_clip=None, # backward compatible
process_position=0, process_position=0,
nb_gpu_nodes=1, nb_gpu_nodes=1,
gpus=None, gpus=None,
@ -81,6 +82,7 @@ class Trainer(TrainerIOMixin):
val_check_interval=1.0, val_check_interval=1.0,
log_save_interval=100, log_save_interval=100,
row_log_interval=10, row_log_interval=10,
add_row_log_interval=None, # backward compatible
distributed_backend=None, distributed_backend=None,
use_amp=False, use_amp=False,
print_nan_grads=False, print_nan_grads=False,
@ -95,6 +97,7 @@ class Trainer(TrainerIOMixin):
:param early_stop_callback: Callback for early stopping :param early_stop_callback: Callback for early stopping
:param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed :param default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
:param gradient_clip_val: int. 0 means don't clip. :param gradient_clip_val: int. 0 means don't clip.
:param gradient_clip: int. 0 means don't clip. Deprecated.
:param process_position: shown in the tqdm bar :param process_position: shown in the tqdm bar
:param nb_gpu_nodes: number of GPU nodes :param nb_gpu_nodes: number of GPU nodes
:param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1' :param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1'
@ -113,6 +116,7 @@ class Trainer(TrainerIOMixin):
:param val_check_interval: int. Check val this frequently within a train epoch :param val_check_interval: int. Check val this frequently within a train epoch
:param log_save_interval: int. Writes logs to disk this often :param log_save_interval: int. Writes logs to disk this often
:param row_log_interval: int. How often to add logging rows :param row_log_interval: int. How often to add logging rows
:param add_row_log_interval: int. How often to add logging rows. Deprecated.
:param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'. :param distributed_backend: str. Options: 'dp', 'ddp', 'ddp2'.
:param use_amp: Bool. If true uses apex for 16bit precision :param use_amp: Bool. If true uses apex for 16bit precision
:param print_nan_grads: Bool. Prints nan gradients :param print_nan_grads: Bool. Prints nan gradients
@ -124,6 +128,11 @@ class Trainer(TrainerIOMixin):
# Transfer params # Transfer params
self.nb_gpu_nodes = nb_gpu_nodes self.nb_gpu_nodes = nb_gpu_nodes
self.log_gpu_memory = log_gpu_memory self.log_gpu_memory = log_gpu_memory
if not (gradient_clip is None):
# Backward compatibility
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
DeprecationWarning)
gradient_clip_val = gradient_clip
self.gradient_clip_val = gradient_clip_val self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch self.check_val_every_n_epoch = check_val_every_n_epoch
self.track_grad_norm = track_grad_norm self.track_grad_norm = track_grad_norm
@ -242,6 +251,11 @@ class Trainer(TrainerIOMixin):
# logging # logging
self.log_save_interval = log_save_interval self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval self.val_check_interval = val_check_interval
if not (add_row_log_interval is None):
# backward compatibility
warnings.warn("gradient_clip has renamed to gradient_clip_val since v0.5.0",
DeprecationWarning)
row_log_interval = add_row_log_interval
self.row_log_interval = row_log_interval self.row_log_interval = row_log_interval
# how much of the data to use # how much of the data to use
@ -520,6 +534,16 @@ class Trainer(TrainerIOMixin):
""" """
return self.__training_tqdm_dict return self.__training_tqdm_dict
@property
def tng_tqdm_dic(self):
"""
* Deprecated in v0.5.0. use training_tqdm_dict instead. *
:return:
"""
warnings.warn("tng_tqdm_dict has renamed to training_tqdm_dict since v0.5.0",
DeprecationWarning)
return self.training_tqdm_dict
def __layout_bookeeping(self): def __layout_bookeeping(self):
# determine number of training batches # determine number of training batches