WIP way to support multi multi dimensional latents. (#10456)
This commit is contained in:
@@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
||||
def pack_latents(latents):
|
||||
latent_shapes = []
|
||||
tensors = []
|
||||
for tensor in latents:
|
||||
latent_shapes.append(tensor.shape)
|
||||
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
||||
|
||||
latent = torch.cat(tensors, dim=-1)
|
||||
return latent, latent_shapes
|
||||
|
||||
def unpack_latents(combined_latent, latent_shapes):
|
||||
if len(latent_shapes) > 1:
|
||||
output_tensors = []
|
||||
for shape in latent_shapes:
|
||||
cut = math.prod(shape[1:])
|
||||
tens = combined_latent[:, :, :cut]
|
||||
combined_latent = combined_latent[:, :, cut:]
|
||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
Reference in New Issue
Block a user