Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,268 @@
#credit to shakker-labs and instantX for this module
#from https://github.com/Shakker-Labs/ComfyUI-IPAdapter-Flux
import torch
from PIL import Image
import numpy as np
from .attention_processor import IPAFluxAttnProcessor2_0
from .utils import is_model_pathched, FluxUpdateModules
from .sd3.resampler import TimeResampler
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
image_proj_model = None
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class InstantXFluxIpadapterApply:
def __init__(self, num_tokens=128):
self.device = None
self.dtype = torch.float16
self.num_tokens = num_tokens
self.ip_ckpt = None
self.clip_vision = None
self.image_encoder = None
self.clip_image_processor = None
# state_dict
self.state_dict = None
self.joint_attention_dim = 4096
self.hidden_size = 3072
def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
s = flux_model.model_sampling
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
timestep_range = (percent_to_timestep_function(timestep_percent_range[0]),
percent_to_timestep_function(timestep_percent_range[1]))
ip_attn_procs = {} # 19+38=57
dsb_count = len(flux_model.diffusion_model.double_blocks)
for i in range(dsb_count):
name = f"double_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
ssb_count = len(flux_model.diffusion_model.single_blocks)
for i in range(ssb_count):
name = f"single_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
return ip_attn_procs
def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
global image_proj_model
image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True)
ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range)
ip_layers = torch.nn.ModuleList(ip_attn_procs.values())
ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True)
return ip_attn_procs
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
# outputs = self.clip_vision.encode_image(pil_image)
# clip_image_embeds = outputs['image_embeds']
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
# image_prompt_embeds = self.image_proj_model(clip_image_embeds)
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
clip_image_embeds = clip_image_embeds.to(dtype=self.dtype)
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
global image_proj_model
image_prompt_embeds = image_proj_model(clip_image_embeds)
return image_prompt_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
# self.clip_vision = ipadapter["clipvision"]['model']
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
# process image
pil_image = image.numpy()[0] * 255.0
pil_image = Image.fromarray(pil_image.astype(np.uint8))
# initialize ipadapter
global image_proj_model
if image_proj_model is None:
image_proj_model = MLPProjModel(
cross_attention_dim=self.joint_attention_dim, # 4096
id_embeddings_dim=1152,
num_tokens=self.num_tokens,
)
image_proj_model.to(self.device, dtype=self.dtype)
ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at))
# process control image
image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None)
# set model
# is_patched = is_model_pathched(model.model)
bi = model.clone()
FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds)
return (bi, image)
def patch_sd3(
patcher,
ip_procs,
resampler: TimeResampler,
clip_embeds,
weight=1.0,
start=0.0,
end=1.0,
):
"""
Patches a model_sampler to add the ipadapter
"""
mmdit = patcher.model.diffusion_model
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
"timesteps", 1000
)
# hook the model's forward function
# so that when it gets called, we can grab the timestep and send it to the resampler
ip_options = {
"hidden_states": None,
"t_emb": None,
"weight": weight,
}
def ddit_wrapper(forward, args):
# this is between 0 and 1, so the adapters can calculate start_point and end_point
# actually, do we need to get the sigma value instead?
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
if start <= t_percent <= end:
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
embeds = clip_embeds[args["cond_or_uncond"]]
# slight efficiency optimization todo: pass the embeds through and then afterwards
# repeat to the batch size
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
# the resampler wants between 0 and MAX_STEPS
timestep = args["timestep"] * timestep_schedule_max
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
# these will need to be accessible to the IPAdapters
ip_options["hidden_states"] = image_emb
ip_options["t_emb"] = t_emb
else:
ip_options["hidden_states"] = None
ip_options["t_emb"] = None
return forward(args["input"], args["timestep"], **args["c"])
patcher.set_model_unet_function_wrapper(ddit_wrapper)
# patch each dit block
for i, block in enumerate(mmdit.joint_blocks):
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
class InstantXSD3IpadapterApply:
def __init__(self):
self.device = None
self.dtype = torch.float16
self.clip_image_processor = None
self.image_encoder = None
self.resampler = None
self.procs = None
@torch.inference_mode()
def encode(self, image):
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype),
output_hidden_states=True,
).hidden_states[-2]
clip_image_embeds = torch.cat(
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
)
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
return clip_image_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
self.resampler = TimeResampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=64,
embedding_dim=1152,
output_dim=2432,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
)
self.resampler.eval()
self.resampler.to(self.device, dtype=self.dtype)
self.resampler.load_state_dict(self.state_dict["image_proj"])
# now we'll create the attention processors
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
n_procs = len(
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
)
self.procs = torch.nn.ModuleList(
[
# this is hardcoded for SD3.5L
IPAttnProcessor(
hidden_size=2432,
cross_attention_dim=2432,
ip_hidden_states_dim=2432,
ip_encoder_hidden_states_dim=2432,
head_dim=64,
timesteps_emb_dim=1280,
).to(self.device, dtype=torch.float16)
for _ in range(n_procs)
]
)
self.procs.load_state_dict(self.state_dict["ip_adapter"])
work_model = model.clone()
embeds = self.encode(image)
patch_sd3(
work_model,
self.procs,
self.resampler,
embeds,
weight,
start_at,
end_at,
)
return (work_model, image)

View File

@@ -0,0 +1,87 @@
import numbers
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class IPAFluxAttnProcessor2_0(nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, timestep_range=None):
super().__init__()
self.hidden_size = hidden_size # 3072
self.cross_attention_dim = cross_attention_dim # 4096
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
self.norm_added_v = RMSNorm(128, eps=1e-5, elementwise_affine=False)
self.timestep_range = timestep_range
def __call__(
self,
num_heads,
query,
image_emb: torch.FloatTensor,
t: torch.FloatTensor
) -> torch.FloatTensor:
# only apply IPA if timestep is within range
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
return None
# `ip-adapter` projections
ip_hidden_states = image_emb
ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states)
ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states)
ip_hidden_states_key_proj = rearrange(ip_hidden_states_key_proj, 'B L (H D) -> B H L D', H=num_heads)
ip_hidden_states_value_proj = rearrange(ip_hidden_states_value_proj, 'B L (H D) -> B H L D', H=num_heads)
ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj)
ip_hidden_states_value_proj = self.norm_added_v(ip_hidden_states_value_proj)
ip_hidden_states = F.scaled_dot_product_attention(query.to(image_emb.device).to(image_emb.dtype),
ip_hidden_states_key_proj,
ip_hidden_states_value_proj,
dropout_p=0.0, is_causal=False)
ip_hidden_states = rearrange(ip_hidden_states, "B H L D -> B L (H D)", H=num_heads)
ip_hidden_states = ip_hidden_states.to(query.dtype).to(query.device)
return self.scale * ip_hidden_states

View File

@@ -0,0 +1,153 @@
import torch
from torch import Tensor, nn
from .math import attention
from ..attention_processor import IPAFluxAttnProcessor2_0
from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock
from comfy import model_management as mm
class DoubleStreamBlockIPA(nn.Module):
def __init__(self, original_block: DoubleStreamBlock, ip_adapter, image_emb):
super().__init__()
mlp_hidden_dim = original_block.img_mlp[0].out_features
mlp_ratio = mlp_hidden_dim / original_block.hidden_size
mlp_hidden_dim = int(original_block.hidden_size * mlp_ratio)
self.num_heads = original_block.num_heads
self.hidden_size = original_block.hidden_size
self.img_mod = original_block.img_mod
self.img_norm1 = original_block.img_norm1
self.img_attn = original_block.img_attn
self.img_norm2 = original_block.img_norm2
self.img_mlp = original_block.img_mlp
self.txt_mod = original_block.txt_mod
self.txt_norm1 = original_block.txt_norm1
self.txt_attn = original_block.txt_attn
self.txt_norm2 = original_block.txt_norm2
self.txt_mlp = original_block.txt_mlp
self.flipped_img_txt = original_block.flipped_img_txt
self.ip_adapter = ip_adapter
self.image_emb = image_emb
self.device = mm.get_torch_device()
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
for adapter, image in zip(self.ip_adapter, self.image_emb):
# this does a separate attention for each adapter
ip_hidden_states = adapter(self.num_heads, img_q, image, t)
if ip_hidden_states is not None:
ip_hidden_states = ip_hidden_states.to(self.device)
img_attn = img_attn + ip_hidden_states
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlockIPA(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(self, original_block: SingleStreamBlock, ip_adapter, image_emb):
super().__init__()
self.hidden_dim = original_block.hidden_size
self.num_heads = original_block.num_heads
self.scale = original_block.scale
self.mlp_hidden_dim = original_block.mlp_hidden_dim
# qkv and mlp_in
self.linear1 = original_block.linear1
# proj and mlp_out
self.linear2 = original_block.linear2
self.norm = original_block.norm
self.hidden_size = original_block.hidden_size
self.pre_norm = original_block.pre_norm
self.mlp_act = original_block.mlp_act
self.modulation = original_block.modulation
self.ip_adapter = ip_adapter
self.image_emb = image_emb
self.device = mm.get_torch_device()
def add_adapter(self, ip_adapter: IPAFluxAttnProcessor2_0, image_emb):
self.ip_adapter.append(ip_adapter)
self.image_emb.append(image_emb)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
for adapter, image in zip(self.ip_adapter, self.image_emb):
# this does a separate attention for each adapter
# maybe we want a single joint attention call for all adapters?
ip_hidden_states = adapter(self.num_heads, q, image, t)
if ip_hidden_states is not None:
ip_hidden_states = ip_hidden_states.to(self.device)
attn = attn + ip_hidden_states
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@@ -0,0 +1,35 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,219 @@
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import (RMSNorm, JointBlock,)
class AdaLayerNorm(nn.Module):
"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"):
super().__init__()
self.silu = nn.SiLU()
num_params_dict = dict(
zero=6,
normal=2,
)
num_params = num_params_dict[mode]
self.linear = nn.Linear(
time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True
)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
self.mode = mode
def forward(
self,
x,
hidden_dtype=None,
emb=None,
):
emb = self.linear(self.silu(emb))
if self.mode == "normal":
shift_msa, scale_msa = emb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x
elif self.mode == "zero":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
6, dim=1
)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class IPAttnProcessor(nn.Module):
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
ip_hidden_states_dim=None,
ip_encoder_hidden_states_dim=None,
head_dim=None,
timesteps_emb_dim=1280,
):
super().__init__()
self.norm_ip = AdaLayerNorm(
ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim
)
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
self.norm_q = RMSNorm(head_dim, 1e-6)
self.norm_k = RMSNorm(head_dim, 1e-6)
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
def forward(
self,
ip_hidden_states,
img_query,
img_key=None,
img_value=None,
t_emb=None,
n_heads=1,
):
if ip_hidden_states is None:
return None
if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"):
return None
# norm ip input
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=t_emb)
# to k and v
ip_key = self.to_k_ip(norm_ip_hidden_states)
ip_value = self.to_v_ip(norm_ip_hidden_states)
# reshape
img_query = rearrange(img_query, "b l (h d) -> b h l d", h=n_heads)
img_key = rearrange(img_key, "b l (h d) -> b h l d", h=n_heads)
# note that the image is in a different shape: b l h d
# so we transpose to b h l d
# or do we have to transpose here?
img_value = torch.transpose(img_value, 1, 2)
ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=n_heads)
ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=n_heads)
# norm
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
ip_key = self.norm_ip_k(ip_key)
# cat img
key = torch.cat([img_key, ip_key], dim=2)
value = torch.cat([img_value, ip_value], dim=2)
#
ip_hidden_states = F.scaled_dot_product_attention(
img_query, key, value, dropout_p=0.0, is_causal=False
)
ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)")
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
return ip_hidden_states
class JointBlockIPWrapper:
"""To be used as a patch_replace with Comfy"""
def __init__(
self,
original_block: JointBlock,
adapter: IPAttnProcessor,
ip_options=None,
):
self.original_block = original_block
self.adapter = adapter
if ip_options is None:
ip_options = {}
self.ip_options = ip_options
def block_mixing(self, context, x, context_block, x_block, c):
"""
Comes from mmdit.py. Modified to add ipadapter attention.
"""
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
else:
x_qkv, x_intermediates = x_block.pre_attention(x, c)
qkv = tuple(torch.cat((context_qkv[j], x_qkv[j]), dim=1) for j in range(3))
attn = optimized_attention(
qkv[0],
qkv[1],
qkv[2],
heads=x_block.attn.num_heads,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
attn[:, context_qkv[0].shape[1] :],
)
# if the current timestep is not in the ipadapter enabling range, then the resampler wasn't run
# and the hidden states will be None
if (
self.ip_options["hidden_states"] is not None
and self.ip_options["t_emb"] is not None
):
# IP-Adapter
ip_attn = self.adapter(
self.ip_options["hidden_states"],
*x_qkv,
self.ip_options["t_emb"],
x_block.attn.num_heads,
)
x_attn = x_attn + ip_attn * self.ip_options["weight"]
# Everything else is unchanged
if not context_block.pre_only:
context = context_block.post_attention(context_attn, *context_intermediates)
else:
context = None
if x_block.x_block_self_attn:
attn2 = optimized_attention(
x_qkv2[0],
x_qkv2[1],
x_qkv2[2],
heads=x_block.attn2.num_heads,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
x = x_block.post_attention(x_attn, *x_intermediates)
return context, x
def __call__(self, args, _):
# Code from mmdit.py:
# in this case, we're blocks_replace[("double_block", i)]
# note that although we're passed the original block,
# we can't actually get it from inside its wrapper
# (which would simplify the whole code...)
# ```
# def block_wrap(args):
# out = {}
# out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
# return out
# out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
# context = out["txt"]
# x = out["img"]
# ```
c, x = self.block_mixing(
args["txt"],
args["img"],
self.original_block.context_block,
self.original_block.x_block,
c=args["vec"],
)
return {"txt": c, "img": x}

View File

@@ -0,0 +1,385 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch
import torch.nn as nn
from typing import Optional
ACTIVATION_FUNCTIONS = {
"swish": nn.SiLU(),
"silu": nn.SiLU(),
"mish": nn.Mish(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
}
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
Args:
act_fn (str): Name of activation function.
Returns:
nn.Module: Activation function.
"""
act_fn = act_fn.lower()
if act_fn in ACTIVATION_FUNCTIONS:
return ACTIVATION_FUNCTIONS[act_fn]
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, shift=None, scale=None):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
if shift is not None and scale is not None:
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
*args,
**kwargs,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
class TimeResampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
# msa
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
# ff
FeedForward(dim=dim, mult=ff_mult),
# adaLN
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)),
]
)
)
# time
self.time_proj = Timesteps(
timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift
)
self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
# adaLN
# self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
# nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
# )
def forward(self, x, timestep, need_temb=False):
timestep_emb = self.embedding_time(x, timestep) # bs, dim
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
x = x + timestep_emb[:, None]
for attn, ff, adaLN_modulation in self.layers:
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(
timestep_emb
).chunk(4, dim=1)
latents = attn(x, latents, shift_msa, scale_msa) + latents
res = latents
for idx_ff in range(len(ff)):
layer_ff = ff[idx_ff]
latents = layer_ff(latents)
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
latents = latents * (
1 + scale_mlp.unsqueeze(1)
) + shift_mlp.unsqueeze(1)
latents = latents + res
# latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
if need_temb:
return latents, timestep_emb
else:
return latents
def embedding_time(self, sample, timestep):
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, None)
return emb

View File

@@ -0,0 +1,136 @@
import torch
from torch import Tensor
from .flux.layers import DoubleStreamBlockIPA, SingleStreamBlockIPA
from comfy.ldm.flux.layers import timestep_embedding
from types import MethodType
def FluxUpdateModules(bi, ip_attn_procs, image_emb):
flux_model = bi.model
bi.add_object_patch(f"diffusion_model.forward_orig", MethodType(forward_orig_ipa, flux_model.diffusion_model))
for i, original in enumerate(flux_model.diffusion_model.double_blocks):
patch_name = f"double_blocks.{i}"
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
# if there's already a patch there, collect its adapters and replace it
procs = [ip_attn_procs[patch_name]]
embs = [image_emb]
if isinstance(maybe_patched_layer, DoubleStreamBlockIPA):
procs = maybe_patched_layer.ip_adapter + procs
embs = maybe_patched_layer.image_emb + embs
# initial ipa models with image embeddings
new_layer = DoubleStreamBlockIPA(original, procs, embs)
# for example, ComfyUI internally uses model.add_patches to add loras
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
for i, original in enumerate(flux_model.diffusion_model.single_blocks):
patch_name = f"single_blocks.{i}"
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
procs = [ip_attn_procs[patch_name]]
embs = [image_emb]
if isinstance(maybe_patched_layer, SingleStreamBlockIPA):
procs = maybe_patched_layer.ip_adapter + procs
embs = maybe_patched_layer.image_emb + embs
# initial ipa models with image embeddings
new_layer = SingleStreamBlockIPA(original, procs, embs)
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
def is_model_pathched(model):
def test(mod):
if isinstance(mod, DoubleStreamBlockIPA):
return True
else:
for p in mod.children():
if test(p):
return True
return False
result = test(model)
return result
def forward_orig_ipa(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor|None = None,
control=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
else:
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
if isinstance(block, SingleStreamBlockIPA): # ipadaper
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
else:
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
img = out["img"]
else:
if isinstance(block, SingleStreamBlockIPA): # ipadaper
img = block(img, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1] :, ...] += add
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img