Fix conversion in on_before_forward
This commit is contained in:
parent
50ed083fc7
commit
df416f6c78
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Any, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning.core.lightning import LightningModule
|
||||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||||
|
@ -50,11 +50,12 @@ class DDPShardedPlugin(DDPPlugin):
|
||||||
optimizer.consolidate_state_dict()
|
optimizer.consolidate_state_dict()
|
||||||
return self._optim_state_dict(optimizer)
|
return self._optim_state_dict(optimizer)
|
||||||
|
|
||||||
def on_before_forward(self, args: Any, model: LightningModule):
|
def on_before_forward(self, model: LightningModule, *args):
|
||||||
|
args = list(args)
|
||||||
batch = args[0]
|
batch = args[0]
|
||||||
batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu)
|
batch = model.transfer_batch_to_device(batch, model.trainer.root_gpu)
|
||||||
args[0] = batch
|
args[0] = batch
|
||||||
return args
|
return tuple(args)
|
||||||
|
|
||||||
def _check_fairscale(self):
|
def _check_fairscale(self):
|
||||||
if not FAIRSCALE_AVAILABLE:
|
if not FAIRSCALE_AVAILABLE:
|
||||||
|
|
Loading…
Reference in New Issue