diff --git a/CHANGELOG.md b/CHANGELOG.md index d18fac2668..9ab690429b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/)) +- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528)) ### Changed +- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320)) ### Deprecated diff --git a/environment.yml b/environment.yml index 2c93fd594d..b326ed3440 100644 --- a/environment.yml +++ b/environment.yml @@ -31,6 +31,7 @@ dependencies: - future>=0.17.1 - PyYAML>=5.1 - tqdm>=4.41.0 + - fsspec>=0.8.0 - nvidia-apex # For dev and testing diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 63dddd1aa1..71a502fe1e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -30,7 +30,7 @@ import torch from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only -from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path +from pytorch_lightning.utilities.cloud_io import get_filesystem class ModelCheckpoint(Callback): @@ -119,9 +119,11 @@ class ModelCheckpoint(Callback): save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() - if(filepath): - filepath = str(filepath) # the tests pass in a py.path.local but we want a str - if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0: + if filepath: + self._fs = get_filesystem(filepath) + else: + self._fs = get_filesystem("") # will give local fileystem + if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" @@ -133,13 +135,13 @@ class ModelCheckpoint(Callback): if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: - if gfile.isdir(filepath): - self.dirpath, self.filename = filepath, '{epoch}' + if self._fs.isdir(filepath): + self.dirpath, self.filename = filepath, "{epoch}" else: - if not is_remote_path(filepath): # dont normalize remote paths + if self._fs.protocol == "file": # dont normalize remote paths filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) - makedirs(self.dirpath) # calls with exist_ok + self._fs.makedirs(self.dirpath, exist_ok=True) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only @@ -182,19 +184,8 @@ class ModelCheckpoint(Callback): return self.kth_best_model_path def _del_model(self, filepath): - if gfile.exists(filepath): - try: - # in compat mode, remove is not implemented so if running this - # against an actual remove file system and the correct remote - # dependencies exist then this will work fine. - gfile.remove(filepath) - except AttributeError: - if is_remote_path(filepath): - log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode." - " Please install tensorflow to run gfile in full mode" - " if writing checkpoints to remote locations") - else: - os.remove(filepath) + if self._fs.exists(filepath): + self._fs.rm(filepath) def _save_model(self, filepath, trainer, pl_module): @@ -202,8 +193,7 @@ class ModelCheckpoint(Callback): trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - if not gfile.exists(os.path.dirname(filepath)): - makedirs(os.path.dirname(filepath)) + self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the model if self.save_function is not None: @@ -308,9 +298,8 @@ class ModelCheckpoint(Callback): self.dirpath = ckpt_path - assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' - if not gfile.exists(self.dirpath): - makedirs(self.dirpath) + assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0" + self._fs.makedirs(self.dirpath, exist_ok=True) def __warn_deprecated_monitor_key(self): using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None) @@ -359,7 +348,7 @@ class ModelCheckpoint(Callback): ckpt_name_metrics = trainer.logged_metrics filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) version_cnt = 0 - while gfile.exists(filepath): + while self._fs.exists(filepath): filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt) # this epoch called before version_cnt += 1 @@ -435,4 +424,4 @@ class ModelCheckpoint(Callback): def on_load_checkpoint(self, checkpointed_state): self.best_model_score = checkpointed_state['best_model_score'] - self.best_model_path = checkpointed_state['best_model_path'] \ No newline at end of file + self.best_model_path = checkpointed_state['best_model_path'] diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 68cd105214..e012799b69 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -19,13 +19,15 @@ import os from argparse import Namespace from typing import Union, Dict, Any, Optional, Callable, MutableMapping +import fsspec import torch import yaml from pytorch_lightning import _logger as log from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.cloud_io import gfile, cloud_open +from pytorch_lightning.utilities.cloud_io import get_filesystem + PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) @@ -290,11 +292,12 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: True >>> os.remove(path_csv) """ - if not gfile.exists(tags_csv): + fs = get_filesystem(tags_csv) + if not fs.exists(tags_csv): rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning) return {} - with cloud_open(tags_csv, "r", newline="") as fp: + with fs.open(tags_csv, "r", newline="") as fp: csv_reader = csv.reader(fp, delimiter=",") tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} @@ -302,13 +305,14 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: - if not gfile.isdir(os.path.dirname(tags_csv)): + fs = get_filesystem(tags_csv) + if not fs.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") if isinstance(hparams, Namespace): hparams = vars(hparams) - with cloud_open(tags_csv, "w", newline="") as fp: + with fs.open(tags_csv, "w", newline="") as fp: fieldnames = ["key", "value"] writer = csv.DictWriter(fp, fieldnames=fieldnames) writer.writerow({"key": "key", "value": "value"}) @@ -327,11 +331,12 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: True >>> os.remove(path_yaml) """ - if not gfile.exists(config_yaml): + fs = get_filesystem(config_yaml) + if not fs.exists(config_yaml): rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) return {} - with cloud_open(config_yaml, "r") as fp: + with fs.open(config_yaml, "r") as fp: tags = yaml.load(fp) return tags @@ -343,7 +348,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: config_yaml: path to new YAML file hparams: parameters to be saved """ - if not gfile.isdir(os.path.dirname(config_yaml)): + fs = get_filesystem(config_yaml) + if not fs.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") # convert Namespace or AD to dict @@ -364,7 +370,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving the standard way assert isinstance(hparams, dict) - with cloud_open(config_yaml, 'w', newline='') as fp: + with fs.open(config_yaml, "w", newline="") as fp: yaml.dump(hparams, fp) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index b6b1bd85f1..d9c5125173 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -30,7 +30,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn -from pytorch_lightning.utilities.cloud_io import gfile, makedirs +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.core.lightning import LightningModule try: @@ -87,6 +87,7 @@ class TensorBoardLogger(LightningLoggerBase): self._version = version self._log_graph = log_graph self._default_hp_metric = default_hp_metric + self._fs = get_filesystem(save_dir) self._experiment = None self.hparams = {} @@ -136,8 +137,8 @@ class TensorBoardLogger(LightningLoggerBase): return self._experiment assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0' - if self.root_dir and not gfile.exists(str(self.root_dir)): - makedirs(self.root_dir) + if self.root_dir: + self._fs.makedirs(self.root_dir, exist_ok=True) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment @@ -207,7 +208,7 @@ class TensorBoardLogger(LightningLoggerBase): def save(self) -> None: super().save() dir_path = self.log_dir - if not gfile.isdir(dir_path): + if not self._fs.isdir(dir_path): dir_path = self.save_dir # prepare the file path @@ -233,16 +234,16 @@ class TensorBoardLogger(LightningLoggerBase): def _get_next_version(self): root_dir = os.path.join(self.save_dir, self.name) - if not gfile.isdir(root_dir): + if not self._fs.isdir(root_dir): log.warning('Missing logger folder: %s', root_dir) return 0 existing_versions = [] - for d in gfile.listdir(root_dir): - if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): - dir_ver = d.split("_")[1].replace('/', '') + for d in self._fs.ls(root_dir): + bn = os.path.basename(d) + if self._fs.isdir(d) and bn.startswith("version_"): + dir_ver = bn.split("_")[1].replace('/', '') existing_versions.append(int(dir_ver)) - if len(existing_versions) == 0: return 0 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 50cd439674..3554a9ae78 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -51,7 +51,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.cloud_io import is_remote_path +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector @@ -915,10 +915,9 @@ class Trainer( The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. """ - if is_remote_path(self._default_root_dir): - # it is a remote uri, use as is - return self._default_root_dir - return os.path.normpath(self._default_root_dir) + if get_filesystem(self._default_root_dir).protocol == "file": + return os.path.normpath(self._default_root_dir) + return self._default_root_dir @property def weights_save_path(self) -> str: @@ -926,10 +925,9 @@ class Trainer( The default root location to save weights (checkpoints), e.g., when the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. """ - if is_remote_path(self._weights_save_path): - # it is a remote uri, use as is - return self._weights_save_path - return os.path.normpath(self._weights_save_path) + if get_filesystem(self._weights_save_path).protocol == "file": + return os.path.normpath(self._weights_save_path) + return self._weights_save_path def tune( self, diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 0f7df23e71..1bc530f5ec 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -114,9 +114,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel from pytorch_lightning.utilities import AMPType, rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save, gfile +from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.cloud_io import makedirs from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS try: @@ -391,8 +390,9 @@ class TrainerIOMixin(ABC): # look for hpc weights folderpath = str(self.weights_save_path) - if gfile.exists(folderpath): - files = gfile.listdir(folderpath) + fs = get_filesystem(folderpath) + if fs.exists(folderpath): + files = [os.path.basename(f) for f in fs.ls(folderpath)] hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model @@ -463,16 +463,15 @@ class TrainerIOMixin(ABC): def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object - if not gfile.exists(folderpath): - makedirs(folderpath) + fs = get_filesystem(folderpath) + fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 - if not gfile.exists(folderpath): - makedirs(folderpath) + fs.makedirs(folderpath, exist_ok=True) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save @@ -525,7 +524,8 @@ class TrainerIOMixin(ABC): log.info(f'restored hpc model from: {filepath}') def max_ckpt_in_folder(self, path, name_key='ckpt_'): - files = gfile.listdir(str(path)) + fs = get_filesystem(path) + files = [os.path.basename(f) for f in fs.ls(path)] files = [x for x in files if name_key in x] if len(files) == 0: return 0 diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 26d906a794..2c6771c5d6 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -13,81 +13,31 @@ # limitations under the License. import io -import platform -import os from distutils.version import LooseVersion from typing import Union from pathlib import Path from urllib.parse import urlparse import torch +import fsspec -import tensorboard -from pytorch_lightning import _logger as log - -# we want this for tf.io.gfile, which if tf is installed gives full tf, -# otherwise gives a pruned down version which works for some file backends but -# not all -from tensorboard.compat import tf - -gfile = tf.io.gfile pathlike = Union[Path, str] -# older version of tensorboard had buggy gfile compatibility layers -# only support remote cloud paths if newer - def load(path_or_url: str, map_location=None): - if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter + if urlparse(path_or_url).scheme == "" or Path(path_or_url).drive: # no scheme or with a drive letter return torch.load(path_or_url, map_location=map_location) return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) -def is_remote_path(path: pathlike): - """Determine if a path is a local path or a remote path like s3://bucket/path - - This should catch paths like s3:// hdfs:// and gcs:// - """ - return "://" in str(path) - - -def modern_gfile(): - """Check the version number of tensorboard. - - Cheking to see if it has the gfile compatibility layers needed for remote - file operations - """ - tb_version = LooseVersion(tensorboard.version.VERSION) - modern_gfile = tb_version >= LooseVersion("2.0") - return modern_gfile - - -def cloud_open(path: pathlike, mode: str, newline: str = None): - if platform.system() == "Windows": - log.debug( - "gfile does not handle newlines correctly on windows so remote files are not" - " supported falling back to normal local file open." - ) - return open(path, mode, newline=newline) - if not modern_gfile(): - log.debug( - "tenosrboard.compat gfile does not work on older versions " - "of tensorboard for remote files, using normal local file open." - ) - return open(path, mode, newline=newline) - try: - return gfile.GFile(path, mode) - except NotImplementedError as e: - # minimal dependencies are installed and only local files will work - return open(path, mode, newline=newline) - - -def makedirs(path: pathlike): - if hasattr(gfile, "makedirs") and modern_gfile(): - if not gfile.exists(str(path)): - return gfile.makedirs(str(path)) - # otherwise minimal dependencies are installed and only local files will work - return os.makedirs(path, exist_ok=True) +def get_filesystem(path: pathlike): + path = str(path) + if "://" in path: + # use the fileystem from the protocol specified + return fsspec.filesystem(path.split(":", 1)[0]) + else: + # use local filesystem + return fsspec.filesystem("file") def atomic_save(checkpoint, filepath: str): @@ -108,5 +58,5 @@ def atomic_save(checkpoint, filepath: str): torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False) else: torch.save(checkpoint, bytesbuffer) - with cloud_open(filepath, 'wb') as f: + with fsspec.open(filepath, "wb") as f: f.write(bytesbuffer.getvalue()) diff --git a/requirements/base.txt b/requirements/base.txt index 83b0341ce4..3e30d2b93d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement >=5.1 tqdm>=4.41.0 +fsspec>=0.8.0