use fsspec instead of gfile for all IO (#3320)

* use fsspec instead of gfile for all IO

This better supports remote (and local) file operations with a dedicated package

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
Brendan Fahy 2020-09-03 12:19:20 +00:00 committed by GitHub
parent d521c1b178
commit 2d8c1b7c54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 73 additions and 126 deletions

View File

@ -9,10 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
### Changed
- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320))
### Deprecated

View File

@ -31,6 +31,7 @@ dependencies:
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
- fsspec>=0.8.0
- nvidia-apex
# For dev and testing

View File

@ -30,7 +30,7 @@ import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path
from pytorch_lightning.utilities.cloud_io import get_filesystem
class ModelCheckpoint(Callback):
@ -119,9 +119,11 @@ class ModelCheckpoint(Callback):
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
if filepath:
self._fs = get_filesystem(filepath)
else:
self._fs = get_filesystem("") # will give local fileystem
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
@ -133,13 +135,13 @@ class ModelCheckpoint(Callback):
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
if self._fs.isdir(filepath):
self.dirpath, self.filename = filepath, "{epoch}"
else:
if not is_remote_path(filepath): # dont normalize remote paths
if self._fs.protocol == "file": # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
makedirs(self.dirpath) # calls with exist_ok
self._fs.makedirs(self.dirpath, exist_ok=True)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
@ -182,19 +184,8 @@ class ModelCheckpoint(Callback):
return self.kth_best_model_path
def _del_model(self, filepath):
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
if is_remote_path(filepath):
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
" Please install tensorflow to run gfile in full mode"
" if writing checkpoints to remote locations")
else:
os.remove(filepath)
if self._fs.exists(filepath):
self._fs.rm(filepath)
def _save_model(self, filepath, trainer, pl_module):
@ -202,8 +193,7 @@ class ModelCheckpoint(Callback):
trainer.dev_debugger.track_checkpointing_history(filepath)
# make paths
if not gfile.exists(os.path.dirname(filepath)):
makedirs(os.path.dirname(filepath))
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
# delegate the saving to the model
if self.save_function is not None:
@ -308,9 +298,8 @@ class ModelCheckpoint(Callback):
self.dirpath = ckpt_path
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)
def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
@ -359,7 +348,7 @@ class ModelCheckpoint(Callback):
ckpt_name_metrics = trainer.logged_metrics
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while gfile.exists(filepath):
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
@ -435,4 +424,4 @@ class ModelCheckpoint(Callback):
def on_load_checkpoint(self, checkpointed_state):
self.best_model_score = checkpointed_state['best_model_score']
self.best_model_path = checkpointed_state['best_model_path']
self.best_model_path = checkpointed_state['best_model_path']

View File

@ -19,13 +19,15 @@ import os
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
import fsspec
import torch
import yaml
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
from pytorch_lightning.utilities.cloud_io import get_filesystem
PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
@ -290,11 +292,12 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not gfile.exists(tags_csv):
fs = get_filesystem(tags_csv)
if not fs.exists(tags_csv):
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
return {}
with cloud_open(tags_csv, "r", newline="") as fp:
with fs.open(tags_csv, "r", newline="") as fp:
csv_reader = csv.reader(fp, delimiter=",")
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
@ -302,13 +305,14 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not gfile.isdir(os.path.dirname(tags_csv)):
fs = get_filesystem(tags_csv)
if not fs.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
if isinstance(hparams, Namespace):
hparams = vars(hparams)
with cloud_open(tags_csv, "w", newline="") as fp:
with fs.open(tags_csv, "w", newline="") as fp:
fieldnames = ["key", "value"]
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({"key": "key", "value": "value"})
@ -327,11 +331,12 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not gfile.exists(config_yaml):
fs = get_filesystem(config_yaml)
if not fs.exists(config_yaml):
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
return {}
with cloud_open(config_yaml, "r") as fp:
with fs.open(config_yaml, "r") as fp:
tags = yaml.load(fp)
return tags
@ -343,7 +348,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not gfile.isdir(os.path.dirname(config_yaml)):
fs = get_filesystem(config_yaml)
if not fs.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
# convert Namespace or AD to dict
@ -364,7 +370,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
# saving the standard way
assert isinstance(hparams, dict)
with cloud_open(config_yaml, 'w', newline='') as fp:
with fs.open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)

View File

@ -30,7 +30,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.core.lightning import LightningModule
try:
@ -87,6 +87,7 @@ class TensorBoardLogger(LightningLoggerBase):
self._version = version
self._log_graph = log_graph
self._default_hp_metric = default_hp_metric
self._fs = get_filesystem(save_dir)
self._experiment = None
self.hparams = {}
@ -136,8 +137,8 @@ class TensorBoardLogger(LightningLoggerBase):
return self._experiment
assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
if self.root_dir and not gfile.exists(str(self.root_dir)):
makedirs(self.root_dir)
if self.root_dir:
self._fs.makedirs(self.root_dir, exist_ok=True)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment
@ -207,7 +208,7 @@ class TensorBoardLogger(LightningLoggerBase):
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not gfile.isdir(dir_path):
if not self._fs.isdir(dir_path):
dir_path = self.save_dir
# prepare the file path
@ -233,16 +234,16 @@ class TensorBoardLogger(LightningLoggerBase):
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)
if not gfile.isdir(root_dir):
if not self._fs.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0
existing_versions = []
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
dir_ver = d.split("_")[1].replace('/', '')
for d in self._fs.ls(root_dir):
bn = os.path.basename(d)
if self._fs.isdir(d) and bn.startswith("version_"):
dir_ver = bn.split("_")[1].replace('/', '')
existing_versions.append(int(dir_ver))
if len(existing_versions) == 0:
return 0

View File

@ -51,7 +51,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
@ -915,10 +915,9 @@ class Trainer(
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if is_remote_path(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)
if get_filesystem(self._default_root_dir).protocol == "file":
return os.path.normpath(self._default_root_dir)
return self._default_root_dir
@property
def weights_save_path(self) -> str:
@ -926,10 +925,9 @@ class Trainer(
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if is_remote_path(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)
if get_filesystem(self._weights_save_path).protocol == "file":
return os.path.normpath(self._weights_save_path)
return self._weights_save_path
def tune(
self,

View File

@ -114,9 +114,8 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import makedirs
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
try:
@ -391,8 +390,9 @@ class TrainerIOMixin(ABC):
# look for hpc weights
folderpath = str(self.weights_save_path)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
fs = get_filesystem(folderpath)
if fs.exists(folderpath):
files = [os.path.basename(f) for f in fs.ls(folderpath)]
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
# if hpc weights exist restore model
@ -463,16 +463,15 @@ class TrainerIOMixin(ABC):
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
if not gfile.exists(folderpath):
makedirs(folderpath)
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)
# save logger to make sure we get all the metrics
logger.save()
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
if not gfile.exists(folderpath):
makedirs(folderpath)
fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
# give model a chance to do something on hpc_save
@ -525,7 +524,8 @@ class TrainerIOMixin(ABC):
log.info(f'restored hpc model from: {filepath}')
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = gfile.listdir(str(path))
fs = get_filesystem(path)
files = [os.path.basename(f) for f in fs.ls(path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0

View File

@ -13,81 +13,31 @@
# limitations under the License.
import io
import platform
import os
from distutils.version import LooseVersion
from typing import Union
from pathlib import Path
from urllib.parse import urlparse
import torch
import fsspec
import tensorboard
from pytorch_lightning import _logger as log
# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf
gfile = tf.io.gfile
pathlike = Union[Path, str]
# older version of tensorboard had buggy gfile compatibility layers
# only support remote cloud paths if newer
def load(path_or_url: str, map_location=None):
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
if urlparse(path_or_url).scheme == "" or Path(path_or_url).drive: # no scheme or with a drive letter
return torch.load(path_or_url, map_location=map_location)
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
def is_remote_path(path: pathlike):
"""Determine if a path is a local path or a remote path like s3://bucket/path
This should catch paths like s3:// hdfs:// and gcs://
"""
return "://" in str(path)
def modern_gfile():
"""Check the version number of tensorboard.
Cheking to see if it has the gfile compatibility layers needed for remote
file operations
"""
tb_version = LooseVersion(tensorboard.version.VERSION)
modern_gfile = tb_version >= LooseVersion("2.0")
return modern_gfile
def cloud_open(path: pathlike, mode: str, newline: str = None):
if platform.system() == "Windows":
log.debug(
"gfile does not handle newlines correctly on windows so remote files are not"
" supported falling back to normal local file open."
)
return open(path, mode, newline=newline)
if not modern_gfile():
log.debug(
"tenosrboard.compat gfile does not work on older versions "
"of tensorboard for remote files, using normal local file open."
)
return open(path, mode, newline=newline)
try:
return gfile.GFile(path, mode)
except NotImplementedError as e:
# minimal dependencies are installed and only local files will work
return open(path, mode, newline=newline)
def makedirs(path: pathlike):
if hasattr(gfile, "makedirs") and modern_gfile():
if not gfile.exists(str(path)):
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)
def get_filesystem(path: pathlike):
path = str(path)
if "://" in path:
# use the fileystem from the protocol specified
return fsspec.filesystem(path.split(":", 1)[0])
else:
# use local filesystem
return fsspec.filesystem("file")
def atomic_save(checkpoint, filepath: str):
@ -108,5 +58,5 @@ def atomic_save(checkpoint, filepath: str):
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, bytesbuffer)
with cloud_open(filepath, 'wb') as f:
with fsspec.open(filepath, "wb") as f:
f.write(bytesbuffer.getvalue())

View File

@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement >=5.1
tqdm>=4.41.0
fsspec>=0.8.0