import from PL
This commit is contained in:
parent
7b047f82c7
commit
eb9c9b7f0d
|
@ -42,7 +42,7 @@ class PyTorchLightningScriptRunner(TracerPythonScript):
|
|||
self.env = env
|
||||
|
||||
def configure_tracer(self):
|
||||
from lightning import Trainer
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
tracer = super().configure_tracer()
|
||||
tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
|
||||
|
@ -70,8 +70,8 @@ class PyTorchLightningScriptRunner(TracerPythonScript):
|
|||
return super().run(**kwargs)
|
||||
|
||||
def on_after_run(self, script_globals):
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.cli import LightningCLI
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
|
||||
for v in script_globals.values():
|
||||
if isinstance(v, LightningCLI):
|
||||
|
|
Loading…
Reference in New Issue