Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
268 lines
11 KiB
Python
268 lines
11 KiB
Python
#credit to shakker-labs and instantX for this module
|
|
#from https://github.com/Shakker-Labs/ComfyUI-IPAdapter-Flux
|
|
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
from .attention_processor import IPAFluxAttnProcessor2_0
|
|
from .utils import is_model_pathched, FluxUpdateModules
|
|
from .sd3.resampler import TimeResampler
|
|
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
|
|
|
|
image_proj_model = None
|
|
class MLPProjModel(torch.nn.Module):
|
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
|
super().__init__()
|
|
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.num_tokens = num_tokens
|
|
|
|
self.proj = torch.nn.Sequential(
|
|
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
|
|
torch.nn.GELU(),
|
|
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
|
|
)
|
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
|
|
|
def forward(self, id_embeds):
|
|
x = self.proj(id_embeds)
|
|
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
class InstantXFluxIpadapterApply:
|
|
def __init__(self, num_tokens=128):
|
|
self.device = None
|
|
self.dtype = torch.float16
|
|
self.num_tokens = num_tokens
|
|
self.ip_ckpt = None
|
|
self.clip_vision = None
|
|
self.image_encoder = None
|
|
self.clip_image_processor = None
|
|
# state_dict
|
|
self.state_dict = None
|
|
self.joint_attention_dim = 4096
|
|
self.hidden_size = 3072
|
|
|
|
def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
|
|
s = flux_model.model_sampling
|
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
|
timestep_range = (percent_to_timestep_function(timestep_percent_range[0]),
|
|
percent_to_timestep_function(timestep_percent_range[1]))
|
|
ip_attn_procs = {} # 19+38=57
|
|
dsb_count = len(flux_model.diffusion_model.double_blocks)
|
|
for i in range(dsb_count):
|
|
name = f"double_blocks.{i}"
|
|
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
|
|
hidden_size=self.hidden_size,
|
|
cross_attention_dim=self.joint_attention_dim,
|
|
num_tokens=self.num_tokens,
|
|
scale=weight,
|
|
timestep_range=timestep_range
|
|
).to(self.device, dtype=self.dtype)
|
|
ssb_count = len(flux_model.diffusion_model.single_blocks)
|
|
for i in range(ssb_count):
|
|
name = f"single_blocks.{i}"
|
|
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
|
|
hidden_size=self.hidden_size,
|
|
cross_attention_dim=self.joint_attention_dim,
|
|
num_tokens=self.num_tokens,
|
|
scale=weight,
|
|
timestep_range=timestep_range
|
|
).to(self.device, dtype=self.dtype)
|
|
return ip_attn_procs
|
|
|
|
def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
|
|
global image_proj_model
|
|
image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True)
|
|
ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range)
|
|
ip_layers = torch.nn.ModuleList(ip_attn_procs.values())
|
|
ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True)
|
|
return ip_attn_procs
|
|
|
|
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
|
# outputs = self.clip_vision.encode_image(pil_image)
|
|
# clip_image_embeds = outputs['image_embeds']
|
|
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
|
|
# image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
|
if pil_image is not None:
|
|
if isinstance(pil_image, Image.Image):
|
|
pil_image = [pil_image]
|
|
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
|
clip_image_embeds = self.image_encoder(
|
|
clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
|
|
clip_image_embeds = clip_image_embeds.to(dtype=self.dtype)
|
|
else:
|
|
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
|
|
global image_proj_model
|
|
image_prompt_embeds = image_proj_model(clip_image_embeds)
|
|
return image_prompt_embeds
|
|
|
|
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
|
|
self.device = provider.lower()
|
|
if "clipvision" in ipadapter:
|
|
# self.clip_vision = ipadapter["clipvision"]['model']
|
|
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
|
|
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
|
|
if "ipadapter" in ipadapter:
|
|
self.ip_ckpt = ipadapter["ipadapter"]['file']
|
|
self.state_dict = ipadapter["ipadapter"]['model']
|
|
|
|
# process image
|
|
pil_image = image.numpy()[0] * 255.0
|
|
pil_image = Image.fromarray(pil_image.astype(np.uint8))
|
|
# initialize ipadapter
|
|
global image_proj_model
|
|
if image_proj_model is None:
|
|
image_proj_model = MLPProjModel(
|
|
cross_attention_dim=self.joint_attention_dim, # 4096
|
|
id_embeddings_dim=1152,
|
|
num_tokens=self.num_tokens,
|
|
)
|
|
image_proj_model.to(self.device, dtype=self.dtype)
|
|
ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at))
|
|
# process control image
|
|
image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None)
|
|
# set model
|
|
# is_patched = is_model_pathched(model.model)
|
|
bi = model.clone()
|
|
FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds)
|
|
|
|
return (bi, image)
|
|
|
|
|
|
def patch_sd3(
|
|
patcher,
|
|
ip_procs,
|
|
resampler: TimeResampler,
|
|
clip_embeds,
|
|
weight=1.0,
|
|
start=0.0,
|
|
end=1.0,
|
|
):
|
|
"""
|
|
Patches a model_sampler to add the ipadapter
|
|
"""
|
|
mmdit = patcher.model.diffusion_model
|
|
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
|
|
"timesteps", 1000
|
|
)
|
|
# hook the model's forward function
|
|
# so that when it gets called, we can grab the timestep and send it to the resampler
|
|
ip_options = {
|
|
"hidden_states": None,
|
|
"t_emb": None,
|
|
"weight": weight,
|
|
}
|
|
|
|
def ddit_wrapper(forward, args):
|
|
# this is between 0 and 1, so the adapters can calculate start_point and end_point
|
|
# actually, do we need to get the sigma value instead?
|
|
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
|
|
if start <= t_percent <= end:
|
|
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
|
|
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
|
|
embeds = clip_embeds[args["cond_or_uncond"]]
|
|
# slight efficiency optimization todo: pass the embeds through and then afterwards
|
|
# repeat to the batch size
|
|
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
|
|
# the resampler wants between 0 and MAX_STEPS
|
|
timestep = args["timestep"] * timestep_schedule_max
|
|
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
|
|
# these will need to be accessible to the IPAdapters
|
|
ip_options["hidden_states"] = image_emb
|
|
ip_options["t_emb"] = t_emb
|
|
else:
|
|
ip_options["hidden_states"] = None
|
|
ip_options["t_emb"] = None
|
|
|
|
return forward(args["input"], args["timestep"], **args["c"])
|
|
|
|
patcher.set_model_unet_function_wrapper(ddit_wrapper)
|
|
# patch each dit block
|
|
for i, block in enumerate(mmdit.joint_blocks):
|
|
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
|
|
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
|
|
|
|
class InstantXSD3IpadapterApply:
|
|
def __init__(self):
|
|
self.device = None
|
|
self.dtype = torch.float16
|
|
self.clip_image_processor = None
|
|
self.image_encoder = None
|
|
self.resampler = None
|
|
self.procs = None
|
|
|
|
@torch.inference_mode()
|
|
def encode(self, image):
|
|
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
|
|
clip_image_embeds = self.image_encoder(
|
|
clip_image.to(self.device, dtype=self.image_encoder.dtype),
|
|
output_hidden_states=True,
|
|
).hidden_states[-2]
|
|
clip_image_embeds = torch.cat(
|
|
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
|
|
)
|
|
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
|
|
return clip_image_embeds
|
|
|
|
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
|
|
self.device = provider.lower()
|
|
if "clipvision" in ipadapter:
|
|
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
|
|
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
|
|
if "ipadapter" in ipadapter:
|
|
self.ip_ckpt = ipadapter["ipadapter"]['file']
|
|
self.state_dict = ipadapter["ipadapter"]['model']
|
|
|
|
self.resampler = TimeResampler(
|
|
dim=1280,
|
|
depth=4,
|
|
dim_head=64,
|
|
heads=20,
|
|
num_queries=64,
|
|
embedding_dim=1152,
|
|
output_dim=2432,
|
|
ff_mult=4,
|
|
timestep_in_dim=320,
|
|
timestep_flip_sin_to_cos=True,
|
|
timestep_freq_shift=0,
|
|
)
|
|
self.resampler.eval()
|
|
self.resampler.to(self.device, dtype=self.dtype)
|
|
self.resampler.load_state_dict(self.state_dict["image_proj"])
|
|
|
|
# now we'll create the attention processors
|
|
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
|
|
n_procs = len(
|
|
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
|
|
)
|
|
self.procs = torch.nn.ModuleList(
|
|
[
|
|
# this is hardcoded for SD3.5L
|
|
IPAttnProcessor(
|
|
hidden_size=2432,
|
|
cross_attention_dim=2432,
|
|
ip_hidden_states_dim=2432,
|
|
ip_encoder_hidden_states_dim=2432,
|
|
head_dim=64,
|
|
timesteps_emb_dim=1280,
|
|
).to(self.device, dtype=torch.float16)
|
|
for _ in range(n_procs)
|
|
]
|
|
)
|
|
self.procs.load_state_dict(self.state_dict["ip_adapter"])
|
|
|
|
work_model = model.clone()
|
|
embeds = self.encode(image)
|
|
|
|
patch_sd3(
|
|
work_model,
|
|
self.procs,
|
|
self.resampler,
|
|
embeds,
|
|
weight,
|
|
start_at,
|
|
end_at,
|
|
)
|
|
|
|
return (work_model, image) |