use fsspec instead of gfile for all IO (#3320)
* use fsspec instead of gfile for all IO This better supports remote (and local) file operations with a dedicated package * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * chlog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
parent
d521c1b178
commit
2d8c1b7c54
|
@ -9,10 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528/))
|
||||
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
|
||||
|
||||
### Changed
|
||||
|
||||
- Used `fsspec` instead of `gfile` for all IO ([#3320](https://github.com/PyTorchLightning/pytorch-lightning/pull/3320))
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ dependencies:
|
|||
- future>=0.17.1
|
||||
- PyYAML>=5.1
|
||||
- tqdm>=4.41.0
|
||||
- fsspec>=0.8.0
|
||||
- nvidia-apex
|
||||
|
||||
# For dev and testing
|
||||
|
|
|
@ -30,7 +30,7 @@ import torch
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
|
||||
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
|
@ -119,9 +119,11 @@ class ModelCheckpoint(Callback):
|
|||
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
|
||||
mode: str = 'auto', period: int = 1, prefix: str = ''):
|
||||
super().__init__()
|
||||
if(filepath):
|
||||
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
|
||||
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
|
||||
if filepath:
|
||||
self._fs = get_filesystem(filepath)
|
||||
else:
|
||||
self._fs = get_filesystem("") # will give local fileystem
|
||||
if save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0:
|
||||
rank_zero_warn(
|
||||
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
|
||||
"All files in this directory will be deleted when a checkpoint is saved!"
|
||||
|
@ -133,13 +135,13 @@ class ModelCheckpoint(Callback):
|
|||
if filepath is None: # will be determined by trainer at runtime
|
||||
self.dirpath, self.filename = None, None
|
||||
else:
|
||||
if gfile.isdir(filepath):
|
||||
self.dirpath, self.filename = filepath, '{epoch}'
|
||||
if self._fs.isdir(filepath):
|
||||
self.dirpath, self.filename = filepath, "{epoch}"
|
||||
else:
|
||||
if not is_remote_path(filepath): # dont normalize remote paths
|
||||
if self._fs.protocol == "file": # dont normalize remote paths
|
||||
filepath = os.path.realpath(filepath)
|
||||
self.dirpath, self.filename = os.path.split(filepath)
|
||||
makedirs(self.dirpath) # calls with exist_ok
|
||||
self._fs.makedirs(self.dirpath, exist_ok=True)
|
||||
self.save_last = save_last
|
||||
self.save_top_k = save_top_k
|
||||
self.save_weights_only = save_weights_only
|
||||
|
@ -182,19 +184,8 @@ class ModelCheckpoint(Callback):
|
|||
return self.kth_best_model_path
|
||||
|
||||
def _del_model(self, filepath):
|
||||
if gfile.exists(filepath):
|
||||
try:
|
||||
# in compat mode, remove is not implemented so if running this
|
||||
# against an actual remove file system and the correct remote
|
||||
# dependencies exist then this will work fine.
|
||||
gfile.remove(filepath)
|
||||
except AttributeError:
|
||||
if is_remote_path(filepath):
|
||||
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
|
||||
" Please install tensorflow to run gfile in full mode"
|
||||
" if writing checkpoints to remote locations")
|
||||
else:
|
||||
os.remove(filepath)
|
||||
if self._fs.exists(filepath):
|
||||
self._fs.rm(filepath)
|
||||
|
||||
def _save_model(self, filepath, trainer, pl_module):
|
||||
|
||||
|
@ -202,8 +193,7 @@ class ModelCheckpoint(Callback):
|
|||
trainer.dev_debugger.track_checkpointing_history(filepath)
|
||||
|
||||
# make paths
|
||||
if not gfile.exists(os.path.dirname(filepath)):
|
||||
makedirs(os.path.dirname(filepath))
|
||||
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
|
||||
# delegate the saving to the model
|
||||
if self.save_function is not None:
|
||||
|
@ -308,9 +298,8 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
self.dirpath = ckpt_path
|
||||
|
||||
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
|
||||
if not gfile.exists(self.dirpath):
|
||||
makedirs(self.dirpath)
|
||||
assert trainer.global_rank == 0, "tried to make a checkpoint from non global_rank=0"
|
||||
self._fs.makedirs(self.dirpath, exist_ok=True)
|
||||
|
||||
def __warn_deprecated_monitor_key(self):
|
||||
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
|
||||
|
@ -359,7 +348,7 @@ class ModelCheckpoint(Callback):
|
|||
ckpt_name_metrics = trainer.logged_metrics
|
||||
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
|
||||
version_cnt = 0
|
||||
while gfile.exists(filepath):
|
||||
while self._fs.exists(filepath):
|
||||
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
|
||||
# this epoch called before
|
||||
version_cnt += 1
|
||||
|
@ -435,4 +424,4 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
def on_load_checkpoint(self, checkpointed_state):
|
||||
self.best_model_score = checkpointed_state['best_model_score']
|
||||
self.best_model_path = checkpointed_state['best_model_path']
|
||||
self.best_model_path = checkpointed_state['best_model_path']
|
||||
|
|
|
@ -19,13 +19,15 @@ import os
|
|||
from argparse import Namespace
|
||||
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import gfile, cloud_open
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
||||
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
||||
|
@ -290,11 +292,12 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
|
|||
True
|
||||
>>> os.remove(path_csv)
|
||||
"""
|
||||
if not gfile.exists(tags_csv):
|
||||
fs = get_filesystem(tags_csv)
|
||||
if not fs.exists(tags_csv):
|
||||
rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
|
||||
return {}
|
||||
|
||||
with cloud_open(tags_csv, "r", newline="") as fp:
|
||||
with fs.open(tags_csv, "r", newline="") as fp:
|
||||
csv_reader = csv.reader(fp, delimiter=",")
|
||||
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
|
||||
|
||||
|
@ -302,13 +305,14 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
|
||||
if not gfile.isdir(os.path.dirname(tags_csv)):
|
||||
fs = get_filesystem(tags_csv)
|
||||
if not fs.isdir(os.path.dirname(tags_csv)):
|
||||
raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
|
||||
|
||||
if isinstance(hparams, Namespace):
|
||||
hparams = vars(hparams)
|
||||
|
||||
with cloud_open(tags_csv, "w", newline="") as fp:
|
||||
with fs.open(tags_csv, "w", newline="") as fp:
|
||||
fieldnames = ["key", "value"]
|
||||
writer = csv.DictWriter(fp, fieldnames=fieldnames)
|
||||
writer.writerow({"key": "key", "value": "value"})
|
||||
|
@ -327,11 +331,12 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
|
|||
True
|
||||
>>> os.remove(path_yaml)
|
||||
"""
|
||||
if not gfile.exists(config_yaml):
|
||||
fs = get_filesystem(config_yaml)
|
||||
if not fs.exists(config_yaml):
|
||||
rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
|
||||
return {}
|
||||
|
||||
with cloud_open(config_yaml, "r") as fp:
|
||||
with fs.open(config_yaml, "r") as fp:
|
||||
tags = yaml.load(fp)
|
||||
|
||||
return tags
|
||||
|
@ -343,7 +348,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
|||
config_yaml: path to new YAML file
|
||||
hparams: parameters to be saved
|
||||
"""
|
||||
if not gfile.isdir(os.path.dirname(config_yaml)):
|
||||
fs = get_filesystem(config_yaml)
|
||||
if not fs.isdir(os.path.dirname(config_yaml)):
|
||||
raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
|
||||
|
||||
# convert Namespace or AD to dict
|
||||
|
@ -364,7 +370,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
|||
|
||||
# saving the standard way
|
||||
assert isinstance(hparams, dict)
|
||||
with cloud_open(config_yaml, 'w', newline='') as fp:
|
||||
with fs.open(config_yaml, "w", newline="") as fp:
|
||||
yaml.dump(hparams, fp)
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
try:
|
||||
|
@ -87,6 +87,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
self._version = version
|
||||
self._log_graph = log_graph
|
||||
self._default_hp_metric = default_hp_metric
|
||||
self._fs = get_filesystem(save_dir)
|
||||
|
||||
self._experiment = None
|
||||
self.hparams = {}
|
||||
|
@ -136,8 +137,8 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
return self._experiment
|
||||
|
||||
assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
|
||||
if self.root_dir and not gfile.exists(str(self.root_dir)):
|
||||
makedirs(self.root_dir)
|
||||
if self.root_dir:
|
||||
self._fs.makedirs(self.root_dir, exist_ok=True)
|
||||
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
|
||||
return self._experiment
|
||||
|
||||
|
@ -207,7 +208,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
def save(self) -> None:
|
||||
super().save()
|
||||
dir_path = self.log_dir
|
||||
if not gfile.isdir(dir_path):
|
||||
if not self._fs.isdir(dir_path):
|
||||
dir_path = self.save_dir
|
||||
|
||||
# prepare the file path
|
||||
|
@ -233,16 +234,16 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
def _get_next_version(self):
|
||||
root_dir = os.path.join(self.save_dir, self.name)
|
||||
|
||||
if not gfile.isdir(root_dir):
|
||||
if not self._fs.isdir(root_dir):
|
||||
log.warning('Missing logger folder: %s', root_dir)
|
||||
return 0
|
||||
|
||||
existing_versions = []
|
||||
for d in gfile.listdir(root_dir):
|
||||
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
|
||||
dir_ver = d.split("_")[1].replace('/', '')
|
||||
for d in self._fs.ls(root_dir):
|
||||
bn = os.path.basename(d)
|
||||
if self._fs.isdir(d) and bn.startswith("version_"):
|
||||
dir_ver = bn.split("_")[1].replace('/', '')
|
||||
existing_versions.append(int(dir_ver))
|
||||
|
||||
if len(existing_versions) == 0:
|
||||
return 0
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
|||
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.cloud_io import is_remote_path
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
|
||||
|
@ -915,10 +915,9 @@ class Trainer(
|
|||
The default location to save artifacts of loggers, checkpoints etc.
|
||||
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
|
||||
"""
|
||||
if is_remote_path(self._default_root_dir):
|
||||
# it is a remote uri, use as is
|
||||
return self._default_root_dir
|
||||
return os.path.normpath(self._default_root_dir)
|
||||
if get_filesystem(self._default_root_dir).protocol == "file":
|
||||
return os.path.normpath(self._default_root_dir)
|
||||
return self._default_root_dir
|
||||
|
||||
@property
|
||||
def weights_save_path(self) -> str:
|
||||
|
@ -926,10 +925,9 @@ class Trainer(
|
|||
The default root location to save weights (checkpoints), e.g., when the
|
||||
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
|
||||
"""
|
||||
if is_remote_path(self._weights_save_path):
|
||||
# it is a remote uri, use as is
|
||||
return self._weights_save_path
|
||||
return os.path.normpath(self._weights_save_path)
|
||||
if get_filesystem(self._weights_save_path).protocol == "file":
|
||||
return os.path.normpath(self._weights_save_path)
|
||||
return self._weights_save_path
|
||||
|
||||
def tune(
|
||||
self,
|
||||
|
|
|
@ -114,9 +114,8 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import makedirs
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
|
||||
try:
|
||||
|
@ -391,8 +390,9 @@ class TrainerIOMixin(ABC):
|
|||
|
||||
# look for hpc weights
|
||||
folderpath = str(self.weights_save_path)
|
||||
if gfile.exists(folderpath):
|
||||
files = gfile.listdir(folderpath)
|
||||
fs = get_filesystem(folderpath)
|
||||
if fs.exists(folderpath):
|
||||
files = [os.path.basename(f) for f in fs.ls(folderpath)]
|
||||
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
|
||||
|
||||
# if hpc weights exist restore model
|
||||
|
@ -463,16 +463,15 @@ class TrainerIOMixin(ABC):
|
|||
def hpc_save(self, folderpath: str, logger):
|
||||
# make sure the checkpoint folder exists
|
||||
folderpath = str(folderpath) # because the tests pass a path object
|
||||
if not gfile.exists(folderpath):
|
||||
makedirs(folderpath)
|
||||
fs = get_filesystem(folderpath)
|
||||
fs.makedirs(folderpath, exist_ok=True)
|
||||
|
||||
# save logger to make sure we get all the metrics
|
||||
logger.save()
|
||||
|
||||
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
|
||||
|
||||
if not gfile.exists(folderpath):
|
||||
makedirs(folderpath)
|
||||
fs.makedirs(folderpath, exist_ok=True)
|
||||
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
|
||||
|
||||
# give model a chance to do something on hpc_save
|
||||
|
@ -525,7 +524,8 @@ class TrainerIOMixin(ABC):
|
|||
log.info(f'restored hpc model from: {filepath}')
|
||||
|
||||
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
|
||||
files = gfile.listdir(str(path))
|
||||
fs = get_filesystem(path)
|
||||
files = [os.path.basename(f) for f in fs.ls(path)]
|
||||
files = [x for x in files if name_key in x]
|
||||
if len(files) == 0:
|
||||
return 0
|
||||
|
|
|
@ -13,81 +13,31 @@
|
|||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import platform
|
||||
import os
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import torch
|
||||
import fsspec
|
||||
|
||||
import tensorboard
|
||||
from pytorch_lightning import _logger as log
|
||||
|
||||
# we want this for tf.io.gfile, which if tf is installed gives full tf,
|
||||
# otherwise gives a pruned down version which works for some file backends but
|
||||
# not all
|
||||
from tensorboard.compat import tf
|
||||
|
||||
gfile = tf.io.gfile
|
||||
|
||||
pathlike = Union[Path, str]
|
||||
|
||||
# older version of tensorboard had buggy gfile compatibility layers
|
||||
# only support remote cloud paths if newer
|
||||
|
||||
|
||||
def load(path_or_url: str, map_location=None):
|
||||
if urlparse(path_or_url).scheme == '' or Path(path_or_url).drive: # no scheme or with a drive letter
|
||||
if urlparse(path_or_url).scheme == "" or Path(path_or_url).drive: # no scheme or with a drive letter
|
||||
return torch.load(path_or_url, map_location=map_location)
|
||||
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
|
||||
|
||||
|
||||
def is_remote_path(path: pathlike):
|
||||
"""Determine if a path is a local path or a remote path like s3://bucket/path
|
||||
|
||||
This should catch paths like s3:// hdfs:// and gcs://
|
||||
"""
|
||||
return "://" in str(path)
|
||||
|
||||
|
||||
def modern_gfile():
|
||||
"""Check the version number of tensorboard.
|
||||
|
||||
Cheking to see if it has the gfile compatibility layers needed for remote
|
||||
file operations
|
||||
"""
|
||||
tb_version = LooseVersion(tensorboard.version.VERSION)
|
||||
modern_gfile = tb_version >= LooseVersion("2.0")
|
||||
return modern_gfile
|
||||
|
||||
|
||||
def cloud_open(path: pathlike, mode: str, newline: str = None):
|
||||
if platform.system() == "Windows":
|
||||
log.debug(
|
||||
"gfile does not handle newlines correctly on windows so remote files are not"
|
||||
" supported falling back to normal local file open."
|
||||
)
|
||||
return open(path, mode, newline=newline)
|
||||
if not modern_gfile():
|
||||
log.debug(
|
||||
"tenosrboard.compat gfile does not work on older versions "
|
||||
"of tensorboard for remote files, using normal local file open."
|
||||
)
|
||||
return open(path, mode, newline=newline)
|
||||
try:
|
||||
return gfile.GFile(path, mode)
|
||||
except NotImplementedError as e:
|
||||
# minimal dependencies are installed and only local files will work
|
||||
return open(path, mode, newline=newline)
|
||||
|
||||
|
||||
def makedirs(path: pathlike):
|
||||
if hasattr(gfile, "makedirs") and modern_gfile():
|
||||
if not gfile.exists(str(path)):
|
||||
return gfile.makedirs(str(path))
|
||||
# otherwise minimal dependencies are installed and only local files will work
|
||||
return os.makedirs(path, exist_ok=True)
|
||||
def get_filesystem(path: pathlike):
|
||||
path = str(path)
|
||||
if "://" in path:
|
||||
# use the fileystem from the protocol specified
|
||||
return fsspec.filesystem(path.split(":", 1)[0])
|
||||
else:
|
||||
# use local filesystem
|
||||
return fsspec.filesystem("file")
|
||||
|
||||
|
||||
def atomic_save(checkpoint, filepath: str):
|
||||
|
@ -108,5 +58,5 @@ def atomic_save(checkpoint, filepath: str):
|
|||
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
torch.save(checkpoint, bytesbuffer)
|
||||
with cloud_open(filepath, 'wb') as f:
|
||||
with fsspec.open(filepath, "wb") as f:
|
||||
f.write(bytesbuffer.getvalue())
|
||||
|
|
|
@ -7,3 +7,4 @@ future>=0.17.1 # required for builtins in setup.py
|
|||
# pyyaml>=3.13
|
||||
PyYAML>=5.1 # OmegaConf requirement >=5.1
|
||||
tqdm>=4.41.0
|
||||
fsspec>=0.8.0
|
||||
|
|
Loading…
Reference in New Issue