Fix conversion in on_before_forward

This commit is contained in:
SeanNaren 2020-11-22 15:06:11 +00:00
parent 50ed083fc7
commit df416f6c78
1 changed files with 4 additions and 3 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Any, Optional, Union
from typing import List, Optional, Union
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
@ -50,11 +50,12 @@ class DDPShardedPlugin(DDPPlugin):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)
def on_before_forward(self, args: Any, model: LightningModule):
def on_before_forward(self, model: LightningModule, *args):
args = list(args)
batch = args[0]
batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu)
args[0] = batch
return args
return tuple(args)
def _check_fairscale(self):
if not FAIRSCALE_AVAILABLE: