Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
219 lines
7.2 KiB
Python
219 lines
7.2 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from einops import rearrange
|
|
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy.ldm.modules.diffusionmodules.mmdit import (RMSNorm, JointBlock,)
|
|
|
|
|
|
class AdaLayerNorm(nn.Module):
|
|
"""
|
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"):
|
|
super().__init__()
|
|
|
|
self.silu = nn.SiLU()
|
|
num_params_dict = dict(
|
|
zero=6,
|
|
normal=2,
|
|
)
|
|
num_params = num_params_dict[mode]
|
|
self.linear = nn.Linear(
|
|
time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True
|
|
)
|
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
|
self.mode = mode
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
hidden_dtype=None,
|
|
emb=None,
|
|
):
|
|
emb = self.linear(self.silu(emb))
|
|
if self.mode == "normal":
|
|
shift_msa, scale_msa = emb.chunk(2, dim=1)
|
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
return x
|
|
|
|
elif self.mode == "zero":
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
|
|
6, dim=1
|
|
)
|
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
|
|
|
|
class IPAttnProcessor(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size=None,
|
|
cross_attention_dim=None,
|
|
ip_hidden_states_dim=None,
|
|
ip_encoder_hidden_states_dim=None,
|
|
head_dim=None,
|
|
timesteps_emb_dim=1280,
|
|
):
|
|
super().__init__()
|
|
|
|
self.norm_ip = AdaLayerNorm(
|
|
ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim
|
|
)
|
|
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
|
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
|
self.norm_q = RMSNorm(head_dim, 1e-6)
|
|
self.norm_k = RMSNorm(head_dim, 1e-6)
|
|
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
|
|
|
def forward(
|
|
self,
|
|
ip_hidden_states,
|
|
img_query,
|
|
img_key=None,
|
|
img_value=None,
|
|
t_emb=None,
|
|
n_heads=1,
|
|
):
|
|
if ip_hidden_states is None:
|
|
return None
|
|
|
|
if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"):
|
|
return None
|
|
|
|
# norm ip input
|
|
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=t_emb)
|
|
|
|
# to k and v
|
|
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
|
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
|
|
|
# reshape
|
|
img_query = rearrange(img_query, "b l (h d) -> b h l d", h=n_heads)
|
|
img_key = rearrange(img_key, "b l (h d) -> b h l d", h=n_heads)
|
|
# note that the image is in a different shape: b l h d
|
|
# so we transpose to b h l d
|
|
# or do we have to transpose here?
|
|
img_value = torch.transpose(img_value, 1, 2)
|
|
ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=n_heads)
|
|
ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=n_heads)
|
|
|
|
# norm
|
|
img_query = self.norm_q(img_query)
|
|
img_key = self.norm_k(img_key)
|
|
ip_key = self.norm_ip_k(ip_key)
|
|
|
|
# cat img
|
|
key = torch.cat([img_key, ip_key], dim=2)
|
|
value = torch.cat([img_value, ip_value], dim=2)
|
|
|
|
#
|
|
ip_hidden_states = F.scaled_dot_product_attention(
|
|
img_query, key, value, dropout_p=0.0, is_causal=False
|
|
)
|
|
ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)")
|
|
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
|
|
return ip_hidden_states
|
|
|
|
|
|
class JointBlockIPWrapper:
|
|
"""To be used as a patch_replace with Comfy"""
|
|
|
|
def __init__(
|
|
self,
|
|
original_block: JointBlock,
|
|
adapter: IPAttnProcessor,
|
|
ip_options=None,
|
|
):
|
|
self.original_block = original_block
|
|
self.adapter = adapter
|
|
if ip_options is None:
|
|
ip_options = {}
|
|
self.ip_options = ip_options
|
|
|
|
def block_mixing(self, context, x, context_block, x_block, c):
|
|
"""
|
|
Comes from mmdit.py. Modified to add ipadapter attention.
|
|
"""
|
|
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
|
|
|
if x_block.x_block_self_attn:
|
|
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
|
else:
|
|
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
|
|
|
qkv = tuple(torch.cat((context_qkv[j], x_qkv[j]), dim=1) for j in range(3))
|
|
|
|
attn = optimized_attention(
|
|
qkv[0],
|
|
qkv[1],
|
|
qkv[2],
|
|
heads=x_block.attn.num_heads,
|
|
)
|
|
context_attn, x_attn = (
|
|
attn[:, : context_qkv[0].shape[1]],
|
|
attn[:, context_qkv[0].shape[1] :],
|
|
)
|
|
# if the current timestep is not in the ipadapter enabling range, then the resampler wasn't run
|
|
# and the hidden states will be None
|
|
if (
|
|
self.ip_options["hidden_states"] is not None
|
|
and self.ip_options["t_emb"] is not None
|
|
):
|
|
# IP-Adapter
|
|
ip_attn = self.adapter(
|
|
self.ip_options["hidden_states"],
|
|
*x_qkv,
|
|
self.ip_options["t_emb"],
|
|
x_block.attn.num_heads,
|
|
)
|
|
x_attn = x_attn + ip_attn * self.ip_options["weight"]
|
|
|
|
# Everything else is unchanged
|
|
if not context_block.pre_only:
|
|
context = context_block.post_attention(context_attn, *context_intermediates)
|
|
|
|
else:
|
|
context = None
|
|
if x_block.x_block_self_attn:
|
|
attn2 = optimized_attention(
|
|
x_qkv2[0],
|
|
x_qkv2[1],
|
|
x_qkv2[2],
|
|
heads=x_block.attn2.num_heads,
|
|
)
|
|
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
|
else:
|
|
x = x_block.post_attention(x_attn, *x_intermediates)
|
|
return context, x
|
|
|
|
def __call__(self, args, _):
|
|
# Code from mmdit.py:
|
|
# in this case, we're blocks_replace[("double_block", i)]
|
|
# note that although we're passed the original block,
|
|
# we can't actually get it from inside its wrapper
|
|
# (which would simplify the whole code...)
|
|
# ```
|
|
# def block_wrap(args):
|
|
# out = {}
|
|
# out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
|
# return out
|
|
# out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
|
# context = out["txt"]
|
|
# x = out["img"]
|
|
# ```
|
|
c, x = self.block_mixing(
|
|
args["txt"],
|
|
args["img"],
|
|
self.original_block.context_block,
|
|
self.original_block.x_block,
|
|
c=args["vec"],
|
|
)
|
|
return {"txt": c, "img": x} |