diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cb32aaf34..6e4436b393 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) - Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259)) - Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) +- Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339)) ### Changed diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index e11f5139ca..5890873cdb 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -33,6 +33,7 @@ class WandbLogger(LightningLoggerBase): anonymous (bool): enables or explicitly disables anonymous logging. project (str): the name of the project to which this run will belong. tags (list of str): tags associated with this run. + log_model (bool): save checkpoints in wandb dir to upload on W&B servers. Example -------- @@ -48,7 +49,8 @@ class WandbLogger(LightningLoggerBase): def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None, offline: bool = False, id: Optional[str] = None, anonymous: bool = False, version: Optional[str] = None, project: Optional[str] = None, - tags: Optional[List[str]] = None, experiment=None, entity=None): + tags: Optional[List[str]] = None, log_model: bool = False, + experiment=None, entity=None): super().__init__() self._name = name self._save_dir = save_dir @@ -59,6 +61,7 @@ class WandbLogger(LightningLoggerBase): self._experiment = experiment self._offline = offline self._entity = entity + self._log_model = log_model def __getstate__(self): state = self.__dict__.copy() @@ -85,6 +88,9 @@ class WandbLogger(LightningLoggerBase): self._experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, id=self._id, resume='allow', tags=self._tags, entity=self._entity) + # save checkpoints in wandb dir to upload on W&B servers + if self._log_model: + self.save_dir = self._experiment.dir return self._experiment def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):