Files
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

1605 lines
75 KiB
Python

from comfy_extras.nodes_lt import get_noise_mask, LTXVAddGuide
import types
import math
from typing import Tuple
import comfy
from comfy_api.latest import io
import numpy as np
import torch
import logging
import comfy.model_management as mm
device = mm.get_torch_device()
import latent_preview
class LTXVAddGuideMulti(LTXVAddGuide):
@classmethod
def define_schema(cls):
options = []
for num_guides in range(1, 21): # 1 to 20 guides
guide_inputs = []
for i in range(1, num_guides + 1):
guide_inputs.extend([
io.Image.Input(f"image_{i}"),
io.Int.Input(
f"frame_idx_{i}",
default=0,
min=-9999,
max=9999,
tooltip=f"Frame index for guide {i}.",
),
io.Float.Input(f"strength_{i}", default=1.0, min=0.0, max=1.0, step=0.01, tooltip=f"Strength for guide {i}."),
])
options.append(io.DynamicCombo.Option(
key=str(num_guides),
inputs=guide_inputs
))
return io.Schema(
node_id="LTXVAddGuideMulti",
category="KJNodes/ltxv",
description="Add multiple guide images at specified frame indices with strengths, uses DynamicCombo which requires ComfyUI 0.8.1 and frontend 1.33.4 or later.",
inputs=[
io.Conditioning.Input("positive", tooltip="Positive conditioning to which guide keyframe info will be added"),
io.Conditioning.Input("negative", tooltip="Negative conditioning to which guide keyframe info will be added"),
io.Vae.Input("vae", tooltip="Video VAE used to encode the guide images"),
io.Latent.Input("latent", tooltip="Video latent, guides are added to the end of this latent"),
io.DynamicCombo.Input(
"num_guides",
options=options,
display_name="Number of Guides",
tooltip="Select how many guide images to use",
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Video latent with added guides"),
],
)
@classmethod
def execute(cls, positive, negative, vae, latent, num_guides) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
# num_guides is a dict containing the inputs from the selected option
# e.g., {'image_1': tensor, 'frame_idx_1': 0, 'strength_1': 1.0, 'image_2': tensor, 'frame_idx_2': 20, 'strength_2': 0.8, ...}
image_keys = sorted([k for k in num_guides.keys() if k.startswith('image_')])
for img_key in image_keys:
i = img_key.split('_')[1]
img = num_guides[f"image_{i}"]
f_idx = num_guides[f"frame_idx_{i}"]
strength = num_guides[f"strength_{i}"]
image_1, t = cls.encode(vae, latent_width, latent_height, img, scale_factors)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image_1), f_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t,
strength,
scale_factors,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
class LTXVAddGuidesFromBatch(LTXVAddGuide):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVAddGuidesFromBatch",
category="conditioning/ltxv",
description="Adds multiple guide images from a batch to the latent at corresponding frame indices. Non-black images in the batch are used as guides.",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Latent.Input("latent"),
io.Image.Input("images", tooltip="Batch of images - non-black images will be used as guides"),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, latent, images, strength) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
# Process each image in the batch
batch_size = images.shape[0]
for i in range(batch_size):
img = images[i:i+1]
# Check if image is not black and use batch index as frame index
if img.max() > 0.001:
f_idx = i
image_1, t = cls.encode(vae, latent_width, latent_height, img, scale_factors)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image_1), f_idx, scale_factors)
if latent_idx + t.shape[2] <= latent_length:
positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t,
strength,
scale_factors,
)
else:
print(f"Warning: Skipping guide at index {i} - conditioning frames exceed latent sequence length")
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
class LTXVAudioVideoMask(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVAudioVideoMask",
category="KJNodes/ltxv",
description="Creates noise masks for video and audio latents based on specified time ranges. New content is generated within these masked regions",
inputs=[
io.Latent.Input("video_latent", optional=True),
io.Latent.Input("audio_latent", optional=True),
io.Float.Input("video_fps", default=25, min=0.0, max=100.0, step=0.1),
io.Float.Input("video_start_time", default=0.0, min=0.0, max=10000.0, step=0.1, tooltip="Start time in seconds for the video mask."),
io.Float.Input("video_end_time", default=5.0, min=0.0, max=10000.0, step=0.1, tooltip="End time in seconds for the video mask."),
io.Float.Input("audio_start_time", default=0.0, min=0.0, max=10000.0, step=0.1, tooltip="Start time in seconds for the audio mask."),
io.Float.Input("audio_end_time", default=5.0, min=0.0, max=10000.0, step=0.1, tooltip="End time in seconds for the audio mask."),
io.Combo.Input(
"max_length",
options=["truncate", "pad", "partial"],
default="truncate",
tooltip="'truncate': cut latent to end_time length. 'pad': extend latent to end_time. 'partial': mask range within existing latent.",
),
],
outputs=[
io.Latent.Output(display_name="video_latent"),
io.Latent.Output(display_name="audio_latent"),
],
)
@classmethod
def execute(cls, video_fps, video_start_time, video_end_time, audio_start_time, audio_end_time, max_length="truncate", video_latent=None, audio_latent=None) -> io.NodeOutput:
time_scale_factor = 8
mel_hop_length = 160
sampling_rate = 16000
latent_downsample_factor = 4
audio_latents_per_second = (sampling_rate / mel_hop_length / latent_downsample_factor) # 25
if video_latent is not None:
video_latent_frame_count = video_latent["samples"].shape[2]
video_pixel_frame_start_raw = int(round(video_start_time * video_fps))
video_pixel_frame_end_raw = int(round(video_end_time * video_fps))
# Calculate required latent frames based on end time
required_latent_frames = (video_pixel_frame_end_raw - 1) // time_scale_factor + 1
# Handle different max_length modes
if max_length == "pad" and required_latent_frames > video_latent_frame_count:
# Pad video latent if required frames exceed current length
pad_frames = required_latent_frames - video_latent_frame_count
padding = torch.zeros(
video_latent["samples"].shape[0],
video_latent["samples"].shape[1],
pad_frames,
video_latent["samples"].shape[3],
video_latent["samples"].shape[4],
dtype=video_latent["samples"].dtype,
device=video_latent["samples"].device
)
video_samples = torch.cat([video_latent["samples"], padding], dim=2)
video_latent_frame_count = video_samples.shape[2]
elif max_length == "truncate":
# Truncate to the end_time
video_samples = video_latent["samples"][:, :, :required_latent_frames]
video_latent_frame_count = video_samples.shape[2]
else: # partial
video_samples = video_latent["samples"]
# Now calculate indices based on potentially padded latent
video_pixel_frame_count = (video_latent_frame_count - 1) * time_scale_factor + 1
xp = np.array([0] + list(range(1, video_pixel_frame_count + time_scale_factor, time_scale_factor)))
# video_frame_index_start = index of the value in xp rounding up
video_latent_frame_index_start = np.searchsorted(xp, video_pixel_frame_start_raw, side="left")
# video_frame_index_end = index of the value in xp rounding down
video_latent_frame_index_end = np.searchsorted(xp, video_pixel_frame_end_raw, side="right") - 1
video_latent_frame_index_start = max(0, video_latent_frame_index_start)
video_latent_frame_index_end = min(video_latent_frame_index_end, video_latent_frame_count)
# Get existing noise mask if present, otherwise create new one
if "noise_mask" in video_latent:
video_mask = video_latent["noise_mask"].clone()
# Adjust mask size based on mode
if max_length == "pad" and video_samples.shape[2] > video_latent["samples"].shape[2]:
# Pad the mask if we padded the samples
mask_padding = torch.zeros(
video_mask.shape[0],
video_mask.shape[1],
video_samples.shape[2] - video_mask.shape[2],
video_mask.shape[3],
video_mask.shape[4],
dtype=video_mask.dtype,
device=video_mask.device
)
video_mask = torch.cat([video_mask, mask_padding], dim=2)
elif max_length == "truncate":
# Truncate the mask to match truncated samples
video_mask = video_mask[:, :, :video_samples.shape[2]]
else:
video_mask = torch.zeros_like(video_samples)[:, :1]
video_mask[:, :, video_latent_frame_index_start:video_latent_frame_index_end] = 1.0
# ensure all padded frames are also masked
if max_length == "pad" and video_samples.shape[2] > video_latent["samples"].shape[2]:
video_mask[:, :, video_latent["samples"].shape[2]:] = 1.0
video_latent = video_latent.copy()
video_latent["samples"] = video_samples
video_latent["noise_mask"] = video_mask
if audio_latent is not None:
audio_latent_frame_count = audio_latent["samples"].shape[2]
audio_latent_frame_index_start = int(round(audio_start_time * audio_latents_per_second))
audio_latent_frame_index_end = int(round(audio_end_time * audio_latents_per_second)) + 1
# Handle different max_length modes
if max_length == "pad" and audio_latent_frame_index_end > audio_latent_frame_count:
# Pad audio latent if end index exceeds current length
pad_frames = audio_latent_frame_index_end - audio_latent_frame_count
padding = torch.zeros(
audio_latent["samples"].shape[0],
audio_latent["samples"].shape[1],
pad_frames,
audio_latent["samples"].shape[3],
dtype=audio_latent["samples"].dtype,
device=audio_latent["samples"].device
)
audio_samples = torch.cat([audio_latent["samples"], padding], dim=2)
audio_latent_frame_count = audio_samples.shape[2]
elif max_length == "truncate":
# Truncate to the end_time
audio_samples = audio_latent["samples"][:, :, :audio_latent_frame_index_end]
audio_latent_frame_count = audio_samples.shape[2]
else: # partial
audio_samples = audio_latent["samples"]
audio_latent_frame_index_start = max(0, audio_latent_frame_index_start)
audio_latent_frame_index_end = min(audio_latent_frame_index_end, audio_latent_frame_count)
# Get existing noise mask if present, otherwise create new one
if "noise_mask" in audio_latent:
audio_mask = audio_latent["noise_mask"].clone()
# Adjust mask size based on mode
if max_length == "pad" and audio_samples.shape[2] > audio_latent["samples"].shape[2]:
# Pad the mask if we padded the samples
mask_padding = torch.zeros(
audio_mask.shape[0],
audio_mask.shape[1],
audio_samples.shape[2] - audio_mask.shape[2],
audio_mask.shape[3],
dtype=audio_mask.dtype,
device=audio_mask.device
)
audio_mask = torch.cat([audio_mask, mask_padding], dim=2)
elif max_length == "truncate":
# Truncate the mask to match truncated samples
audio_mask = audio_mask[:, :, :audio_samples.shape[2]]
else:
audio_mask = torch.zeros_like(audio_samples)
audio_mask[:, :, audio_latent_frame_index_start:audio_latent_frame_index_end] = 1.0
# ensure all padded frames are also masked
if max_length == "pad" and audio_samples.shape[2] > audio_latent["samples"].shape[2]:
audio_mask[:, :, audio_latent["samples"].shape[2]:] = 1.0
audio_latent = audio_latent.copy()
audio_latent["samples"] = audio_samples
audio_latent["noise_mask"] = audio_mask
return io.NodeOutput(video_latent, audio_latent)
def _compute_attention(self, query, context, transformer_options={}):
"""Compute attention and return the result. Cleans up intermediate tensors."""
k = self.k_norm(self.to_k(context)).to(query.dtype)
v = self.to_v(context).to(query.dtype)
x = comfy.ldm.modules.attention.optimized_attention(query, k, v, heads=self.heads, transformer_options=transformer_options).flatten(2)
del k, v
return x
def nag_attention(self, query, context_positive, nag_context, transformer_options={}):
x_positive = _compute_attention(self, query, context_positive, transformer_options)
x_negative = _compute_attention(self, query, nag_context, transformer_options)
return x_positive, x_negative
def normalized_attention_guidance(self, x_positive, x_negative):
if self.inplace:
nag_guidance = x_negative.mul_(self.nag_scale - 1).neg_().add_(x_positive, alpha=self.nag_scale)
else:
nag_guidance = x_positive * self.nag_scale - x_negative * (self.nag_scale - 1)
del x_negative
norm_positive = torch.norm(x_positive, p=1, dim=-1, keepdim=True)
norm_guidance = torch.norm(nag_guidance, p=1, dim=-1, keepdim=True)
scale = norm_guidance / norm_positive
torch.nan_to_num_(scale, nan=10.0)
mask = scale > self.nag_tau
del scale
adjustment = (norm_positive * self.nag_tau) / (norm_guidance + 1e-7)
del norm_positive, norm_guidance
nag_guidance.mul_(torch.where(mask, adjustment, 1.0))
del mask, adjustment
if self.inplace:
nag_guidance.sub_(x_positive).mul_(self.nag_alpha).add_(x_positive)
else:
nag_guidance = nag_guidance * self.nag_alpha + x_positive * (1 - self.nag_alpha)
del x_positive
return nag_guidance
#region NAG
def ltxv_crossattn_forward_nag(self, x, context, mask=None, transformer_options={}, **kwargs):
# Single or [pos, neg] pair
if context.shape[0] == 1:
x_pos, context_pos = x, context
x_neg, context_neg = None, None
else:
x_pos, x_neg = torch.chunk(x, 2, dim=0)
context_pos, context_neg = torch.chunk(context, 2, dim=0)
# Positive
q_pos = self.q_norm(self.to_q(x_pos))
del x_pos
x_positive, x_negative = nag_attention(self, q_pos, context_pos, self.nag_context, transformer_options=transformer_options)
del context_pos, q_pos
x_pos_out = normalized_attention_guidance(self, x_positive, x_negative)
del x_positive, x_negative
# Negative
if x_neg is not None and context_neg is not None:
q_neg = self.q_norm(self.to_q(x_neg))
k_neg = self.k_norm(self.to_k(context_neg))
v_neg = self.to_v(context_neg)
x_neg_out = comfy.ldm.modules.attention.optimized_attention(q_neg, k_neg, v_neg, heads=self.heads, transformer_options=transformer_options)
x = torch.cat([x_pos_out, x_neg_out], dim=0)
else:
x = x_pos_out
return self.to_out(x)
class LTXVCrossAttentionPatch:
def __init__(self, context, nag_scale, nag_alpha, nag_tau, inplace=True):
self.nag_context = context
self.nag_scale = nag_scale
self.nag_alpha = nag_alpha
self.nag_tau = nag_tau
self.inplace = inplace
def __get__(self, obj, objtype=None):
# Create bound method with stored parameters
def wrapped_attention(self_module, *args, **kwargs):
self_module.nag_context = self.nag_context
self_module.nag_scale = self.nag_scale
self_module.nag_alpha = self.nag_alpha
self_module.nag_tau = self.nag_tau
self_module.inplace = self.inplace
return ltxv_crossattn_forward_nag(self_module, *args, **kwargs)
return types.MethodType(wrapped_attention, obj)
class LTX2_NAG(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTX2_NAG",
display_name="LTX2 NAG",
category="KJNodes/ltxv",
description="https://github.com/ChenDarYen/Normalized-Attention-Guidance",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("nag_scale", default=11.0, min=0.0, max=100.0, step=0.001, tooltip="Strength of negative guidance effect"),
io.Float.Input("nag_alpha", default=0.25, min=0.0, max=1.0, step=0.001, tooltip="Mixing coefficient in that controls the balance between the normalized guided representation and the original positive representation."),
io.Float.Input("nag_tau", default=2.5, min=0.0, max=10.0, step=0.001, tooltip="Clipping threshold that controls how much the guided attention can deviate from the positive attention."),
io.Conditioning.Input("nag_cond_video", optional=True),
io.Conditioning.Input("nag_cond_audio", optional=True),
io.Boolean.Input("inplace", default=True, optional=True, tooltip="If true, modifies tensors in place to save memory. Leads to different numerical results which may change the output slightly."),
],
outputs=[
io.Model.Output(display_name="model"),
],
)
@classmethod
def execute(cls, model, nag_scale, nag_alpha, nag_tau, nag_cond_video=None, nag_cond_audio=None, inplace=True) -> io.NodeOutput:
if nag_scale == 0:
return io.NodeOutput(model)
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = model.model.manual_cast_dtype
if dtype is None:
dtype = model.model.diffusion_model.dtype
model_clone = model.clone()
diffusion_model = model_clone.get_model_object("diffusion_model")
img_dim = diffusion_model.inner_dim
audio_dim = diffusion_model.audio_inner_dim
context_video = context_audio = None
if nag_cond_video is not None:
diffusion_model.caption_projection.to(device)
context_video = nag_cond_video[0][0].to(device, dtype)
v_context, _ = torch.split(context_video, int(context_video.shape[-1] / 2), len(context_video.shape) - 1)
context_video = diffusion_model.caption_projection(v_context)
diffusion_model.caption_projection.to(offload_device)
context_video = context_video.view(1, -1, img_dim)
for idx, block in enumerate(diffusion_model.transformer_blocks):
patched_attn2 = LTXVCrossAttentionPatch(context_video, nag_scale, nag_alpha, nag_tau, inplace=inplace).__get__(block.attn2, block.__class__)
model_clone.add_object_patch(f"diffusion_model.transformer_blocks.{idx}.attn2.forward", patched_attn2)
if nag_cond_audio is not None and diffusion_model.audio_caption_projection is not None:
diffusion_model.audio_caption_projection.to(device)
context_audio = nag_cond_audio[0][0].to(device, dtype)
_, a_context = torch.split(context_audio, int(context_audio.shape[-1] / 2), len(context_audio.shape) - 1)
context_audio = diffusion_model.audio_caption_projection(a_context)
diffusion_model.audio_caption_projection.to(offload_device)
context_audio = context_audio.view(1, -1, audio_dim)
for idx, block in enumerate(diffusion_model.transformer_blocks):
patched_audio_attn2 = LTXVCrossAttentionPatch(context_audio, nag_scale, nag_alpha, nag_tau, inplace=inplace).__get__(block.audio_attn2, block.__class__)
model_clone.add_object_patch(f"diffusion_model.transformer_blocks.{idx}.audio_attn2.forward", patched_audio_attn2)
return io.NodeOutput(model_clone)
def ffn_chunked_forward(self, x):
if x.shape[1] > self.dim_threshold:
chunk_size = x.shape[1] // self.num_chunks
for i in range(self.num_chunks):
start_idx = i * chunk_size
end_idx = (i + 1) * chunk_size if i < self.num_chunks - 1 else x.shape[1]
x[:, start_idx:end_idx] = self.net(x[:, start_idx:end_idx])
return x
else:
return self.net(x)
class LTXVffnChunkPatch:
def __init__(self, num_chunks, dim_threshold=4096):
self.num_chunks = num_chunks
self.dim_threshold = dim_threshold
def __get__(self, obj, objtype=None):
def wrapped_forward(self_module, *args, **kwargs):
self_module.num_chunks = self.num_chunks
self_module.dim_threshold = self.dim_threshold
return ffn_chunked_forward(self_module, *args, **kwargs)
return types.MethodType(wrapped_forward, obj)
class LTXVChunkFeedForward(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVChunkFeedForward",
display_name="LTXV Chunk FeedForward",
category="KJNodes/ltxv",
description="EXPERIMENTAL AND MAY CHANGE THE MODEL OUTPUT!! Chunks feedforward activations to reduce peak VRAM usage.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Int.Input("chunks", default=2, min=1, max=100, step=1, tooltip="Number of chunks to split the feedforward activations into to reduce peak VRAM usage."),
io.Int.Input("dim_threshold", default=4096, min=0, max=16384, step=256, tooltip="Dimension threshold above which to apply chunking."),
],
outputs=[
io.Model.Output(display_name="model"),
],
)
@classmethod
def execute(cls, model, chunks, dim_threshold) -> io.NodeOutput:
if chunks == 1:
return io.NodeOutput(model)
model_clone = model.clone()
diffusion_model = model_clone.get_model_object("diffusion_model")
for idx, block in enumerate(diffusion_model.transformer_blocks):
patched_attn2 = LTXVffnChunkPatch(chunks, dim_threshold).__get__(block.ff, block.__class__)
model_clone.add_object_patch(f"diffusion_model.transformer_blocks.{idx}.ff.forward", patched_attn2)
return io.NodeOutput(model_clone)
#borrowed VideoHelperSuite https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/blob/main/videohelpersuite/latent_preview.py
import server
from threading import Thread
import torch.nn.functional as F
import time
import struct
from PIL import Image
from io import BytesIO
serv = server.PromptServer.instance
class WrappedPreviewer():
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias, rate=8, taeltx=None):
self.first_preview = True
self.taeltx = taeltx
self.last_time = 0
self.c_index = 0
self.rate = rate
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") if latent_rgb_factors_bias is not None else None
def decode_latent_to_preview_image(self, preview_format, x0):
if x0.ndim == 5:
#Keep batch major
x0 = x0.movedim(2,1)
x0 = x0.reshape((-1,)+x0.shape[-3:])
num_images = x0.size(0)
new_time = time.time()
num_previews = int((new_time - self.last_time) * self.rate)
self.last_time = self.last_time + num_previews/self.rate
if num_previews > num_images:
num_previews = num_images
elif num_previews <= 0:
return None
if self.first_preview:
self.first_preview = False
serv.send_sync('VHS_latentpreview', {'length':num_images, 'rate': self.rate, 'id': serv.last_node_id})
self.last_time = new_time + 1/self.rate
if self.c_index + num_previews > num_images:
x0 = x0.roll(-self.c_index, 0)[:num_previews]
else:
x0 = x0[self.c_index:self.c_index + num_previews]
Thread(target=self.process_previews, args=(x0, self.c_index,
num_images)).run()
self.c_index = (self.c_index + num_previews) % num_images
return None
def process_previews(self, image_tensor, ind, leng):
max_size = 256
image_tensor = self.decode_latent_to_preview(image_tensor)
if image_tensor.size(1) > max_size or image_tensor.size(2) > max_size:
image_tensor = image_tensor.movedim(-1,0)
if image_tensor.size(2) < image_tensor.size(3):
height = (max_size * image_tensor.size(2)) // image_tensor.size(3)
image_tensor = F.interpolate(image_tensor, (height,max_size), mode='bilinear')
else:
width = (max_size * image_tensor.size(3)) // image_tensor.size(2)
image_tensor = F.interpolate(image_tensor, (max_size, width), mode='bilinear')
image_tensor = image_tensor.movedim(0,-1)
previews_ubyte = (image_tensor.clamp(0, 1)
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8)
# Send VHS preview
for preview in previews_ubyte:
i = Image.fromarray(preview.numpy())
message = BytesIO()
message.write((1).to_bytes(length=4, byteorder='big')*2)
message.write(ind.to_bytes(length=4, byteorder='big'))
message.write(struct.pack('16p', serv.last_node_id.encode('ascii')))
i.save(message, format="JPEG", quality=95, compress_level=1)
#NOTE: send sync already uses call_soon_threadsafe
serv.send_sync(server.BinaryEventTypes.PREVIEW_IMAGE,
message.getvalue(), serv.client_id)
if self.taeltx is not None:
ind = (ind + 1) % ((leng-1) * 8 + 1)
else:
ind = (ind + 1) % leng
def decode_latent_to_preview(self, x0):
if self.taeltx is not None:
x0 = x0.unsqueeze(0).to(dtype=self.taeltx.vae_dtype, device=device)
x_sample = self.taeltx.first_stage_model.decode(x0)[0].permute(1, 2, 3, 0)
return x_sample
else:
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
latent_image = F.linear(x0.movedim(1, -1), self.latent_rgb_factors,
bias=self.latent_rgb_factors_bias)
latent_image = (latent_image + 1.0) / 2.0
return latent_image
def prepare_callback(model, steps, x0_output_dict=None, shape=None, latent_upscale_model=None, vae=None, rate=8, taeltx=False, num_keyframes=0):
latent_rgb_factors = [
[ 0.0350, 0.0159, 0.0132],
[ 0.0025, -0.0021, -0.0003],
[ 0.0286, 0.0028, 0.0020],
[ 0.0280, -0.0114, -0.0202],
[-0.0186, 0.0073, 0.0092],
[ 0.0027, 0.0097, -0.0113],
[-0.0069, -0.0032, -0.0024],
[-0.0323, -0.0370, -0.0457],
[ 0.0174, 0.0164, 0.0106],
[-0.0097, 0.0061, 0.0035],
[-0.0130, -0.0042, -0.0012],
[-0.0102, -0.0002, -0.0091],
[-0.0025, 0.0063, 0.0161],
[ 0.0003, 0.0037, 0.0108],
[ 0.0152, 0.0082, 0.0143],
[ 0.0317, 0.0203, 0.0312],
[-0.0092, -0.0233, -0.0119],
[-0.0405, -0.0226, -0.0023],
[ 0.0376, 0.0397, 0.0352],
[ 0.0171, -0.0043, -0.0095],
[ 0.0482, 0.0341, 0.0213],
[ 0.0031, -0.0046, -0.0018],
[-0.0486, -0.0383, -0.0294],
[-0.0071, -0.0272, -0.0123],
[ 0.0320, 0.0218, 0.0289],
[ 0.0327, 0.0088, -0.0116],
[-0.0098, -0.0240, -0.0111],
[ 0.0094, -0.0116, 0.0021],
[ 0.0309, 0.0092, 0.0165],
[-0.0065, -0.0077, -0.0107],
[ 0.0179, 0.0114, 0.0038],
[-0.0018, -0.0030, -0.0026],
[-0.0002, 0.0076, -0.0029],
[-0.0131, -0.0059, -0.0170],
[ 0.0055, 0.0066, -0.0038],
[ 0.0154, 0.0063, 0.0090],
[ 0.0186, 0.0175, 0.0188],
[-0.0166, -0.0381, -0.0428],
[ 0.0121, 0.0015, -0.0153],
[ 0.0118, 0.0050, 0.0019],
[ 0.0125, 0.0259, 0.0231],
[ 0.0046, 0.0130, 0.0081],
[ 0.0271, 0.0250, 0.0250],
[-0.0054, -0.0347, -0.0326],
[-0.0438, -0.0262, -0.0228],
[-0.0191, -0.0256, -0.0173],
[-0.0205, -0.0058, 0.0042],
[ 0.0404, 0.0434, 0.0346],
[-0.0242, -0.0177, -0.0146],
[ 0.0161, 0.0223, 0.0168],
[-0.0240, -0.0320, -0.0299],
[-0.0019, 0.0043, 0.0008],
[-0.0060, -0.0133, -0.0244],
[-0.0048, -0.0225, -0.0167],
[ 0.0267, 0.0133, 0.0152],
[ 0.0222, 0.0167, 0.0028],
[ 0.0015, -0.0062, 0.0013],
[-0.0241, -0.0178, -0.0079],
[ 0.0040, -0.0081, -0.0097],
[-0.0064, 0.0133, -0.0011],
[-0.0204, -0.0231, -0.0304],
[ 0.0011, -0.0011, 0.0145],
[-0.0283, -0.0259, -0.0260],
[ 0.0038, 0.0171, -0.0029],
[ 0.0637, 0.0424, 0.0409],
[ 0.0092, 0.0163, 0.0188],
[ 0.0082, 0.0055, -0.0179],
[-0.0177, -0.0286, -0.0147],
[ 0.0171, 0.0242, 0.0398],
[-0.0129, 0.0095, -0.0071],
[-0.0154, 0.0036, 0.0128],
[-0.0081, -0.0009, 0.0118],
[-0.0067, -0.0178, -0.0230],
[-0.0022, -0.0125, -0.0003],
[-0.0032, -0.0039, -0.0022],
[-0.0005, -0.0127, -0.0131],
[-0.0143, -0.0157, -0.0165],
[-0.0262, -0.0263, -0.0270],
[ 0.0063, 0.0127, 0.0178],
[ 0.0092, 0.0133, 0.0150],
[-0.0106, -0.0068, 0.0032],
[-0.0214, -0.0022, 0.0171],
[-0.0104, -0.0266, -0.0362],
[ 0.0021, 0.0048, -0.0005],
[ 0.0345, 0.0431, 0.0402],
[-0.0275, -0.0110, -0.0195],
[ 0.0203, 0.0251, 0.0224],
[ 0.0016, -0.0037, -0.0094],
[ 0.0241, 0.0198, 0.0114],
[-0.0003, 0.0027, 0.0141],
[ 0.0012, -0.0052, -0.0084],
[ 0.0057, -0.0028, -0.0163],
[-0.0488, -0.0545, -0.0509],
[-0.0076, -0.0025, -0.0014],
[-0.0249, -0.0142, -0.0367],
[ 0.0136, 0.0041, 0.0135],
[ 0.0007, 0.0034, -0.0053],
[-0.0068, -0.0109, 0.0029],
[ 0.0006, -0.0237, -0.0094],
[-0.0149, -0.0177, -0.0131],
[-0.0105, 0.0039, 0.0216],
[ 0.0242, 0.0200, 0.0180],
[-0.0339, -0.0153, -0.0195],
[ 0.0104, 0.0151, 0.0120],
[-0.0043, 0.0089, 0.0047],
[ 0.0157, -0.0030, 0.0008],
[ 0.0126, 0.0102, -0.0040],
[ 0.0040, 0.0114, 0.0137],
[ 0.0423, 0.0473, 0.0436],
[-0.0128, -0.0066, -0.0152],
[-0.0337, -0.0087, -0.0026],
[-0.0052, 0.0235, 0.0291],
[ 0.0079, 0.0154, 0.0260],
[-0.0539, -0.0377, -0.0358],
[-0.0188, 0.0062, -0.0035],
[-0.0186, 0.0041, -0.0083],
[ 0.0045, -0.0049, 0.0053],
[ 0.0172, 0.0071, 0.0042],
[-0.0003, -0.0078, -0.0096],
[-0.0209, -0.0132, -0.0135],
[-0.0074, 0.0017, 0.0099],
[-0.0038, 0.0070, 0.0014],
[-0.0013, -0.0017, 0.0073],
[ 0.0030, 0.0105, 0.0105],
[ 0.0154, -0.0168, -0.0235],
[-0.0108, -0.0038, 0.0047],
[-0.0298, -0.0347, -0.0436],
[-0.0206, -0.0189, -0.0139]
]
latent_rgb_factors_bias = [0.2796, 0.1101, -0.0047]
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = WrappedPreviewer(latent_rgb_factors, latent_rgb_factors_bias, rate=rate, taeltx=vae if taeltx else None)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
if x0 is not None and shape is not None:
cut = math.prod(shape[1:])
x0 = x0[:, :, :cut].reshape([x0.shape[0]] + list(shape)[1:])
if num_keyframes > 0:
x0 = x0[:, :, :-num_keyframes]
if latent_upscale_model is not None:
x0 = vae.first_stage_model.per_channel_statistics.un_normalize(x0)
x0 = latent_upscale_model(x0.to(torch.bfloat16))
x0 = vae.first_stage_model.per_channel_statistics.normalize(x0)
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback
class OuterSampleCallbackWrapper:
def __init__(self, latent_upscale_model=None, vae=None, preview_rate=8, taeltx=False):
self.latent_upscale_model = latent_upscale_model
self.vae = vae
self.preview_rate = preview_rate
self.taeltx = taeltx
self.x0_output = {}
def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes):
guider = executor.class_obj
original_callback = callback
if self.latent_upscale_model is not None:
self.latent_upscale_model.to(device)
if self.vae is not None and self.taeltx:
self.vae.first_stage_model.to(device)
num_keyframes = 0
if 'positive' in guider.conds and len(guider.conds['positive']) > 0:
keyframe_idxs = guider.conds['positive'][0].get('keyframe_idxs')
if keyframe_idxs is not None:
num_keyframes = len(torch.unique(keyframe_idxs[0, 0, :, 0]))
new_callback = prepare_callback(guider.model_patcher, len(sigmas) -1, shape=latent_shapes[0] if len(latent_shapes) > 1 else None,
x0_output_dict=self.x0_output, latent_upscale_model=self.latent_upscale_model, vae=self.vae, rate=self.preview_rate, taeltx=self.taeltx, num_keyframes=num_keyframes)
# Wrapper that calls both callbacks
def combined_callback(step, x0, x, total_steps):
new_callback(step, x0, x, total_steps)
if original_callback is not None:
original_callback(step, x0, x, total_steps)
out = executor(noise, latent_image, sampler, sigmas, denoise_mask, combined_callback, disable_pbar, seed, latent_shapes=latent_shapes)
if self.latent_upscale_model is not None:
self.latent_upscale_model.to(mm.unet_offload_device())
return out
class LTX2SamplingPreviewOverride(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTX2SamplingPreviewOverride",
display_name="LTX2 Sampling Preview Override",
description="Overrides the LTX2 preview sampling preview function, temporary measure until previews are in comfy core",
category="KJNodes/experimental",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add preview override to."),
io.Int.Input("preview_rate", default=8, min=1, max=60, step=1, tooltip="Preview frame rate."),
io.LatentUpscaleModel.Input("latent_upscale_model", optional=True, tooltip="Optional upscale model to use for higher resolution previews."),
io.Vae.Input("vae", optional=True, tooltip="VAE model to use normalizing the latents for the upscale model."),
],
outputs=[
io.Model.Output(tooltip="The model with Sampling Preview Override."),
],
)
@classmethod
def execute(cls, model, preview_rate, latent_upscale_model=None, vae=None) -> io.NodeOutput:
model = model.clone()
taeltx = False
if vae is not None:
if vae.first_stage_model.__class__.__name__ == "TAEHV":
taeltx = True
latent_upscale_model=None
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "sampling_preview", OuterSampleCallbackWrapper(latent_upscale_model, vae, preview_rate, taeltx))
return io.NodeOutput(model)
# based on https://github.com/Lightricks/ComfyUI-LTXVideo/blob/cd5d371518afb07d6b3641be8012f644f25269fc/easy_samplers.py#L916
class OuterSampleAudioNormalizationWrapper:
def __init__(self, audio_normalization_factors):
self.audio_normalization_factors = audio_normalization_factors
def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes):
guider = executor.class_obj
ltxav = guider.model_patcher.model.diffusion_model
x0_output = {}
self.total_steps = sigmas.shape[-1] - 1
pbar = comfy.utils.ProgressBar(self.total_steps)
self.full_step = 0
previewer = latent_preview.get_previewer(guider.model_patcher.load_device, guider.model_patcher.model.latent_format)
def custom_callback(step, x0, x, total_steps):
if x0_output is not None:
x0_output["x0"] = x0
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0)
self.full_step += 1
pbar.update_absolute(self.full_step, self.total_steps, preview_bytes)
callback = custom_callback
audio_normalization_factors = self.audio_normalization_factors.strip().split(",")
audio_normalization_factors = [float(factor) for factor in audio_normalization_factors]
# Extend normalization factors to match the length of sigmas
sigmas_len = self.total_steps
if len(audio_normalization_factors) < sigmas_len and len(audio_normalization_factors) > 0:
audio_normalization_factors.extend([audio_normalization_factors[-1]] * (sigmas_len - len(audio_normalization_factors)))
# Calculate indices where both normalization factors are not 1.0
sampling_split_indices = [i + 1 for i, a in enumerate(audio_normalization_factors) if a != 1.0]
# Split sigmas according to sampling_split_indices
def split_by_indices(arr, indices):
"""
Splits arr into chunks according to indices (split points).
Indices are treated as starting a new chunk at each index in the list.
"""
if not indices:
return [arr]
split_points = sorted(set(indices))
chunks = []
prev = 0
for idx in split_points:
if prev < idx:
chunks.append(arr[prev : idx + 1])
prev = idx
if prev < len(arr):
chunks.append(arr[prev:])
return chunks
sigmas_chunks = split_by_indices(sigmas, sampling_split_indices)
i = 0
for sigmas_chunk in sigmas_chunks:
i += len(sigmas_chunk) - 1
latent_image = executor(noise, latent_image, sampler, sigmas_chunk, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
if "x0" in x0_output:
latent_image = guider.model_patcher.model.process_latent_out(x0_output["x0"])
if i - 1 < len(audio_normalization_factors):
vx, ax = ltxav.separate_audio_and_video_latents(comfy.utils.unpack_latents(latent_image, latent_shapes), None)
if denoise_mask is not None:
audio_mask = ltxav.separate_audio_and_video_latents(comfy.utils.unpack_latents(denoise_mask, latent_shapes), None)[1]
ax = ax * audio_mask * audio_normalization_factors[i - 1] + ax * (1 - audio_mask)
else:
ax = ax * audio_normalization_factors[i - 1]
latent_image = comfy.utils.pack_latents(ltxav.recombine_audio_and_video_latents(vx, ax))[0]
print("After %d steps, the audio latent was normalized by %f" % (i, audio_normalization_factors[i - 1]))
return latent_image
class LTX2AudioLatentNormalizingSampling(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTX2AudioLatentNormalizingSampling",
display_name="LTX2 Audio Latent Normalizing Sampling",
description="Improves LTX2 generated audio quality by normalizing audio latents at specified sampling steps.",
category="KJNodes/experimental",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to add preview override to."),
io.String.Input("audio_normalization_factors", default="1,1,0.25,1,1,0.25,1,1", tooltip="Comma-separated list of audio normalization factors to apply at each sampling step. For example, '1,1,0.25,1,1,0.25,1,1' will apply a factor of 0.25 at the 3rd and 6th steps."),
],
outputs=[
io.Model.Output(tooltip="The model with Audio Latent Normalizing Sampling."),
],
)
@classmethod
def execute(cls, model, audio_normalization_factors) -> io.NodeOutput:
model = model.clone()
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "ltx2_audio_normalization", OuterSampleAudioNormalizationWrapper(audio_normalization_factors))
return io.NodeOutput(model)
class LTXVImgToVideoInplaceKJ(io.ComfyNode):
@classmethod
def define_schema(cls):
options = []
for num_images in range(1, 21): # 1 to 20 images
image_inputs = []
for i in range(1, num_images + 1):
image_inputs.extend([
io.Image.Input(f"image_{i}", optional=True, tooltip=f"Image {i} to insert into the video latent."),
io.Int.Input(
f"index_{i}",
default=0,
min=-9999,
max=9999,
step=1,
tooltip=f"Frame index for image {i} (in pixel space).",
optional=True,
),
io.Float.Input(f"strength_{i}", default=1.0, min=0.0, max=1.0, step=0.01, tooltip=f"Strength for image {i}."),
])
options.append(io.DynamicCombo.Option(
key=str(num_images),
inputs=image_inputs
))
return io.Schema(
node_id="LTXVImgToVideoInplaceKJ",
category="KJNodes/ltxv",
description="Replaces video latent frames with the encoded input images, uses DynamicCombo which requires ComfyUI 0.8.1 and frontend 1.33.4 or later.",
inputs=[
io.Vae.Input("vae", tooltip="Video VAE used to encode the images"),
io.Latent.Input("latent", tooltip="Video latent to insert images into"),
io.DynamicCombo.Input(
"num_images",
options=options,
display_name="Number of Images",
tooltip="Select how many images to insert",
),
],
outputs=[
io.Latent.Output(display_name="latent", tooltip="The video latent with the images inserted and latent noise mask updated."),
],
)
@classmethod
def execute(cls, vae, latent, num_images) -> io.NodeOutput:
samples = latent["samples"].clone()
scale_factors = vae.downscale_index_formula
_, height_scale_factor, width_scale_factor = scale_factors
batch, _, latent_frames, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
# Get existing noise mask if present, otherwise create new one
if "noise_mask" in latent:
conditioning_latent_frames_mask = latent["noise_mask"].clone()
else:
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
# num_images is a dict containing the inputs from the selected option
# e.g., {'image_1': tensor, 'frame_idx_1': 0, 'strength_1': 1.0, 'image_2': tensor, 'frame_idx_2': 20, 'strength_2': 0.8, ...}
image_keys = sorted([k for k in num_images.keys() if k.startswith('image_')])
for img_key in image_keys:
i = img_key.split('_')[1]
image = num_images[f"image_{i}"]
if image is None:
continue
index = num_images.get(f"index_{i}")
if index is None:
continue
strength = num_images[f"strength_{i}"]
if image.shape[1] != height or image.shape[2] != width:
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
else:
pixels = image
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
# Convert pixel frame index to latent index
time_scale_factor = scale_factors[0]
# Handle negative indexing in pixel space
pixel_frame_count = (latent_frames - 1) * time_scale_factor + 1
if index < 0:
index = pixel_frame_count + index
# Convert to latent index
latent_idx = index // time_scale_factor
# Clamp to valid range
latent_idx = max(0, min(latent_idx, latent_frames - 1))
# Calculate end index, ensuring we don't exceed latent_frames
end_index = min(latent_idx + t.shape[2], latent_frames)
# Replace samples at the specified index range
samples[:, :, latent_idx:end_index] = t[:, :, :end_index - latent_idx]
# Update mask at the specified index range
conditioning_latent_frames_mask[:, :, latent_idx:end_index] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
def ltx2_forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
video_scale = getattr(self, 'video_scale', 1.0)
audio_scale = getattr(self, 'audio_scale', 1.0)
audio_to_video_scale = getattr(self, 'audio_to_video_scale', 1.0)
video_to_audio_scale = getattr(self, 'video_to_audio_scale', 1.0)
vx, ax = x
run_ax = run_ax and ax.numel() > 0 and audio_scale != 0.0
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 and audio_to_video_scale != 0.0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) and video_to_audio_scale != 0.0
# Video self-attention.
if run_vx:
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
del norm_vx
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx += attn1_out * vgate_msa * video_scale
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options), alpha=video_scale)
# Audio self-attention.
if run_ax:
ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2)))
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
del ashift_msa, ascale_msa
attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
del norm_ax
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax += attn1_out * agate_msa * audio_scale
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options), alpha=audio_scale)
# Audio - Video cross attention.
if run_a2v or run_v2a:
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
# audio to video cross attention
if run_a2v:
scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(0, 2))
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v
del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v
scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(0, 2))
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v
del scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v
a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options)
del vx_scaled, ax_scaled
gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep, slice(0, 1))[0]
vx += a2v_out * gate_out_a2v * audio_to_video_scale
del gate_out_a2v, a2v_out
# video to audio cross attention
if run_v2a:
scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep, slice(2, 4))
vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a
del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, vx_norm3
scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values(
self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep, slice(2, 4))
ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a
del scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a, ax_norm3
v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options)
del ax_scaled, vx_scaled
gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep, slice(0, 1))[0]
ax += v2a_out * gate_out_v2a * video_to_audio_scale
del gate_out_v2a, v2a_out
# video feedforward
if run_vx:
vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5))
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
del vshift_mlp, vscale_mlp
ff_out = self.ff(vx_scaled)
del vx_scaled
vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0]
vx += ff_out * vgate_mlp * video_scale
del vgate_mlp, ff_out
# audio feedforward
if run_ax:
ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5))
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
del ashift_mlp, ascale_mlp
ff_out = self.audio_ff(ax_scaled)
del ax_scaled
agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0]
ax += ff_out * agate_mlp * audio_scale
del agate_mlp, ff_out
return vx, ax
class LTX2ForwardPatch:
def __init__(self, video, audio, audio_to_video, video_to_audio):
self.video_scale = video
self.audio_scale = audio
self.video_to_audio_scale = video_to_audio
self.audio_to_video_scale = audio_to_video
def __get__(self, obj, objtype=None):
def wrapped_forward(self_module, *args, **kwargs):
self_module.video_scale = self.video_scale
self_module.audio_scale = self.audio_scale
self_module.video_to_audio_scale = self.video_to_audio_scale
self_module.audio_to_video_scale = self.audio_to_video_scale
return ltx2_forward(self_module, *args, **kwargs)
return types.MethodType(wrapped_forward, obj)
class LTX2AttentionTunerPatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTX2AttentionTunerPatch",
display_name="LTX2 Attention Tuner Patch",
category="KJNodes/ltxv",
description="EXPERIMENTAL! Custom LTX2 forward pass with attention scaling factors per modality, also reduces peak VRAM usage.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.String.Input("blocks", default="", tooltip="Comma separated list of transformer block indices to apply the patch to. Leave empty to apply to all blocks."),
io.Float.Input("video_scale", default=1.0, min=0.0, max=100, step=0.01, tooltip="Scaling factor for video attention."),
io.Float.Input("audio_scale", default=1.0, min=0.0, max=100, step=0.01, tooltip="Scaling factor for audio attention."),
io.Float.Input("audio_to_video_scale", default=1.0, min=0.0, max=100, step=0.01, tooltip="Scaling factor for video attention."),
io.Float.Input("video_to_audio_scale", default=1.0, min=0.0, max=100, step=0.01, tooltip="Scaling factor for audio attention."),
],
outputs=[
io.Model.Output(display_name="model"),
],
)
@classmethod
def execute(cls, model, blocks, video_scale, audio_scale, audio_to_video_scale, video_to_audio_scale) -> io.NodeOutput:
model_clone = model.clone()
diffusion_model = model_clone.get_model_object("diffusion_model")
# Parse selected block indices
if blocks.strip() == "":
selected_blocks = set(range(len(diffusion_model.transformer_blocks)))
else:
selected_blocks = set(int(idx) for idx in blocks.strip().split(","))
logging.info(f"Applying LTX2 Attention Tuner Patch with custom scales to blocks: {sorted(selected_blocks)}")
# Apply patch to all blocks, but use 1.0 scales for non-selected blocks
for idx in range(len(diffusion_model.transformer_blocks)):
block = diffusion_model.transformer_blocks[idx]
if idx in selected_blocks:
patched_forward = LTX2ForwardPatch(video_scale, audio_scale, audio_to_video_scale, video_to_audio_scale).__get__(block, block.__class__)
else:
patched_forward = LTX2ForwardPatch(1.0, 1.0, 1.0, 1.0).__get__(block, block.__class__)
model_clone.add_object_patch(f"diffusion_model.transformer_blocks.{idx}.forward", patched_forward)
return io.NodeOutput(model_clone)
class LTX2MemoryEfficientSageAttentionPatch(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTX2MemoryEfficientSageAttentionPatch",
display_name="LTX2 Mem Eff Sage Attention Patch",
category="KJNodes/ltxv",
description="EXPERIMENTAL! Activates custom sageattention to reduce peak VRAM usage, overrides the attention mode. Requires latest sageattention version.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(display_name="model"),
],
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
if _cuda_archs is None:
raise RuntimeError("sageattention is not new enough version or could not determine CUDA architecture, cannot apply LTX2 Memory Efficient Sage Attention Patch.")
model_clone = model.clone()
diffusion_model = model_clone.get_model_object("diffusion_model")
logging.info("Applying LTX2 Memory Efficient Sage Attention Patch to all transformer blocks")
# Apply patch to all blocks, but use 1.0 scales for non-selected blocks
for idx, block in enumerate(diffusion_model.transformer_blocks):
model_clone.add_object_patch(f"diffusion_model.transformer_blocks.{idx}.attn1.forward", ltx2_sageattn_forward.__get__(block.attn1, block.attn1.__class__))
return io.NodeOutput(model_clone)
def get_cuda_version():
try:
version = torch.version.cuda
if version is not None:
major, minor = version.split('.')
return int(major), int(minor)
else:
return 0, 0
except Exception:
return 0, 0
sageplus_sm89_available = False
_cuda_archs = None
try:
from sageattention.core import per_thread_int8_triton, per_warp_int8_cuda, per_block_int8_triton, per_channel_fp8, get_cuda_arch_versions, attn_false
_cuda_archs = get_cuda_arch_versions()
except:
pass
try:
from sageattention.core import _qattn_sm89
cuda_version = get_cuda_version()
sageplus_sm89_available = hasattr(_qattn_sm89, 'qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf') and cuda_version >= (12, 8)
except ImportError:
try:
from sageattention.core import sm89_compile as _qattn_sm89
except ImportError:
_qattn_sm89 = None
try:
from sageattention.core import _qattn_sm80
except ImportError:
try:
from sageattention.core import sm80_compile as _qattn_sm80
except ImportError:
_qattn_sm80 = None
try:
from sageattention.core import _qattn_sm90
except ImportError:
try:
from sageattention.core import sm90_compile as _qattn_sm90
except ImportError:
_qattn_sm90 = None
from comfy.ldm.lightricks.model import apply_rotary_emb
def ltx2_sageattn_forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
dtype = x.dtype
context = x if context is None else context
# query
q = self.to_q(x)
q = self.q_norm(q)
if pe is not None:
q = apply_rotary_emb(q, pe)
# key
k = self.to_k(context)
k = self.k_norm(k)
if pe is not None:
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
# value
v = self.to_v(context)
# Reshape from [batch, seq_len, total_dim] to [batch, seq_len, num_heads, head_dim]
batch_size, seq_len, _ = q.shape
head_dim_og = self.dim_head
q = q.view(batch_size, seq_len, self.heads, head_dim_og)
k = k.view(batch_size, k.shape[1], self.heads, head_dim_og)
v = v.view(batch_size, v.shape[1], self.heads, head_dim_og)
tensor_layout="NHD"
_tensor_layout = 0 # NHD
_is_caual = 0
_qk_quant_gran = 3
_return_lse = 0
sm_scale = head_dim_og**-0.5
quant_v_scale_max = 448.0
if _cuda_archs[0] in {"sm80", "sm86"}:
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km=k.mean(dim=1, keepdim=True), tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64)
del q, k
o = torch.empty(q_int8.size(), dtype=dtype, device=q_int8.device)
v_fp16 = v.to(torch.float16).contiguous()
del v
_qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v_fp16, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif _cuda_archs[0] == "sm75":
q_int8, q_scale, k_int8, k_scale = per_block_int8_triton(q, k, km=k.mean(dim=1, keepdim=True), sm_scale=sm_scale, tensor_layout=tensor_layout)
del q, k
o, _ = attn_false(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, attn_mask=None, return_lse=False)
del v
elif _cuda_archs[0] == "sm89":
if not sageplus_sm89_available:
pv_accum_dtype = "fp32+fp32"
else:
pv_accum_dtype = "fp32+fp16"
quant_v_scale_max = 2.25
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km=k.mean(dim=1, keepdim=True), tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64)
del q, k
v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=False)
del v
o = torch.empty(q_int8.size(), dtype=dtype, device=q_int8.device)
if pv_accum_dtype == "fp32+fp16":
_qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp32":
_qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
del v_fp8, v_scale
elif _cuda_archs[0] == "sm90":
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km=k.mean(dim=1, keepdim=True), tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
del q, k,
v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
del v
o = torch.empty(q_int8.size(), dtype=dtype, device=q_int8.device)
_qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
del v_fp8, v_scale
elif _cuda_archs[0] == "sm120":
if not sageplus_sm89_available:
pv_accum_dtype = "fp32"
else:
pv_accum_dtype = "fp32+fp16"
quant_v_scale_max = 2.25
_qk_quant_gran = 2 # per warp
q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km=k.mean(dim=1, keepdim=True), tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64)
del q, k
v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=False)
del v
o = torch.empty(q_int8.size(), dtype=dtype, device=q_int8.device)
if pv_accum_dtype == "fp32":
_qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp16":
_qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
del v_fp8, v_scale
del q_int8, q_scale, k_int8, k_scale
return self.to_out(o.view(batch_size, seq_len, -1))
import folder_paths
class LTX2LoraLoaderAdvanced(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTX2LoraLoaderAdvanced",
display_name="LTX2 LoRA Loader Advanced",
category="KJNodes/ltxv",
description="Advanced LoRA loader with per-block strength control for LTX2 models",
is_experimental=True,
inputs=[
io.Combo.Input("lora_name", options=folder_paths.get_filename_list("loras"), tooltip="The name of the LoRA."),
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
io.String.Input("opt_lora_path", optional=True, force_input=True,tooltip="Absolute path of the LoRA."),
io.Custom("SELECTEDDITBLOCKS").Input("blocks", optional=True, tooltip="Selected DiT blocks configuration"),
io.Float.Input("video", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Strength for video attention layers."),
io.Float.Input("video_to_audio", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Strength for video to audio cross-attention layers."),
io.Float.Input("audio", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Strength for audio attention layers."),
io.Float.Input("audio_to_video", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Strength for audio to video cross-attention layers."),
io.Float.Input("other", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Strength for layers not caught by other layer filters."),
],
outputs=[
io.Model.Output(display_name="model", tooltip="The modified diffusion model."),
io.String.Output(display_name="rank", tooltip="Possible rank of the LoRA."),
io.String.Output(display_name="loaded_keys_info", tooltip="List of loaded keys and their alpha values."),
],
)
@classmethod
def execute(cls, model, lora_name, strength_model, video, video_to_audio, audio, audio_to_video, other, opt_lora_path=None, blocks=None) -> io.NodeOutput:
from comfy.utils import load_torch_file
import comfy.lora
if opt_lora_path:
lora_path = opt_lora_path
else:
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = load_torch_file(lora_path, safe_load=True)
# Find the first key that ends with "weight"
rank = "unknown"
weight_key = next((key for key in lora.keys() if key.endswith('weight')), None)
# Print the shape of the value corresponding to the key
if weight_key:
print(f"Shape of the first 'weight' key ({weight_key}): {lora[weight_key].shape}")
rank = str(lora[weight_key].shape[0])
else:
print("No key ending with 'weight' found.")
rank = "Couldn't find rank"
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
loaded = comfy.lora.load_lora(lora, key_map)
keys_to_delete = []
# First apply blocks filtering if provided
if blocks is not None:
for block in blocks:
for key in list(loaded.keys()):
match = False
if isinstance(key, str) and block in key:
match = True
elif isinstance(key, tuple):
for k in key:
if block in k:
match = True
break
if match:
ratio = blocks[block]
if ratio == 0:
keys_to_delete.append(key)
else:
# Only modify LoRA adapters, skip diff tuples
value = loaded[key]
if hasattr(value, 'weights'):
weights_list = list(value.weights)
weights_list[2] = ratio
loaded[key].weights = tuple(weights_list)
# Then apply layer-based attention strength filtering (takes priority)
for key in list(loaded.keys()):
if key in keys_to_delete:
continue
key_str = key if isinstance(key, str) else (key[0] if isinstance(key, tuple) else str(key))
# Determine the strength multiplier based on layer name
# Check more specific patterns first
strength_multiplier = None
# Video to audio cross-attention (check first - most specific)
if "video_to_audio_attn" in key_str:
strength_multiplier = video_to_audio
# Audio to video cross-attention
elif "audio_to_video_attn" in key_str:
strength_multiplier = audio_to_video
# Audio layers
elif "audio_attn" in key_str or "audio_ff.net" in key_str:
strength_multiplier = audio
# Video layers (check last - most general)
elif "attn" in key_str or "ff.net" in key_str:
strength_multiplier = video
# Everything else not caught by above filters
else:
strength_multiplier = other
# Apply strength or mark for deletion
if strength_multiplier is not None:
if strength_multiplier == 0:
keys_to_delete.append(key)
elif strength_multiplier != 1.0:
value = loaded[key]
if hasattr(value, 'weights'):
weights_list = list(value.weights)
# Handle case where alpha (weights[2]) might be None
current_alpha = weights_list[2] if weights_list[2] is not None else 1.0
weights_list[2] = current_alpha * strength_multiplier
loaded[key].weights = tuple(weights_list)
for key in keys_to_delete:
if key in loaded:
del loaded[key]
# Build list of loaded keys and their alphas
loaded_keys_list = []
for key, value in loaded.items():
if hasattr(value, 'weights'):
key_str = key if isinstance(key, str) else str(key)
alpha = value.weights[2] if value.weights[2] is not None else "None"
loaded_keys_list.append(f"{key_str}: alpha={alpha}")
else:
key_str = key if isinstance(key, str) else str(key)
loaded_keys_list.append(f"{key_str}: type={type(value).__name__}")
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
# Add not loaded keys to the info
k = set(k)
not_loaded = []
for x in loaded:
if x not in k:
key_str = x if isinstance(x, str) else str(x)
not_loaded.append(f"NOT LOADED: {key_str}")
if not_loaded:
loaded_keys_list.extend(not_loaded)
loaded_keys_info = "\n".join(loaded_keys_list)
return io.NodeOutput(new_modelpatcher, rank, loaded_keys_info)