diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index f81a382041..5846214636 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -10,7 +10,7 @@ TORCHTEXT_AVAILABLE = importlib.util.find_spec("torchtext") is not None if TORCHTEXT_AVAILABLE: from torchtext.data import Batch else: - Batch = None + Batch = type(None) def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: