Remove support for the deprecated torchtext legacy (#14375)

This commit is contained in:
Carlos Mocholí 2022-08-26 22:01:51 +02:00 committed by GitHub
parent 8950613552
commit 3ba0f56b18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 19 additions and 244 deletions

View File

@ -45,7 +45,7 @@ jobs:
pip --version
sudo pip uninstall -y lightning pytorch-lightning
pip install fire
python .actions/assistant.py requirements-prune-pkgs torch,torchvision,torchtext
python .actions/assistant.py requirements-prune-pkgs torch,torchvision
pip install ".[extra,test]"
pip list
env:

View File

@ -78,11 +78,11 @@ RUN \
conda update -n base -c defaults conda && \
CUDA_VERSION_MM=$(python -c "print('.'.join('$CUDA_VERSION'.split('.')[:2]))") && \
conda create -y --name $CONDA_ENV \
python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION} torchvision torchtext cudatoolkit=${CUDA_VERSION_MM} \
python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION} torchvision cudatoolkit=${CUDA_VERSION_MM} \
-c nvidia -c pytorch -c pytorch-test && \
conda init bash && \
# NOTE: this requires that the channel is presented in the yaml before packages \
printf "import re;\nfname = 'environment.yml';\nreq = open(fname).read();\nfor n in ['python', 'pytorch', 'torchtext', 'torchvision']:\n req = re.sub(rf'- {n}[>=]+', f'# - {n}=', req);\nopen(fname, 'w').write(req)" > prune.py && \
printf "import re;\nfname = 'environment.yml';\nreq = open(fname).read();\nfor n in ['python', 'pytorch', 'torchvision']:\n req = re.sub(rf'- {n}[>=]+', f'# - {n}=', req);\nopen(fname, 'w').write(req)" > prune.py && \
python prune.py && \
rm prune.py && \
cat environment.yml && \
@ -102,7 +102,7 @@ RUN \
pip list | grep torch && \
python -c "import torch; print(torch.__version__)" && \
pip install -q fire && \
python assistant.py requirements_prune_pkgs torch,torchvision,torchtext && \
python assistant.py requirements_prune_pkgs torch,torchvision && \
# Install remaining requirements
pip install --no-cache-dir -r requirements/pytorch/base.txt \
-r requirements/pytorch/extra.txt \

View File

@ -41,7 +41,6 @@ dependencies:
- scikit-learn>=0.20.0
- matplotlib>=3.1.1
- omegaconf>=2.0.5
- torchtext>=0.10.*
# Examples
- torchvision>=0.10.*

View File

@ -5,14 +5,14 @@ from typing import Dict, Optional
# IMPORTANT: this list needs to be sorted in reverse
VERSIONS = [
dict(torch="1.12.1", torchvision="0.13.1", torchtext="0.13.1"), # stable
dict(torch="1.12.0", torchvision="0.13.0", torchtext="0.13.0"),
dict(torch="1.11.0", torchvision="0.12.0", torchtext="0.12.0"),
dict(torch="1.10.2", torchvision="0.11.3", torchtext="0.11.2"),
dict(torch="1.10.1", torchvision="0.11.2", torchtext="0.11.1"),
dict(torch="1.10.0", torchvision="0.11.1", torchtext="0.11.0"),
dict(torch="1.9.1", torchvision="0.10.1", torchtext="0.10.1"),
dict(torch="1.9.0", torchvision="0.10.0", torchtext="0.10.0"),
dict(torch="1.12.1", torchvision="0.13.1"), # stable
dict(torch="1.12.0", torchvision="0.13.0"),
dict(torch="1.11.0", torchvision="0.12.0"),
dict(torch="1.10.2", torchvision="0.11.3"),
dict(torch="1.10.1", torchvision="0.11.2"),
dict(torch="1.10.0", torchvision="0.11.1"),
dict(torch="1.9.1", torchvision="0.10.1"),
dict(torch="1.9.0", torchvision="0.10.0"),
]

View File

@ -4,4 +4,3 @@ import jsonargparse # noqa: F401
import matplotlib # noqa: F401
import omegaconf # noqa: F401
import rich # noqa: F401
import torchtext # noqa: F401

View File

@ -3,7 +3,6 @@
# extended list of package dependencies to reach full functionality
matplotlib>3.1, <3.5.3
torchtext>=0.10.*, <0.14.0
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.12.0, <=4.12.0

View File

@ -92,6 +92,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the experimental `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868))
- Removed deprecated support for old torchtext versions ([#14375](https://github.com/Lightning-AI/lightning/pull/14375))
### Fixed
- Fixed an assertion error when using a `ReduceOnPlateau` scheduler with the Horovod strategy ([#14215](https://github.com/Lightning-AI/lightning/pull/14215))

View File

@ -624,7 +624,6 @@ class DataHooks:
- :class:`list`
- :class:`dict`
- :class:`tuple`
- :class:`torchtext.data.batch.Batch`
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).

View File

@ -43,7 +43,6 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_TORCH_GREATER_EQUAL_1_11,
_TORCH_GREATER_EQUAL_1_12,
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TPU_AVAILABLE,
_XLA_AVAILABLE,

View File

@ -14,10 +14,9 @@
"""Utilities used for collections."""
import dataclasses
import operator
from abc import ABC
from collections import defaultdict, OrderedDict
from copy import copy, deepcopy
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union
@ -26,17 +25,6 @@ import torch
from torch import Tensor
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
if _TORCHTEXT_LEGACY:
if _compare_version("torchtext", operator.ge, "0.9.0"):
from torchtext.legacy.data import Batch
else:
from torchtext.data import Batch
else:
Batch = type(None)
_BLOCKING_DEVICE_TYPES = ("cpu", "mps")
@ -326,23 +314,6 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
device = torch.device(device)
def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_LEGACY and isinstance(data, Batch):
# TODO: also remove the torchtext dependency with Lightning 1.8
rank_zero_deprecation(
"The `torchtext.legacy.Batch` object is deprecated and Lightning will remove support for it in v1.8."
" We recommend you to migrate away from Batch by following the TorchText README:"
" https://github.com/pytorch/text#bc-breaking-legacy"
)
# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field, field_value in data.dataset.fields.items():
if field_value is None:
continue
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data
kwargs = {}
# Don't issue non-blocking transfers to CPU
# Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015
@ -354,8 +325,7 @@ def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
# user wrongly implemented the `TransferableDataType` and forgot to return `self`.
return data
dtype = (TransferableDataType, Batch) if _TORCHTEXT_LEGACY else TransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to)
return apply_to_collection(batch, dtype=TransferableDataType, function=batch_to)
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:

View File

@ -124,9 +124,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
)
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
except (TypeError, NotImplementedError):
has_len = False
if has_len and has_iterable_dataset(dataloader):
@ -170,9 +168,7 @@ def has_len_all_ranks(
else:
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
except (TypeError, NotImplementedError):
has_len = False
if has_len and has_iterable_dataset(dataloader):

View File

@ -147,8 +147,6 @@ _POPTORCH_AVAILABLE = _package_available("poptorch")
_PSUTIL_AVAILABLE = _package_available("psutil")
_RICH_AVAILABLE = _package_available("rich") and _compare_version("rich", operator.ge, "10.2.2")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCHTEXT_AVAILABLE = _package_available("torchtext")
_TORCHTEXT_LEGACY: bool = _TORCHTEXT_AVAILABLE and _compare_version("torchtext", operator.lt, "0.11.0")
_TORCHVISION_AVAILABLE = _package_available("torchvision")
_XLA_AVAILABLE: bool = _package_available("torch_xla")

View File

@ -21,8 +21,6 @@ import tests_pytorch.helpers.pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import MPSAccelerator
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from tests_pytorch.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests_pytorch.helpers.runif import RunIf
@ -135,30 +133,3 @@ def test_single_gpu_batch_parse():
batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("mps"))
assert batch.a.type() == "torch.mps.FloatTensor"
# torchtext.data.Batch
if not _TORCHTEXT_LEGACY:
return
samples = [
{"text": "PyTorch Lightning is awesome!", "label": 0},
{"text": "Please make it work with torchtext", "label": 1},
]
text_field = Field()
label_field = LabelField()
fields = {"text": ("text", text_field), "label": ("label", label_field)}
examples = [Example.fromdict(sample, fields) for sample in samples]
dataset = Dataset(examples=examples, fields=fields.values())
# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
text_field.build_vocab(dataset)
label_field.build_vocab(dataset)
batch = Batch(data=examples, dataset=dataset)
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
batch = trainer.strategy.batch_to_device(batch, torch.device("mps"))
assert batch.text.type() == "torch.mps.LongTensor"
assert batch.label.type() == "torch.mps.LongTensor"

View File

@ -35,20 +35,8 @@ from pytorch_lightning.strategies.ipu import LightningIPUModule
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
def test_v1_8_0_deprecated_torchtext_batch():
with pytest.deprecated_call(match="is deprecated and Lightning will remove support for it in v1.8"):
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3)
batch = next(iter(data_iterator))
_ = move_data_to_device(batch=batch, device=torch.device("cpu"))
def test_v1_8_0_on_init_start_end(tmpdir):

View File

@ -1,16 +0,0 @@
import operator
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
if _TORCHTEXT_LEGACY:
if _compare_version("torchtext", operator.ge, "0.9.0"):
from torchtext.legacy.data import Batch, Dataset, Example, Field, Iterator, LabelField
else:
from torchtext.data import Batch, Dataset, Example, Field, Iterator, LabelField
else:
Batch = type(None)
Dataset = type(None)
Example = type(None)
Field = type(None)
Iterator = type(None)
LabelField = type(None)

View File

@ -1,54 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import string
from tests_pytorch.helpers.imports import Dataset, Example, Field, Iterator
def _generate_random_string(length: int = 10):
return "".join(random.choices(string.ascii_letters, k=length))
def get_dummy_torchtext_data_iterator(num_samples: int, batch_size: int, include_lengths: bool = False):
text_field = Field(
sequential=True,
pad_first=False, # nosec
init_token="<s>",
eos_token="</s>", # nosec
include_lengths=include_lengths,
) # nosec
dataset = Dataset(
[
Example.fromdict({"text": _generate_random_string()}, {"text": ("text", text_field)})
for _ in range(num_samples)
],
{"text": text_field},
)
text_field.build_vocab(dataset)
iterator = Iterator(
dataset,
batch_size=batch_size,
sort_key=None,
device=None,
batch_size_fn=None,
train=True,
repeat=False,
shuffle=None,
sort=None,
sort_within_batch=None,
)
return iterator, text_field

View File

@ -28,9 +28,8 @@ from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.imports import _compare_version
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.imports import Batch, Dataset, Example, Field, LabelField
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel
@ -269,33 +268,6 @@ def test_single_gpu_batch_parse():
batch = trainer.strategy.batch_to_device(CustomBatchType(), torch.device("cuda:0"))
assert batch.a.type() == "torch.cuda.FloatTensor"
# torchtext.data.Batch
if not _TORCHTEXT_LEGACY:
return
samples = [
{"text": "PyTorch Lightning is awesome!", "label": 0},
{"text": "Please make it work with torchtext", "label": 1},
]
text_field = Field()
label_field = LabelField()
fields = {"text": ("text", text_field), "label": ("label", label_field)}
examples = [Example.fromdict(sample, fields) for sample in samples]
dataset = Dataset(examples=examples, fields=fields.values())
# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
text_field.build_vocab(dataset)
label_field.build_vocab(dataset)
batch = Batch(data=examples, dataset=dataset)
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
batch = trainer.strategy.batch_to_device(batch, torch.device("cuda:0"))
assert batch.text.type() == "torch.cuda.LongTensor"
assert batch.label.type() == "torch.cuda.LongTensor"
@RunIf(min_cuda_gpus=1)
def test_non_blocking():

View File

@ -1,47 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
@pytest.mark.parametrize("include_lengths", [False, True])
@pytest.mark.parametrize("device", [torch.device("cuda", 0)])
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
@RunIf(min_cuda_gpus=1)
def test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, device):
data_iterator, _ = get_dummy_torchtext_data_iterator(num_samples=3, batch_size=3, include_lengths=include_lengths)
data_iter = iter(data_iterator)
batch = next(data_iter)
with pytest.deprecated_call(match="The `torchtext.legacy.Batch` object is deprecated"):
batch_on_device = move_data_to_device(batch, device)
if include_lengths:
# tensor with data
assert batch_on_device.text[0].device == device
# tensor with length of data
assert batch_on_device.text[1].device == device
else:
assert batch_on_device.text.device == device
@pytest.mark.parametrize("include_lengths", [False, True])
@pytest.mark.skipif(not _TORCHTEXT_LEGACY, reason="torchtext.legacy is deprecated.")
def test_batch_move_data_to_device_torchtext_include_lengths_cpu(include_lengths):
test_batch_move_data_to_device_torchtext_include_lengths(include_lengths, torch.device("cpu"))