diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 9b6f97f1eb..8f28696494 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -15,6 +15,7 @@ import os from dataclasses import dataclass from functools import partial from typing import Iterable, Optional, Union +from weakref import proxy import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation @@ -186,7 +187,15 @@ class DataConnector: ) self.attach_datamodule(model, datamodule=datamodule) # set local properties on the model - self.trainer.model_connector.copy_trainer_model_properties(model) + self._copy_trainer_model_properties(model) + + def _copy_trainer_model_properties(self, model): + ref_model = self.trainer.lightning_module or model + + for m in [model, ref_model]: + m.trainer = proxy(self.trainer) + m.use_amp = self.trainer.amp_backend is not None + m.precision = self.trainer.precision def attach_dataloaders( self, diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py deleted file mode 100644 index c4a249a86e..0000000000 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from weakref import proxy - - -class ModelConnector: - def __init__(self, trainer): - self.trainer = trainer - - def copy_trainer_model_properties(self, model): - ref_model = self.trainer.lightning_module or model - - for m in [model, ref_model]: - m.trainer = proxy(self.trainer) - m.use_amp = self.trainer.amp_backend is not None - m.precision = self.trainer.precision diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7d66266c28..4beb93196a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -57,7 +57,6 @@ from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingCo from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin @@ -450,7 +449,6 @@ class Trainer( plugins, ) self.logger_connector = LoggerConnector(self, log_gpu_memory) - self.model_connector = ModelConnector(self) self.callback_connector = CallbackConnector(self) self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self)