flake8 fixes (#3064)
* flake8 fixes * fix pep8 * fix pep8 Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
9a605642a4
commit
cee5eaf659
|
@ -80,9 +80,15 @@ class GpuUsageLogger(Callback):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, memory_utilisation: bool = True, gpu_utilisation: bool = True,
|
||||
intra_step_time: bool = False, inter_step_time: bool = False,
|
||||
fan_speed: bool = False, temperature: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
memory_utilisation: bool = True,
|
||||
gpu_utilisation: bool = True,
|
||||
intra_step_time: bool = False,
|
||||
inter_step_time: bool = False,
|
||||
fan_speed: bool = False,
|
||||
temperature: bool = False,
|
||||
):
|
||||
super(GpuUsageLogger).__init__()
|
||||
self.memory_utilisation = memory_utilisation
|
||||
self.gpu_utilisation = gpu_utilisation
|
||||
|
@ -102,9 +108,10 @@ class GpuUsageLogger(Callback):
|
|||
if self.inter_step_time:
|
||||
# First log at beginning of second step
|
||||
if self.snap_inter_step_time:
|
||||
trainer.logger.log_metrics({'Batch_Time/inter_step (ms)':
|
||||
(time.time() - self.snap_inter_step_time) * 1000},
|
||||
step=trainer.global_step)
|
||||
trainer.logger.log_metrics(
|
||||
{'Batch_Time/inter_step (ms)': (time.time() - self.snap_inter_step_time) * 1000},
|
||||
step=trainer.global_step,
|
||||
)
|
||||
if self.intra_step_time:
|
||||
self.snap_intra_step_time = time.time()
|
||||
|
||||
|
@ -125,9 +132,10 @@ class GpuUsageLogger(Callback):
|
|||
|
||||
if self.intra_step_time:
|
||||
if self.snap_intra_step_time:
|
||||
trainer.logger.log_metrics({'Batch_Time/intra_step (ms)':
|
||||
(time.time() - self.snap_intra_step_time) * 1000},
|
||||
step=trainer.global_step)
|
||||
trainer.logger.log_metrics(
|
||||
{'Batch_Time/intra_step (ms)': (time.time() - self.snap_intra_step_time) * 1000},
|
||||
step=trainer.global_step,
|
||||
)
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
self.snap_intra_step_time = None
|
||||
|
@ -135,10 +143,13 @@ class GpuUsageLogger(Callback):
|
|||
|
||||
@staticmethod
|
||||
def _get_gpu_stat(pitem: str, unit: str):
|
||||
result = subprocess.run(["/bin/nvidia-smi", f"--query-gpu={pitem}", "--format=csv,nounits,noheader"],
|
||||
encoding="utf-8", stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
|
||||
check=True)
|
||||
result = subprocess.run(
|
||||
["/bin/nvidia-smi", f"--query-gpu={pitem}", "--format=csv,nounits,noheader"],
|
||||
encoding="utf-8",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, # for backward compatibility with python version 3.6
|
||||
check=True,
|
||||
)
|
||||
try:
|
||||
gpu_usage = [float(x) for x in result.stdout.strip().split(os.linesep)]
|
||||
except ValueError:
|
||||
|
@ -152,4 +163,4 @@ class GpuUsageLogger(Callback):
|
|||
def _log_memory(self, trainer):
|
||||
trainer.logger.log_metrics(self._get_gpu_stat("memory.used", "MB"), step=trainer.global_step)
|
||||
trainer.logger.log_metrics(self._get_gpu_stat("memory.free", "MB"), step=trainer.global_step)
|
||||
trainer.logger.log_metrics(self._get_gpu_stat("utilization.memory", "%"), step=trainer.global_step)
|
||||
trainer.logger.log_metrics(self._get_gpu_stat("utilization.memory", "%"), step=trainer.global_step)
|
||||
|
|
|
@ -222,7 +222,7 @@ class ModelSummary(object):
|
|||
input_ = model.transfer_batch_to_device(input_, model.device)
|
||||
|
||||
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
|
||||
model.forward = torch.cuda.amp.autocast()(model.forward)
|
||||
model.forward = torch.cuda.amp.autocast()(model.forward)
|
||||
|
||||
mode = model.training
|
||||
model.eval()
|
||||
|
|
|
@ -24,13 +24,12 @@ from pytorch_lightning.metrics.converters import _sync_ddp_if_available
|
|||
|
||||
|
||||
class Result(Dict):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
minimize: Optional[Tensor] = None,
|
||||
early_stop_on: Optional[Tensor] = None,
|
||||
checkpoint_on: Union[Tensor, bool, None] = None,
|
||||
hiddens: Optional[Tensor] = None,
|
||||
self,
|
||||
minimize: Optional[Tensor] = None,
|
||||
early_stop_on: Optional[Tensor] = None,
|
||||
checkpoint_on: Union[Tensor, bool, None] = None,
|
||||
hiddens: Optional[Tensor] = None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
@ -52,12 +51,7 @@ class Result(Dict):
|
|||
if minimize is not None and checkpoint_on is None:
|
||||
self.checkpoint_on = minimize.detach()
|
||||
|
||||
self['meta'] = {
|
||||
'_internal': {
|
||||
'_reduce_on_epoch': False,
|
||||
'batch_sizes': []
|
||||
}
|
||||
}
|
||||
self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}}
|
||||
|
||||
def __getitem__(self, key: Union[str, Any]) -> Any:
|
||||
try:
|
||||
|
@ -109,20 +103,20 @@ class Result(Dict):
|
|||
assert x.grad_fn is not None, m
|
||||
|
||||
def log(
|
||||
self,
|
||||
name: str,
|
||||
value: Any,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None
|
||||
self,
|
||||
name: str,
|
||||
value: Any,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
):
|
||||
# no metrics should be logged with graphs
|
||||
if not enable_graph and isinstance(value, torch.Tensor):
|
||||
|
@ -140,37 +134,60 @@ class Result(Dict):
|
|||
if on_step and on_epoch:
|
||||
# set step version
|
||||
step_name = f'step_{name}'
|
||||
self.__set_meta(step_name, value, prog_bar, logger,
|
||||
on_step=True, on_epoch=False,
|
||||
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
|
||||
self.__set_meta(
|
||||
step_name,
|
||||
value,
|
||||
prog_bar,
|
||||
logger,
|
||||
on_step=True,
|
||||
on_epoch=False,
|
||||
reduce_fx=reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
self.__setitem__(step_name, value)
|
||||
|
||||
# set epoch version
|
||||
epoch_name = f'epoch_{name}'
|
||||
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True,
|
||||
reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
|
||||
self.__set_meta(
|
||||
epoch_name,
|
||||
value,
|
||||
prog_bar,
|
||||
logger,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
reduce_fx=reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
self.__setitem__(epoch_name, value)
|
||||
else:
|
||||
self.__set_meta(name, value,
|
||||
prog_bar, logger,
|
||||
on_step, on_epoch,
|
||||
reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token)
|
||||
self.__set_meta(
|
||||
name,
|
||||
value,
|
||||
prog_bar,
|
||||
logger,
|
||||
on_step,
|
||||
on_epoch,
|
||||
reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
|
||||
# set the value
|
||||
self.__setitem__(name, value)
|
||||
|
||||
def __set_meta(
|
||||
self,
|
||||
name: str,
|
||||
value: Any,
|
||||
prog_bar: bool,
|
||||
logger: bool,
|
||||
on_step: bool,
|
||||
on_epoch: bool,
|
||||
reduce_fx: Callable,
|
||||
tbptt_pad_token: int,
|
||||
tbptt_reduce_fx: Callable
|
||||
self,
|
||||
name: str,
|
||||
value: Any,
|
||||
prog_bar: bool,
|
||||
logger: bool,
|
||||
on_step: bool,
|
||||
on_epoch: bool,
|
||||
reduce_fx: Callable,
|
||||
tbptt_pad_token: int,
|
||||
tbptt_reduce_fx: Callable,
|
||||
):
|
||||
# set the meta for the item
|
||||
meta_value = value
|
||||
|
@ -182,7 +199,7 @@ class Result(Dict):
|
|||
reduce_fx=reduce_fx,
|
||||
value=meta_value,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
|
||||
self['meta'][name] = meta
|
||||
|
@ -200,10 +217,7 @@ class Result(Dict):
|
|||
return torch.tensor(meta['_internal']['batch_sizes'])
|
||||
|
||||
def get_callback_metrics(self) -> dict:
|
||||
result = {
|
||||
'early_stop_on': self.early_stop_on,
|
||||
'checkpoint_on': self.checkpoint_on
|
||||
}
|
||||
result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on}
|
||||
|
||||
return result
|
||||
|
||||
|
@ -342,7 +356,6 @@ class Result(Dict):
|
|||
result = recursive_gather(outputs, result)
|
||||
recursive_stack(result)
|
||||
|
||||
|
||||
for k, option in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
|
@ -457,13 +470,12 @@ def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]:
|
|||
|
||||
|
||||
class TrainResult(Result):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
minimize: Optional[Tensor] = None,
|
||||
early_stop_on: Tensor = None,
|
||||
checkpoint_on: Union[Tensor, bool] = None,
|
||||
hiddens: Optional[Tensor] = None,
|
||||
self,
|
||||
minimize: Optional[Tensor] = None,
|
||||
early_stop_on: Tensor = None,
|
||||
checkpoint_on: Union[Tensor, bool] = None,
|
||||
hiddens: Optional[Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Used in train loop to auto-log to a logger or progress bar without needing to define
|
||||
|
@ -493,20 +505,20 @@ class TrainResult(Result):
|
|||
super().__init__(minimize, early_stop_on, checkpoint_on, hiddens)
|
||||
|
||||
def log(
|
||||
self,
|
||||
name,
|
||||
value,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = True,
|
||||
on_epoch: bool = False,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None
|
||||
self,
|
||||
name,
|
||||
value,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = True,
|
||||
on_epoch: bool = False,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Log a key, value
|
||||
|
@ -543,34 +555,36 @@ class TrainResult(Result):
|
|||
sync_dist_op: the op to sync across
|
||||
sync_dist_group: the ddp group
|
||||
"""
|
||||
super().log(name=name,
|
||||
value=value,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx)
|
||||
super().log(
|
||||
name=name,
|
||||
value=value,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
)
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
dictionary: dict,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None
|
||||
self,
|
||||
dictionary: dict,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Log a dictonary of values at once
|
||||
|
@ -595,28 +609,29 @@ class TrainResult(Result):
|
|||
sync_dist_group: the ddp group:
|
||||
"""
|
||||
for k, v in dictionary.items():
|
||||
self.log(name=k,
|
||||
value=v,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx)
|
||||
self.log(
|
||||
name=k,
|
||||
value=v,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
)
|
||||
|
||||
|
||||
class EvalResult(Result):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
early_stop_on: Optional[Tensor] = None,
|
||||
checkpoint_on: Optional[Tensor] = None,
|
||||
hiddens: Optional[Tensor] = None,
|
||||
self,
|
||||
early_stop_on: Optional[Tensor] = None,
|
||||
checkpoint_on: Optional[Tensor] = None,
|
||||
hiddens: Optional[Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Used in val/train loop to auto-log to a logger or progress bar without needing to define
|
||||
|
@ -645,20 +660,20 @@ class EvalResult(Result):
|
|||
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
|
||||
|
||||
def log(
|
||||
self,
|
||||
name,
|
||||
value,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None
|
||||
self,
|
||||
name,
|
||||
value,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Log a key, value
|
||||
|
@ -694,34 +709,36 @@ class EvalResult(Result):
|
|||
sync_dist_op: the op to sync across
|
||||
sync_dist_group: the ddp group
|
||||
"""
|
||||
super().log(name=name,
|
||||
value=value,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx)
|
||||
super().log(
|
||||
name=name,
|
||||
value=value,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
)
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
dictionary: dict,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None
|
||||
self,
|
||||
dictionary: dict,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Log a dictonary of values at once
|
||||
|
@ -746,25 +763,24 @@ class EvalResult(Result):
|
|||
sync_dist_group: the ddp group
|
||||
"""
|
||||
for k, v in dictionary.items():
|
||||
self.log(name=k,
|
||||
value=v,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx)
|
||||
self.log(
|
||||
name=k,
|
||||
value=v,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
)
|
||||
|
||||
def get_callback_metrics(self) -> dict:
|
||||
result = {
|
||||
'val_early_stop_on': self.early_stop_on,
|
||||
'val_checkpoint_on': self.checkpoint_on
|
||||
}
|
||||
result = {'val_early_stop_on': self.early_stop_on, 'val_checkpoint_on': self.checkpoint_on}
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -255,7 +255,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
model: LightningModule,
|
||||
dataloaders: List[DataLoader],
|
||||
max_batches: Union[int, List[int]],
|
||||
test_mode: bool = False
|
||||
test_mode: bool = False,
|
||||
):
|
||||
"""Run evaluation code.
|
||||
|
||||
|
@ -316,13 +316,13 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
if self.is_overridden('on_test_batch_start'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_test_batch_start'):
|
||||
model_ref.on_test_batch_start(output)
|
||||
model_ref.on_test_batch_start(batch, batch_idx, dataloader_idx)
|
||||
else:
|
||||
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
if self.is_overridden('on_validation_batch_start'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_validation_batch_start'):
|
||||
model_ref.on_validation_batch_start(output)
|
||||
model_ref.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
|
@ -364,13 +364,13 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
if self.is_overridden('on_test_batch_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_test_batch_end'):
|
||||
model_ref.on_test_batch_end(output)
|
||||
model_ref.on_test_batch_end(batch, batch_idx, dataloader_idx)
|
||||
else:
|
||||
self.on_validation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
if self.is_overridden('on_validation_batch_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_validation_batch_end'):
|
||||
model_ref.on_validation_batch_end(output)
|
||||
model_ref.on_validation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# track outputs for collation
|
||||
if output is not None:
|
||||
|
@ -456,8 +456,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
eval_results = model.test_end(eval_results)
|
||||
user_reduced = True
|
||||
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
|
||||
' Use `test_epoch_end` instead.', DeprecationWarning)
|
||||
rank_zero_warn(
|
||||
'Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
|
||||
' Use `test_epoch_end` instead.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
elif self.is_overridden('test_epoch_end', model=model):
|
||||
if using_eval_result:
|
||||
|
@ -474,8 +477,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
eval_results = model.validation_end(eval_results)
|
||||
user_reduced = True
|
||||
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
|
||||
' Use `validation_epoch_end` instead.', DeprecationWarning)
|
||||
rank_zero_warn(
|
||||
'Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
|
||||
' Use `validation_epoch_end` instead.',
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
elif self.is_overridden('validation_epoch_end', model=model):
|
||||
if using_eval_result:
|
||||
|
@ -647,8 +653,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
if (test_mode and len(self.test_dataloaders) > 1) \
|
||||
or (not test_mode and len(self.val_dataloaders) > 1):
|
||||
if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
|
||||
args.append(dataloader_idx)
|
||||
|
||||
# handle DP, DDP forward
|
||||
|
|
|
@ -33,20 +33,15 @@ class TestStepVariations(ABC):
|
|||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
output = OrderedDict({'test_loss': loss_test, 'test_acc': test_acc})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
'test_dic': {'test_loss_a': loss_test}
|
||||
})
|
||||
output = OrderedDict({'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
'test_dic': {'test_loss_a': loss_test}})
|
||||
return output
|
||||
|
||||
def test_step_result_obj(self, batch, batch_idx, *args, **kwargs):
|
||||
|
@ -71,19 +66,13 @@ class TestStepVariations(ABC):
|
|||
result = EvalResult()
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
result.log_dict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
result.log_dict({'test_loss': loss_test, 'test_acc': test_acc})
|
||||
return result
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
||||
if batch_idx % 3 == 0:
|
||||
result.log_dict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
result.log_dict({'test_loss': loss_test, 'test_acc': test_acc})
|
||||
result.test_dic = {'test_loss_a': loss_test}
|
||||
return result
|
||||
|
||||
|
@ -108,10 +97,7 @@ class TestStepVariations(ABC):
|
|||
|
||||
# alternate possible outputs to test
|
||||
if batch_idx % 1 == 0:
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': test_acc,
|
||||
})
|
||||
output = OrderedDict({'test_loss': loss_test, 'test_acc': test_acc})
|
||||
return output
|
||||
if batch_idx % 2 == 0:
|
||||
return test_acc
|
||||
|
@ -124,16 +110,12 @@ class TestStepVariations(ABC):
|
|||
})
|
||||
return output
|
||||
if batch_idx % 5 == 0:
|
||||
output = OrderedDict({
|
||||
f'test_loss_{dataloader_idx}': loss_test,
|
||||
f'test_acc_{dataloader_idx}': test_acc,
|
||||
})
|
||||
output = OrderedDict({f'test_loss_{dataloader_idx}': loss_test, f'test_acc_{dataloader_idx}': test_acc})
|
||||
return output
|
||||
|
||||
def test_step__empty(self, batch, batch_idx, *args, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
def test_step_result_preds(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
@ -174,13 +156,13 @@ class TestStepVariations(ABC):
|
|||
elif option == 1:
|
||||
result.write('idxs', torch.cat((lazy_ids, lazy_ids)), prediction_file)
|
||||
result.write('preds', labels_hat, prediction_file)
|
||||
|
||||
|
||||
# write multi-dimension
|
||||
elif option == 2:
|
||||
result.write('idxs', lazy_ids, prediction_file)
|
||||
result.write('preds', labels_hat, prediction_file)
|
||||
result.write('x', x, prediction_file)
|
||||
|
||||
|
||||
# write str list
|
||||
elif option == 3:
|
||||
result.write('idxs', lazy_ids, prediction_file)
|
||||
|
|
Loading…
Reference in New Issue