import from PL

This commit is contained in:
otaj 2022-09-28 12:20:43 +02:00
parent 7b047f82c7
commit eb9c9b7f0d
1 changed files with 3 additions and 3 deletions

View File

@ -42,7 +42,7 @@ class PyTorchLightningScriptRunner(TracerPythonScript):
self.env = env
def configure_tracer(self):
from lightning import Trainer
from pytorch_lightning import Trainer
tracer = super().configure_tracer()
tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
@ -70,8 +70,8 @@ class PyTorchLightningScriptRunner(TracerPythonScript):
return super().run(**kwargs)
def on_after_run(self, script_globals):
from lightning import Trainer
from lightning.pytorch.cli import LightningCLI
from pytorch_lightning import Trainer
from pytorch_lightning.cli import LightningCLI
for v in script_globals.values():
if isinstance(v, LightningCLI):