diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 73d9a6e6fb..4ab1933e4a 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,20 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel +try: + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel +except (ModuleNotFoundError, ImportError): + FAIRSCALE_SHARDED_AVAILABLE = False +else: + FAIRSCALE_SHARDED_AVAILABLE = True -class LightningShardedDataParallel(ShardedDataParallel): + class LightningShardedDataParallel(ShardedDataParallel): - def forward(self, *inputs, **kwargs): - if self.enable_broadcast_buffers: - self.sync_buffers() + def forward(self, *inputs, **kwargs): + if self.enable_broadcast_buffers: + self.sync_buffers() - if self.module.training: - outputs = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: - outputs = self.module.test_step(*inputs, **kwargs) - else: - outputs = self.module.validation_step(*inputs, **kwargs) - return outputs + if self.module.training: + outputs = self.module.training_step(*inputs, **kwargs) + elif self.module.testing: + outputs = self.module.test_step(*inputs, **kwargs) + else: + outputs = self.module.validation_step(*inputs, **kwargs) + return outputs