Update auto_encoder.py to accomodate torchvision breaking change (#18996)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-11-13 15:32:16 -05:00 committed by GitHub
parent 7288302186
commit cb23fc2dd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -44,7 +44,7 @@ class ImageSampler(callbacks.Callback):
nrow: int = 8, nrow: int = 8,
padding: int = 2, padding: int = 2,
normalize: bool = True, normalize: bool = True,
norm_range: Optional[Tuple[int, int]] = None, value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False, scale_each: bool = False,
pad_value: int = 0, pad_value: int = 0,
) -> None: ) -> None:
@ -56,7 +56,7 @@ class ImageSampler(callbacks.Callback):
padding: Amount of padding. Default: ``2``. padding: Amount of padding. Default: ``2``.
normalize: If ``True``, shift the image to the range (0, 1), normalize: If ``True``, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``. by the min and max values specified by :attr:`range`. Default: ``False``.
norm_range: Tuple (min, max) where min and max are numbers, value_range: Tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max then these numbers are used to normalize the image. By default, min and max
are computed from the tensor. are computed from the tensor.
scale_each: If ``True``, scale each image in the batch of scale_each: If ``True``, scale each image in the batch of
@ -71,7 +71,7 @@ class ImageSampler(callbacks.Callback):
self.nrow = nrow self.nrow = nrow
self.padding = padding self.padding = padding
self.normalize = normalize self.normalize = normalize
self.norm_range = norm_range self.value_range = value_range
self.scale_each = scale_each self.scale_each = scale_each
self.pad_value = pad_value self.pad_value = pad_value
@ -81,7 +81,7 @@ class ImageSampler(callbacks.Callback):
nrow=self.nrow, nrow=self.nrow,
padding=self.padding, padding=self.padding,
normalize=self.normalize, normalize=self.normalize,
value_range=self.norm_range, value_range=self.value_range,
scale_each=self.scale_each, scale_each=self.scale_each,
pad_value=self.pad_value, pad_value=self.pad_value,
) )