From 586f6c62ee578afcbab05e85dd99f7d07e671f2d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 25 Nov 2020 20:16:20 +0000 Subject: [PATCH] Attempt try catch to prevent errors --- pytorch_lightning/overrides/fairscale.py | 30 ++++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 73d9a6e6fb..4ab1933e4a 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,20 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel +try: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel +except (ModuleNotFoundError, ImportError): + FAIRSCALE_SHARDED_AVAILABLE = False +else: + FAIRSCALE_SHARDED_AVAILABLE = True -class LightningShardedDataParallel(ShardedDataParallel): + class LightningShardedDataParallel(ShardedDataParallel): - def forward(self, *inputs, **kwargs): - if self.enable_broadcast_buffers: - self.sync_buffers() + def forward(self, *inputs, **kwargs): + if self.enable_broadcast_buffers: + self.sync_buffers() - if self.module.training: - outputs = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: - outputs = self.module.test_step(*inputs, **kwargs) - else: - outputs = self.module.validation_step(*inputs, **kwargs) - return outputs + if self.module.training: + outputs = self.module.training_step(*inputs, **kwargs) + elif self.module.testing: + outputs = self.module.test_step(*inputs, **kwargs) + else: + outputs = self.module.validation_step(*inputs, **kwargs) + return outputs