lightning/pytorch_lightning/trainer/supporters.py

152 lines
5.1 KiB
Python
Raw Normal View History

from pathlib import Path
from typing import Optional
import torch
from torch import Tensor
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
class TensorRunningAccum(object):
"""Tracks a running accumulation values (min, max, mean) without graph
references.
Examples:
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
>>> accum = TensorRunningAccum(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""
def __init__(self, window_length: int):
self.window_length = window_length
self.memory = torch.Tensor(self.window_length)
self.current_idx: int = 0
self.last_idx: Optional[int] = None
self.rotated: bool = False
def reset(self) -> None:
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
"""Empty the accumulator."""
self = TensorRunningAccum(self.window_length)
def last(self):
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
"""Get the last added element."""
if self.last_idx is not None:
return self.memory[self.last_idx]
def append(self, x):
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
"""Add an element to the accumulator."""
# ensure same device and type
if self.memory.device != x.device or self.memory.type() != x.type():
x = x.to(self.memory)
# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx
# increase index
self.current_idx += 1
# reset index when hit limit of tensor
self.current_idx = self.current_idx % self.window_length
if self.current_idx == 0:
self.rotated = True
def mean(self):
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
"""Get mean value from stored elements."""
return self._agg_memory('mean')
def max(self):
"""Get maximal value from stored elements."""
return self._agg_memory('max')
def min(self):
"""Get minimal value from stored elements."""
return self._agg_memory('min')
def _agg_memory(self, how: str):
if self.last_idx is not None:
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
if self.rotated:
return getattr(self.memory, how)()
else:
return getattr(self.memory[:self.current_idx], how)()
Structured results (train loop only. val loop separate PR) (PR 2/5) (#2615) * r * r * r * patched optimizer closure with sr * patched optimizer closure with sr * patched optimizer closure with sr * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added autoreduce for train step * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added hooks * added hooks * added hooks * added hooks * added hooks * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * cache * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py * Update pytorch_lightning/core/step_result.py * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * simple * finished tests for structured results on train epoch * simple * simple * revert * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update tests/base/deterministic_model.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * finished tests for structured results on train epoch * docstring typos * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update pytorch_lightning/core/step_result.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update pytorch_lightning/overrides/data_parallel.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-07-20 23:00:20 +00:00
class Accumulator(object):
def __init__(self):
self.num_values = 0
self.total = 0
def accumulate(self, x):
with torch.no_grad():
self.total += x
self.num_values += 1
def mean(self):
return self.total / self.num_values
class PredictionCollection(object):
def __init__(self, global_rank: int, world_size: int):
self.global_rank = global_rank
self.world_size = world_size
self.predictions = {}
self.num_predictions = 0
def _add_prediction(self, name, values, filename):
if filename not in self.predictions:
self.predictions[filename] = {name: values}
elif name not in self.predictions[filename]:
self.predictions[filename][name] = values
elif isinstance(values, Tensor):
self.predictions[filename][name] = torch.cat((self.predictions[filename][name], values))
elif isinstance(values, list):
self.predictions[filename][name].extend(values)
def add(self, predictions):
if predictions is None:
return
for filename, pred_dict in predictions.items():
for feature_name, values in pred_dict.items():
self._add_prediction(feature_name, values, filename)
def to_disk(self):
"""Write predictions to file(s).
"""
for filename, predictions in self.predictions.items():
# Absolute path to defined prediction file. rank added to name if in multi-gpu environment
outfile = Path(filename).absolute()
outfile = outfile.with_name(
f"{outfile.stem}{f'_rank_{self.global_rank}' if self.world_size > 1 else ''}{outfile.suffix}"
)
outfile.parent.mkdir(exist_ok=True, parents=True)
# Convert any tensor values to list
predictions = {k: v if not isinstance(v, Tensor) else v.tolist() for k, v in predictions.items()}
# Check if all features for this file add up to same length
feature_lens = {k: len(v) for k, v in predictions.items()}
if len(set(feature_lens.values())) != 1:
raise ValueError('Mismatching feature column lengths found in stored EvalResult predictions.')
# Switch predictions so each entry has its own dict
outputs = []
for values in zip(*predictions.values()):
output_element = {k: v for k, v in zip(predictions.keys(), values)}
outputs.append(output_element)
# Write predictions for current file to disk
torch.save(outputs, outfile)