Simplify GPU and TPU accelerator (#5024)
This commit is contained in:
parent
90d1d9fa73
commit
bcbba3b702
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -66,53 +66,25 @@ class GPUAccelerator(Accelerator):
|
|||
results = self.train_or_test()
|
||||
return results
|
||||
|
||||
def training_step(self, args):
|
||||
def _step(self, model_step: Callable, args):
|
||||
args[0] = self.to_device(args[0])
|
||||
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.__training_step(args)
|
||||
output = model_step(*args)
|
||||
else:
|
||||
output = self.__training_step(args)
|
||||
output = model_step(*args)
|
||||
|
||||
return output
|
||||
|
||||
def __training_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.training_step(*args)
|
||||
return output
|
||||
def training_step(self, args):
|
||||
return self._step(self.trainer.model.training_step, args)
|
||||
|
||||
def validation_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.__validation_step(args)
|
||||
else:
|
||||
output = self.__validation_step(args)
|
||||
|
||||
return output
|
||||
|
||||
def __validation_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.validation_step, args)
|
||||
|
||||
def test_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.__test_step(args)
|
||||
else:
|
||||
output = self.__test_step(args)
|
||||
|
||||
return output
|
||||
|
||||
def __test_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.test_step, args)
|
||||
|
||||
def to_device(self, batch):
|
||||
gpu_id = 0
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -145,26 +145,18 @@ class TPUAccelerator(Accelerator):
|
|||
# persist info in spawn
|
||||
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
|
||||
|
||||
def _step(self, model_step: Callable, args):
|
||||
args[0] = self.to_device(args[0])
|
||||
return model_step(*args)
|
||||
|
||||
def training_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.training_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.training_step, args)
|
||||
|
||||
def validation_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.validation_step, args)
|
||||
|
||||
def test_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
return self._step(self.trainer.model.test_step, args)
|
||||
|
||||
def process_dataloader(self, dataloader):
|
||||
device = xm.xla_device(self.trainer.tpu_id)
|
||||
|
|
Loading…
Reference in New Issue