Add check to ensure 1.6

This commit is contained in:
SeanNaren 2020-11-25 19:40:58 +00:00
parent a311ee17ab
commit ba312473f8
1 changed files with 8 additions and 3 deletions

View File

@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from typing import List, Optional, Union
import torch
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from fairscale.optim import OSS
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
IS_TORCH_AT_LEAST_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
if IS_TORCH_AT_LEAST_1_6:
from fairscale.optim import OSS
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
except (ModuleNotFoundError, ImportError):
FAIRSCALE_AVAILABLE = False
else:
@ -59,7 +64,7 @@ class DDPShardedPlugin(DDPPlugin):
def _check_fairscale(self):
if not FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
'Sharded DDP Plugin requires Fairscale to be installed.'
'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.'
)
@rank_zero_only