Attempt try catch to prevent errors

This commit is contained in:
SeanNaren 2020-11-25 20:16:20 +00:00
parent ba312473f8
commit 586f6c62ee
1 changed files with 17 additions and 13 deletions

View File

@ -11,20 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
try:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
except (ModuleNotFoundError, ImportError):
FAIRSCALE_SHARDED_AVAILABLE = False
else:
FAIRSCALE_SHARDED_AVAILABLE = True
class LightningShardedDataParallel(ShardedDataParallel):
class LightningShardedDataParallel(ShardedDataParallel):
def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers:
self.sync_buffers()
def forward(self, *inputs, **kwargs):
if self.enable_broadcast_buffers:
self.sync_buffers()
if self.module.training:
outputs = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
outputs = self.module.test_step(*inputs, **kwargs)
else:
outputs = self.module.validation_step(*inputs, **kwargs)
return outputs
if self.module.training:
outputs = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
outputs = self.module.test_step(*inputs, **kwargs)
else:
outputs = self.module.validation_step(*inputs, **kwargs)
return outputs