New Year ruff cleanup. (#11595)
This commit is contained in:
@@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
@@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
||||
Reference in New Issue
Block a user