Remove support for the deprecated torchtext legacy (#14375)
This commit is contained in:
parent
8950613552
commit
3ba0f56b18
|
@ -45,7 +45,7 @@ jobs:
|
|||
pip --version
|
||||
sudo pip uninstall -y lightning pytorch-lightning
|
||||
pip install fire
|
||||
python .actions/assistant.py requirements-prune-pkgs torch,torchvision,torchtext
|
||||
python .actions/assistant.py requirements-prune-pkgs torch,torchvision
|
||||
pip install ".[extra,test]"
|
||||
pip list
|
||||
env:
|
||||
|
|
|
@ -78,11 +78,11 @@ RUN \
|
|||
conda update -n base -c defaults conda && \
|
||||
CUDA_VERSION_MM=$(python -c "print('.'.join('$CUDA_VERSION'.split('.')[:2]))") && \
|
||||
conda create -y --name $CONDA_ENV \
|
||||
python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION} torchvision torchtext cudatoolkit=${CUDA_VERSION_MM} \
|
||||
python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION} torchvision cudatoolkit=${CUDA_VERSION_MM} \
|
||||
-c nvidia -c pytorch -c pytorch-test && \
|
||||
conda init bash && \
|
||||
# NOTE: this requires that the channel is presented in the yaml before packages \
|
||||
printf "import re;\nfname = 'environment.yml';\nreq = open(fname).read();\nfor n in ['python', 'pytorch', 'torchtext', 'torchvision']:\n req = re.sub(rf'- {n}[>=]+', f'# - {n}=', req);\nopen(fname, 'w').write(req)" > prune.py && \
|
||||
printf "import re;\nfname = 'environment.yml';\nreq = open(fname).read();\nfor n in ['python', 'pytorch', 'torchvision']:\n req = re.sub(rf'- {n}[>=]+', f'# - {n}=', req);\nopen(fname, 'w').write(req)" > prune.py && \
|
||||
python prune.py && \
|
||||
rm prune.py && \
|
||||
cat environment.yml && \
|
||||
|
@ -102,7 +102,7 @@ RUN \
|
|||
pip list | grep torch && \
|
||||
python -c "import torch; print(torch.__version__)" && \
|
||||
pip install -q fire && \
|
||||
python assistant.py requirements_prune_pkgs torch,torchvision,torchtext && \
|
||||
python assistant.py requirements_prune_pkgs torch,torchvision && \
|
||||
# Install remaining requirements
|
||||
pip install --no-cache-dir -r requirements/pytorch/base.txt \
|
||||
-r requirements/pytorch/extra.txt \
|
||||
|
|
|
@ -41,7 +41,6 @@ dependencies:
|
|||
- scikit-learn>=0.20.0
|
||||
- matplotlib>=3.1.1
|
||||
- omegaconf>=2.0.5
|
||||
- torchtext>=0.10.*
|
||||
|
||||
# Examples
|
||||
- torchvision>=0.10.*
|
||||
|
|
|
@ -5,14 +5,14 @@ from typing import Dict, Optional
|
|||
|
||||
# IMPORTANT: this list needs to be sorted in reverse
|
||||
VERSIONS = [
|
||||
dict(torch="1.12.1", torchvision="0.13.1", torchtext="0.13.1"), # stable
|
||||
dict(torch="1.12.0", torchvision="0.13.0", torchtext="0.13.0"),
|
||||
dict(torch="1.11.0", torchvision="0.12.0", torchtext="0.12.0"),
|
||||
dict(torch="1.10.2", torchvision="0.11.3", torchtext="0.11.2"),
|
||||
dict(torch="1.10.1", torchvision="0.11.2", torchtext="0.11.1"),
|
||||
dict(torch="1.10.0", torchvision="0.11.1", torchtext="0.11.0"),
|
||||
dict(torch="1.9.1", torchvision="0.10.1", torchtext="0.10.1"),
|
||||
dict(torch="1.9.0", torchvision="0.10.0", torchtext="0.10.0"),
|
||||
dict(torch="1.12.1", torchvision="0.13.1"), # stable
|
||||
dict(torch="1.12.0", torchvision="0.13.0"),
|
||||
dict(torch="1.11.0", torchvision="0.12.0"),
|
||||
dict(torch="1.10.2", torchvision="0.11.3"),
|
||||
dict(torch="1.10.1", torchvision="0.11.2"),
|
||||
dict(torch="1.10.0", torchvision="0.11.1"),
|
||||
dict(torch="1.9.1", torchvision="0.10.1"),
|
||||
dict(torch="1.9.0", torchvision="0.10.0"),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -4,4 +4,3 @@ import jsonargparse # noqa: F401
|
|||
import matplotlib # noqa: F401
|
||||
import omegaconf # noqa: F401
|
||||
import rich # noqa: F401
|
||||
import torchtext # noqa: F401
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
# extended list of package dependencies to reach full functionality
|
||||
matplotlib>3.1, <3.5.3
|
||||
torchtext>=0.10.*, <0.14.0
|
||||
omegaconf>=2.0.5, <2.3.0
|
||||
hydra-core>=1.0.5, <1.3.0
|
||||
jsonargparse[signatures]>=4.12.0, <=4.12.0
|
||||
|
|
|
@ -92,6 +92,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed the experimental `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868))
|
||||
|
||||
|
||||
- Removed deprecated support for old torchtext versions ([#14375](https://github.com/Lightning-AI/lightning/pull/14375))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an assertion error when using a `ReduceOnPlateau` scheduler with the Horovod strategy ([#14215](https://github.com/Lightning-AI/lightning/pull/14215))
|
||||
|
|
|
@ -624,7 +624,6 @@ class DataHooks:
|
|||
- :class:`list`
|
||||
- :class:`dict`
|
||||
- :class:`tuple`
|
||||
- :class:`torchtext.data.batch.Batch`
|
||||
|
||||
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_TORCH_GREATER_EQUAL_1_11,
|
||||
_TORCH_GREATER_EQUAL_1_12,
|
||||
_TORCH_QUANTIZE_AVAILABLE,
|
||||
_TORCHTEXT_AVAILABLE,
|
||||
_TORCHVISION_AVAILABLE,
|
||||
_TPU_AVAILABLE,
|
||||
_XLA_AVAILABLE,
|
||||
|
|
|
@ -14,10 +14,9 @@
|
|||
"""Utilities used for collections."""
|
||||
|
||||
import dataclasses
|
||||
import operator
|
||||
from abc import ABC
|
||||
from collections import defaultdict, OrderedDict
|
||||
from copy import copy, deepcopy
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
|
@ -26,17 +25,6 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
||||
|
||||
if _TORCHTEXT_LEGACY:
|
||||
if _compare_version("torchtext", operator.ge, "0.9.0"):
|
||||
from torchtext.legacy.data import Batch
|
||||
else:
|
||||
from torchtext.data import Batch
|
||||
else:
|
||||
Batch = type(None)
|
||||
|
||||
|
||||
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")
|
||||
|
||||
|
@ -326,23 +314,6 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
|
|||
device = torch.device(device)
|
||||
|
||||
def batch_to(data: Any) -> Any:
|
||||
# try to move torchtext data first
|
||||
if _TORCHTEXT_LEGACY and isinstance(data, Batch):
|
||||
# TODO: also remove the torchtext dependency with Lightning 1.8
|
||||
rank_zero_deprecation(
|
||||
"The `torchtext.legacy.Batch` object is deprecated and Lightning will remove support for it in v1.8."
|
||||
" We recommend you to migrate away from Batch by following the TorchText README:"
|
||||
" https://github.com/pytorch/text#bc-breaking-legacy"
|
||||
)
|
||||
# Shallow copy because each Batch has a reference to Dataset which contains all examples
|
||||
device_data = copy(data)
|
||||
for field, field_value in data.dataset.fields.items():
|
||||
if field_value is None:
|
||||
continue
|
||||
device_field = move_data_to_device(getattr(data, field), device)
|
||||
setattr(device_data, field, device_field)
|
||||
return device_data
|
||||
|
||||
kwargs = {}
|
||||
# Don't issue non-blocking transfers to CPU
|
||||
# Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015
|
||||
|
@ -354,8 +325,7 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
|
|||
# user wrongly implemented the `TransferableDataType` and forgot to return `self`.
|
||||
return data
|
||||
|
||||
dtype = (TransferableDataType, Batch) if _TORCHTEXT_LEGACY else TransferableDataType
|
||||
return apply_to_collection(batch, dtype=dtype, function=batch_to)
|
||||
return apply_to_collection(batch, dtype=TransferableDataType, function=batch_to)
|
||||
|
||||
|
||||
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
|
||||
|
|
|
@ -124,9 +124,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
|||
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
|
||||
)
|
||||
has_len = True
|
||||
except TypeError:
|
||||
has_len = False
|
||||
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
||||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader):
|
||||
|
@ -170,9 +168,7 @@ def has_len_all_ranks(
|
|||
else:
|
||||
has_len = True
|
||||
|
||||
except TypeError:
|
||||
has_len = False
|
||||
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
|
||||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader):
|
||||
|
|
|
@ -147,8 +147,6 @@ _POPTORCH_AVAILABLE = _package_available("poptorch")
|
|||
_PSUTIL_AVAILABLE = _package_available("psutil")
|
||||
_RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2")
|
||||
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
|
||||
_TORCHTEXT_AVAILABLE = _package_available("torchtext")
|
||||
_TORCHTEXT_LEGACY: bool = _TORCHTEXT_AVAILABLE and _compare_version("torchtext", operator.lt, "0.11.0")
|
||||
_TORCHVISION_AVAILABLE = _package_available("torchvision")
|
||||
_XLA_AVAILABLE: bool = _package_available("torch_xla")
|
||||
|
||||
|
|
|
@ -21,8 +21,6 @@ import tests_pytorch.helpers.pipelines as tpipes
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators import MPSAccelerator
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
|
||||
from tests_pytorch.helpers.imports import Batch, Dataset, Example, Field, LabelField
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
@ -135,30 +133,3 @@ def test_single_gpu_batch_parse():
|
|||
|
||||
batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("mps"))
|
||||
assert batch.a.type() == "torch.mps.FloatTensor"
|
||||
|
||||
# torchtext.data.Batch
|
||||
if not _TORCHTEXT_LEGACY:
|
||||
return
|
||||
|
||||
samples = [
|
||||
{"text": "PyTorch Lightning is awesome!", "label": 0},
|
||||
{"text": "Please make it work with torchtext", "label": 1},
|
||||
]
|
||||
|
||||
text_field = Field()
|
||||
label_field = LabelField()
|
||||
fields = {"text": ("text", text_field), "label": ("label", label_field)}
|
||||
|
||||
examples = [Example.fromdict(sample, fields) for sample in samples]
|
||||
dataset = Dataset(examples=examples, fields=fields.values())
|
||||
# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
|
||||
text_field.build_vocab(dataset)
|
||||
label_field.build_vocab(dataset)
|
||||
|
||||
batch = Batch(data=examples, dataset=dataset)
|
||||
|
||||
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
|
||||
batch = trainer.strategy.batch_to_device(batch, torch.device("mps"))
|
||||
|
||||
assert batch.text.type() == "torch.mps.LongTensor"
|
||||
assert batch.label.type() == "torch.mps.LongTensor"
|
||||
|
|
|
@ -35,20 +35,8 @@ from pytorch_lightning.strategies.ipu import LightningIPUModule
|
|||
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
|
||||
def test_v1_8_0_deprecated_torchtext_batch():
|
||||
|
||||
with pytest.deprecated_call(match="is deprecated and Lightning will remove support for it in v1.8"):
|
||||
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3)
|
||||
batch = next(iter(data_iterator))
|
||||
_ = move_data_to_device(batch=batch, device=torch.device("cpu"))
|
||||
|
||||
|
||||
def test_v1_8_0_on_init_start_end(tmpdir):
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
import operator
|
||||
|
||||
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
|
||||
|
||||
if _TORCHTEXT_LEGACY:
|
||||
if _compare_version("torchtext", operator.ge, "0.9.0"):
|
||||
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField
|
||||
else:
|
||||
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField
|
||||
else:
|
||||
Batch = type(None)
|
||||
Dataset = type(None)
|
||||
Example = type(None)
|
||||
Field = type(None)
|
||||
Iterator = type(None)
|
||||
LabelField = type(None)
|
|
@ -1,54 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import string
|
||||
|
||||
from tests_pytorch.helpers.imports import Dataset, Example, Field, Iterator
|
||||
|
||||
|
||||
def _generate_random_string(length: int = 10):
|
||||
return "".join(random.choices(string.ascii_letters, k=length))
|
||||
|
||||
|
||||
def get_dummy_torchtext_data_iterator(num_samples: int, batch_size: int, include_lengths: bool = False):
|
||||
text_field = Field(
|
||||
sequential=True,
|
||||
pad_first=False, # nosec
|
||||
init_token="<s>",
|
||||
eos_token="</s>", # nosec
|
||||
include_lengths=include_lengths,
|
||||
) # nosec
|
||||
|
||||
dataset = Dataset(
|
||||
[
|
||||
Example.fromdict({"text": _generate_random_string()}, {"text": ("text", text_field)})
|
||||
for _ in range(num_samples)
|
||||
],
|
||||
{"text": text_field},
|
||||
)
|
||||
text_field.build_vocab(dataset)
|
||||
|
||||
iterator = Iterator(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sort_key=None,
|
||||
device=None,
|
||||
batch_size_fn=None,
|
||||
train=True,
|
||||
repeat=False,
|
||||
shuffle=None,
|
||||
sort=None,
|
||||
sort_within_batch=None,
|
||||
)
|
||||
return iterator, text_field
|
|
@ -28,9 +28,8 @@ from pytorch_lightning.demos.boring_classes import BoringModel
|
|||
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
|
||||
from pytorch_lightning.utilities.imports import _compare_version
|
||||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
from tests_pytorch.helpers.imports import Batch, Dataset, Example, Field, LabelField
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.simple_models import ClassificationModel
|
||||
|
||||
|
@ -269,33 +268,6 @@ def test_single_gpu_batch_parse():
|
|||
batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("cuda:0"))
|
||||
assert batch.a.type() == "torch.cuda.FloatTensor"
|
||||
|
||||
# torchtext.data.Batch
|
||||
if not _TORCHTEXT_LEGACY:
|
||||
return
|
||||
|
||||
samples = [
|
||||
{"text": "PyTorch Lightning is awesome!", "label": 0},
|
||||
{"text": "Please make it work with torchtext", "label": 1},
|
||||
]
|
||||
|
||||
text_field = Field()
|
||||
label_field = LabelField()
|
||||
fields = {"text": ("text", text_field), "label": ("label", label_field)}
|
||||
|
||||
examples = [Example.fromdict(sample, fields) for sample in samples]
|
||||
dataset = Dataset(examples=examples, fields=fields.values())
|
||||
# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
|
||||
text_field.build_vocab(dataset)
|
||||
label_field.build_vocab(dataset)
|
||||
|
||||
batch = Batch(data=examples, dataset=dataset)
|
||||
|
||||
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
|
||||
batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0"))
|
||||
|
||||
assert batch.text.type() == "torch.cuda.LongTensor"
|
||||
assert batch.label.type() == "torch.cuda.LongTensor"
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_non_blocking():
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_lengths", [False, True])
|
||||
@pytest.mark.parametrize("device", [torch.device("cuda", 0)])
|
||||
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
|
||||
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3, include_lengths=include_lengths)
|
||||
data_iter = iter(data_iterator)
|
||||
batch = next(data_iter)
|
||||
|
||||
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
|
||||
batch_on_device = move_data_to_device(batch, device)
|
||||
|
||||
if include_lengths:
|
||||
# tensor with data
|
||||
assert batch_on_device.text[0].device == device
|
||||
# tensor with length of data
|
||||
assert batch_on_device.text[1].device == device
|
||||
else:
|
||||
assert batch_on_device.text.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_lengths", [False, True])
|
||||
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
|
||||
def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths):
|
||||
test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device("cpu"))
|
Loading…
Reference in New Issue