Fix unimplemented type() on TPU (#1396)

* Fix unimplemented type() on TPU

* Add changelog entry

* Add quotation marks
This commit is contained in:
Paweł Rzepiński 2020-04-07 02:29:55 +02:00 committed by GitHub
parent 9754c5da55
commit b8ff9bc1d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
- Fixed `WandbLogger.watch` with `wandb.init()` ([#1311](https://github.com/PyTorchLightning/pytorch-lightning/pull/1311))
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)).
- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)).
- Fixed `WandbLogger.watch` - use of the watch method without importing `wandb` ([#1311](https://github.com/PyTorchLightning/pytorch-lightning/pull/1311))
- Fixed `WandbLogger` to be used with 'ddp' - allow reinits in sub-processes ([#1149](https://github.com/PyTorchLightning/pytorch-lightning/pull/1149), [#1360](https://github.com/PyTorchLightning/pytorch-lightning/pull/1360))
- Made `training_epoch_end` behave like `validation_epoch_end` ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357))
@ -83,6 +83,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed running `on_validation_end` only on main process in DDP ([#1125](https://github.com/PyTorchLightning/pytorch-lightning/pull/1125))
- Fixes `use_amp` issue ([#1145](https://github.com/PyTorchLightning/pytorch-lightning/pull/1145))
- Fixes using deprecated `use_amp` attribute ([#1145](https://github.com/PyTorchLightning/pytorch-lightning/pull/1145))
- Fixed `Unimplemented backend XLA` error on TPU ([#1387](https://github.com/PyTorchLightning/pytorch-lightning/pull/1387))
## [0.7.1] - 2020-03-07

View File

@ -36,9 +36,9 @@ class TensorRunningMean(object):
return self.memory[self.last_idx]
def append(self, x):
# map proper type for memory if they don't match
if self.memory.type() != x.type():
self.memory.type_as(x)
# ensure same device and type
if self.memory.device != x.device or self.memory.type() != x.type():
x = x.to(self.memory)
# store without grads
with torch.no_grad():