Support the LTXV 2 model. (#11632)

This commit is contained in:
comfyanonymous
2026-01-04 22:58:59 -08:00
committed by GitHub
parent 38d0493825
commit f2b002372b
23 changed files with 4214 additions and 185 deletions

View File

@@ -1,13 +1,47 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import math
from typing import Dict, Optional, Tuple
from einops import rearrange
import numpy as np
import torch
from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def _log_base(x, base):
return np.log(x) / np.log(base)
class LTXRopeType(str, Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
KEY = "rope_type"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.INTERLEAVED
return cls(kwargs.get(cls.KEY, default))
class LTXFrequenciesPrecision(str, Enum):
FLOAT32 = "float32"
FLOAT64 = "float64"
KEY = "frequencies_precision"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.FLOAT32
return cls(kwargs.get(cls.KEY, default))
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -39,9 +73,7 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
@@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
@@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
self.linear_2 = operations.Linear(
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
)
if post_act_fn is None:
self.post_act = None
@@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
def __init__(
self,
embedding_dim,
size_emb_dim,
use_additional_conditions: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
self.timestep_embedder = TimestepEmbedding(
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
@@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
def __init__(
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
embedding_dim,
size_emb_dim=embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
dtype=dtype,
device=device,
operations=operations,
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
@@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
@@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
def __init__(
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
self.linear_1 = operations.Linear(
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
self.linear_2 = operations.Linear(
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
@@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis):
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
return (
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
if split_pe else
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
)
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
def apply_split_rotary_emb(input_tensor, cos, sin):
needs_reshape = False
if input_tensor.ndim != 4 and cos.ndim == 4:
B, H, T, _ = cos.shape
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
needs_reshape = True
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
first_half_input = split_input[..., :1, :]
second_half_input = split_input[..., 1:, :]
output = split_input * cos.unsqueeze(-2)
first_half_output = output[..., :1, :]
second_half_output = output[..., 1:, :]
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
output = rearrange(output, "... d r -> ... (d r)")
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
@@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
@@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
return x
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
axis=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32
device = indices_grid.device
@functools.lru_cache(maxsize=5)
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
pow_indices = np.power(
theta,
np.linspace(
_log_base(start, theta),
_log_base(end, theta),
inner_dim // n_elem,
dtype=np.float64,
),
)
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
inner_dim // n_elem,
device=device,
dtype=torch.float32,
)
)
indices = indices.to(dtype=torch.float32)
indices = indices * math.pi / 2
return indices
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
if use_middle_indices_grid:
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
indices = indices.to(device=fractional_positions.device)
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
return freqs
# Pad if dim is not divisible by 6
if dim % 6 != 0:
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
def interleaved_freqs_cis(freqs, pad_size):
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq, sin_freq
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
def split_freqs_cis(freqs, pad_size, num_attention_heads):
cos_freq = freqs.cos()
sin_freq = freqs.sin()
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
return freqs_cis
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
B , T, half_HD = cos_freq.shape
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
caption_channels=4096,
num_layers=28,
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freq, sin_freq
class LTXBaseModel(torch.nn.Module, ABC):
"""
Abstract base class for LTX models (Lightricks Transformer models).
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
This class defines the common interface and shared functionality for all LTX models,
including LTXV (video) and LTXAV (audio-video) variants.
"""
def __init__(
self,
in_channels: int,
cross_attention_dim: int,
attention_head_dim: int,
num_attention_heads: int,
caption_channels: int,
num_layers: int,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list = [20, 2048, 2048],
causal_temporal_positioning: bool = False,
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.use_middle_indices_grid = use_middle_indices_grid
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.cross_attention_dim = cross_attention_dim
self.attention_head_dim = attention_head_dim
self.num_attention_heads = num_attention_heads
self.caption_channels = caption_channels
self.num_layers = num_layers
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
self.freq_grid_generator = (
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
else generate_freq_grid_pytorch
)
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
self.out_channels = in_channels
# Initialize common components
self._init_common_components(device, dtype)
# Initialize model-specific components
self._init_model_components(device, dtype, **kwargs)
# Initialize transformer blocks
self._init_transformer_blocks(device, dtype, **kwargs)
# Initialize output components
self._init_output_components(device, dtype)
def _init_common_components(self, device, dtype):
"""Initialize components common to all LTX models
- patchify_proj: Linear projection for patchifying input
- adaln_single: AdaLN layer for timestep embedding
- caption_projection: Linear projection for caption embedding
"""
self.patchify_proj = self.operations.Linear(
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize model-specific components. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_output_components(self, device, dtype):
"""Initialize output components. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input data. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output data. Must be implemented by subclasses."""
pass
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
self,
indices_grid,
dim,
out_dtype,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=False,
num_attention_heads=32,
):
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
if split_mode:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
"""Prepare positional embeddings."""
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
pe = self._precompute_freqs_cis(
fractional_coords,
dim=self.inner_dim,
out_dtype=x_dtype,
max_pos=self.positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return pe
def _prepare_attention_mask(self, attention_mask, x_dtype):
"""Prepare attention mask."""
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
) * torch.finfo(x_dtype).max
return attention_mask
def forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
),
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
def _forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Internal forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
if isinstance(x, list):
input_dtype = x[0].dtype
batch_size = x[0].shape[0]
else:
input_dtype = x.dtype
batch_size = x.shape[0]
# Process input
merged_args = {**transformer_options, **kwargs}
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
)
# Process output
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
return x
class LTXVModel(LTXBaseModel):
"""LTXV model for video generation."""
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXV."""
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
for d in range(num_layers)
for _ in range(self.num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXV."""
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
self.norm_out = self.operations.LayerNorm(
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1, start_end=True)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXV."""
additional_args = {"orig_shape": list(x.shape)}
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
@@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
causal_fix=self.causal_temporal_positioning,
)
grid_mask = None
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
additional_args.update({"grid_mask": grid_mask})
x = x[:, grid_mask, :]
pixel_coords = pixel_coords[:, :, grid_mask, ...]
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
x = self.patchify_proj(x)
timestep = timestep * 1000.0
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
@@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
transformer_options=transformer_options,
)
# 3. Output
return x
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output for LTXV."""
# Apply scale-shift modulation
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = torch.addcmul(x, x, scale).add_(shift)
x = x * (1 + scale) + shift
x = self.proj_out(x)
if keyframe_idxs is not None:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
full_x[:, grid_mask, :] = x
x = full_x
# Unpatchify to restore original dimensions
orig_shape = kwargs["orig_shape"]
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],