Files
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

268 lines
11 KiB
Python

#credit to shakker-labs and instantX for this module
#from https://github.com/Shakker-Labs/ComfyUI-IPAdapter-Flux
import torch
from PIL import Image
import numpy as np
from .attention_processor import IPAFluxAttnProcessor2_0
from .utils import is_model_pathched, FluxUpdateModules
from .sd3.resampler import TimeResampler
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
image_proj_model = None
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class InstantXFluxIpadapterApply:
def __init__(self, num_tokens=128):
self.device = None
self.dtype = torch.float16
self.num_tokens = num_tokens
self.ip_ckpt = None
self.clip_vision = None
self.image_encoder = None
self.clip_image_processor = None
# state_dict
self.state_dict = None
self.joint_attention_dim = 4096
self.hidden_size = 3072
def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
s = flux_model.model_sampling
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
timestep_range = (percent_to_timestep_function(timestep_percent_range[0]),
percent_to_timestep_function(timestep_percent_range[1]))
ip_attn_procs = {} # 19+38=57
dsb_count = len(flux_model.diffusion_model.double_blocks)
for i in range(dsb_count):
name = f"double_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
ssb_count = len(flux_model.diffusion_model.single_blocks)
for i in range(ssb_count):
name = f"single_blocks.{i}"
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
hidden_size=self.hidden_size,
cross_attention_dim=self.joint_attention_dim,
num_tokens=self.num_tokens,
scale=weight,
timestep_range=timestep_range
).to(self.device, dtype=self.dtype)
return ip_attn_procs
def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
global image_proj_model
image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True)
ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range)
ip_layers = torch.nn.ModuleList(ip_attn_procs.values())
ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True)
return ip_attn_procs
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
# outputs = self.clip_vision.encode_image(pil_image)
# clip_image_embeds = outputs['image_embeds']
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
# image_prompt_embeds = self.image_proj_model(clip_image_embeds)
if pil_image is not None:
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
clip_image_embeds = clip_image_embeds.to(dtype=self.dtype)
else:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
global image_proj_model
image_prompt_embeds = image_proj_model(clip_image_embeds)
return image_prompt_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
# self.clip_vision = ipadapter["clipvision"]['model']
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
# process image
pil_image = image.numpy()[0] * 255.0
pil_image = Image.fromarray(pil_image.astype(np.uint8))
# initialize ipadapter
global image_proj_model
if image_proj_model is None:
image_proj_model = MLPProjModel(
cross_attention_dim=self.joint_attention_dim, # 4096
id_embeddings_dim=1152,
num_tokens=self.num_tokens,
)
image_proj_model.to(self.device, dtype=self.dtype)
ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at))
# process control image
image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None)
# set model
# is_patched = is_model_pathched(model.model)
bi = model.clone()
FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds)
return (bi, image)
def patch_sd3(
patcher,
ip_procs,
resampler: TimeResampler,
clip_embeds,
weight=1.0,
start=0.0,
end=1.0,
):
"""
Patches a model_sampler to add the ipadapter
"""
mmdit = patcher.model.diffusion_model
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
"timesteps", 1000
)
# hook the model's forward function
# so that when it gets called, we can grab the timestep and send it to the resampler
ip_options = {
"hidden_states": None,
"t_emb": None,
"weight": weight,
}
def ddit_wrapper(forward, args):
# this is between 0 and 1, so the adapters can calculate start_point and end_point
# actually, do we need to get the sigma value instead?
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
if start <= t_percent <= end:
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
embeds = clip_embeds[args["cond_or_uncond"]]
# slight efficiency optimization todo: pass the embeds through and then afterwards
# repeat to the batch size
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
# the resampler wants between 0 and MAX_STEPS
timestep = args["timestep"] * timestep_schedule_max
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
# these will need to be accessible to the IPAdapters
ip_options["hidden_states"] = image_emb
ip_options["t_emb"] = t_emb
else:
ip_options["hidden_states"] = None
ip_options["t_emb"] = None
return forward(args["input"], args["timestep"], **args["c"])
patcher.set_model_unet_function_wrapper(ddit_wrapper)
# patch each dit block
for i, block in enumerate(mmdit.joint_blocks):
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
class InstantXSD3IpadapterApply:
def __init__(self):
self.device = None
self.dtype = torch.float16
self.clip_image_processor = None
self.image_encoder = None
self.resampler = None
self.procs = None
@torch.inference_mode()
def encode(self, image):
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
clip_image_embeds = self.image_encoder(
clip_image.to(self.device, dtype=self.image_encoder.dtype),
output_hidden_states=True,
).hidden_states[-2]
clip_image_embeds = torch.cat(
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
)
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
return clip_image_embeds
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
self.device = provider.lower()
if "clipvision" in ipadapter:
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
if "ipadapter" in ipadapter:
self.ip_ckpt = ipadapter["ipadapter"]['file']
self.state_dict = ipadapter["ipadapter"]['model']
self.resampler = TimeResampler(
dim=1280,
depth=4,
dim_head=64,
heads=20,
num_queries=64,
embedding_dim=1152,
output_dim=2432,
ff_mult=4,
timestep_in_dim=320,
timestep_flip_sin_to_cos=True,
timestep_freq_shift=0,
)
self.resampler.eval()
self.resampler.to(self.device, dtype=self.dtype)
self.resampler.load_state_dict(self.state_dict["image_proj"])
# now we'll create the attention processors
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
n_procs = len(
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
)
self.procs = torch.nn.ModuleList(
[
# this is hardcoded for SD3.5L
IPAttnProcessor(
hidden_size=2432,
cross_attention_dim=2432,
ip_hidden_states_dim=2432,
ip_encoder_hidden_states_dim=2432,
head_dim=64,
timesteps_emb_dim=1280,
).to(self.device, dtype=torch.float16)
for _ in range(n_procs)
]
)
self.procs.load_state_dict(self.state_dict["ip_adapter"])
work_model = model.clone()
embeds = self.encode(image)
patch_sd3(
work_model,
self.procs,
self.resampler,
embeds,
weight,
start_at,
end_at,
)
return (work_model, image)