diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 2f7425bf3b..27ec0a5389 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -23,9 +23,13 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE +from pytorch_lightning.utilities.imports import _module_available if _TORCHTEXT_AVAILABLE: - from torchtext.data import Batch + if _module_available("torchtext.legacy.data"): + from torchtext.legacy.data import Batch + else: + from torchtext.data import Batch else: Batch = type(None)