diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6e0db93e2e..46f6a11b95 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1439,11 +1439,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod def load_from_checkpoint( cls, checkpoint_path: str, + *args, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0 hparam_overrides: Optional[Dict] = None, - *args, **kwargs + **kwargs ) -> 'LightningModule': r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint