From ba312473f85f2269c4491a83cb9d0c6edef24e0b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 19:40:58 +0000 Subject: [PATCH] Add check to ensure 1.6 --- pytorch_lightning/plugins/sharded_plugin.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 5575baef0c..d6d1186992 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from distutils.version import LooseVersion from typing import List, Optional, Union +import torch + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException try: - from fairscale.optim import OSS - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel + IS_TORCH_AT_LEAST_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0") + if IS_TORCH_AT_LEAST_1_6: + from fairscale.optim import OSS + from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel except (ModuleNotFoundError, ImportError): FAIRSCALE_AVAILABLE = False else: @@ -59,7 +64,7 @@ class DDPShardedPlugin(DDPPlugin): def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: raise MisconfigurationException( - 'Sharded DDP Plugin requires Fairscale to be installed.' + 'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.' ) @rank_zero_only