Add check to ensure 1.6
This commit is contained in:
parent
a311ee17ab
commit
ba312473f8
|
@ -11,16 +11,21 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from fairscale.optim import OSS
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
|
||||
IS_TORCH_AT_LEAST_1_6 = LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
|
||||
if IS_TORCH_AT_LEAST_1_6:
|
||||
from fairscale.optim import OSS
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
FAIRSCALE_AVAILABLE = False
|
||||
else:
|
||||
|
@ -59,7 +64,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
def _check_fairscale(self):
|
||||
if not FAIRSCALE_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Sharded DDP Plugin requires Fairscale to be installed.'
|
||||
'Sharded DDP Plugin requires Fairscale to be installed and Pytorch version 1.6 or above.'
|
||||
)
|
||||
|
||||
@rank_zero_only
|
||||
|
|
Loading…
Reference in New Issue