Simplify GPU and TPU accelerator (#5024)

This commit is contained in:
Rohit Gupta 2020-12-10 00:42:44 +05:30 committed by GitHub
parent 90d1d9fa73
commit bcbba3b702
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 54 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
@ -66,53 +66,25 @@ class GPUAccelerator(Accelerator):
results = self.train_or_test()
return results
def training_step(self, args):
def _step(self, model_step: Callable, args):
args[0] = self.to_device(args[0])
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__training_step(args)
output = model_step(*args)
else:
output = self.__training_step(args)
output = model_step(*args)
return output
def __training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def training_step(self, args):
return self._step(self.trainer.model.training_step, args)
def validation_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__validation_step(args)
else:
output = self.__validation_step(args)
return output
def __validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)
def test_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.__test_step(args)
else:
output = self.__test_step(args)
return output
def __test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)
def to_device(self, batch):
gpu_id = 0

View File

@ -14,7 +14,7 @@
import io
import os
import re
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.multiprocessing as mp
@ -145,26 +145,18 @@ class TPUAccelerator(Accelerator):
# persist info in spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
def _step(self, model_step: Callable, args):
args[0] = self.to_device(args[0])
return model_step(*args)
def training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
return self._step(self.trainer.model.training_step, args)
def validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
return self._step(self.trainer.model.validation_step, args)
def test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output
return self._step(self.trainer.model.test_step, args)
def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)