diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index cc7220a906..fdb632486c 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any, Optional, Union +from typing import List, Optional, Union from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin @@ -50,11 +50,12 @@ class DDPShardedPlugin(DDPPlugin): optimizer.consolidate_state_dict() return self._optim_state_dict(optimizer) - def on_before_forward(self, args: Any, model: LightningModule): + def on_before_forward(self, model: LightningModule, *args): + args = list(args) batch = args[0] batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu) args[0] = batch - return args + return tuple(args) def _check_fairscale(self): if not FAIRSCALE_AVAILABLE: