Remove `model_connector.py` (#10111)

This commit is contained in:
Danielle Pintz 2021-10-26 02:52:14 -07:00 committed by GitHub
parent 871a96701a
commit a5235d5b01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 30 deletions

View File

@ -15,6 +15,7 @@ import os
from dataclasses import dataclass
from functools import partial
from typing import Iterable, Optional, Union
from weakref import proxy
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
@ -186,7 +187,15 @@ class DataConnector:
)
self.attach_datamodule(model, datamodule=datamodule)
# set local properties on the model
self.trainer.model_connector.copy_trainer_model_properties(model)
self._copy_trainer_model_properties(model)
def _copy_trainer_model_properties(self, model):
ref_model = self.trainer.lightning_module or model
for m in [model, ref_model]:
m.trainer = proxy(self.trainer)
m.use_amp = self.trainer.amp_backend is not None
m.precision = self.trainer.precision
def attach_dataloaders(
self,

View File

@ -1,27 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from weakref import proxy
class ModelConnector:
def __init__(self, trainer):
self.trainer = trainer
def copy_trainer_model_properties(self, model):
ref_model = self.trainer.lightning_module or model
for m in [model, ref_model]:
m.trainer = proxy(self.trainer)
m.use_amp = self.trainer.amp_backend is not None
m.precision = self.trainer.precision

View File

@ -57,7 +57,6 @@ from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingCo
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
@ -450,7 +449,6 @@ class Trainer(
plugins,
)
self.logger_connector = LoggerConnector(self, log_gpu_memory)
self.model_connector = ModelConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)