Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
363 lines
12 KiB
Python
363 lines
12 KiB
Python
import math
|
|
from typing import Callable, List
|
|
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from torch import Tensor
|
|
import numpy as np
|
|
|
|
#from .modules.conditioner import HFEmbedder
|
|
from .layers import DoubleStreamMixerProcessor, timestep_embedding
|
|
from tqdm.auto import tqdm
|
|
from .utils import ControlNetContainer
|
|
def model_forward(
|
|
model,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y: Tensor,
|
|
block_controlnet_hidden_states=None,
|
|
guidance: Tensor | None = None,
|
|
neg_mode: bool | None = False,
|
|
) -> Tensor:
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
# running on sequences img
|
|
img = model.img_in(img)
|
|
vec = model.time_in(timestep_embedding(timesteps, 256))
|
|
if model.params.guidance_embed:
|
|
if guidance is None:
|
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
|
vec = vec + model.guidance_in(timestep_embedding(guidance, 256))
|
|
vec = vec + model.vector_in(y)
|
|
txt = model.txt_in(txt)
|
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
pe = model.pe_embedder(ids)
|
|
if block_controlnet_hidden_states is not None:
|
|
controlnet_depth = len(block_controlnet_hidden_states)
|
|
for index_block, block in enumerate(model.double_blocks):
|
|
if hasattr(block, "processor"):
|
|
if isinstance(block.processor, DoubleStreamMixerProcessor):
|
|
if neg_mode:
|
|
for ip in block.processor.ip_adapters:
|
|
ip.ip_hidden_states = ip.in_hidden_states_neg
|
|
else:
|
|
for ip in block.processor.ip_adapters:
|
|
ip.ip_hidden_states = ip.in_hidden_states_pos
|
|
|
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
# controlnet residual
|
|
|
|
if block_controlnet_hidden_states is not None:
|
|
img = img + block_controlnet_hidden_states[index_block % 2]
|
|
|
|
|
|
img = torch.cat((txt, img), 1)
|
|
for block in model.single_blocks:
|
|
img = block(img, vec=vec, pe=pe)
|
|
img = img[:, txt.shape[1] :, ...]
|
|
|
|
img = model.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
return img
|
|
|
|
def get_noise(
|
|
num_samples: int,
|
|
height: int,
|
|
width: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
seed: int,
|
|
):
|
|
return torch.randn(
|
|
num_samples,
|
|
16,
|
|
# allow for packing
|
|
2 * math.ceil(height / 16),
|
|
2 * math.ceil(width / 16),
|
|
device=device,
|
|
dtype=dtype,
|
|
generator=torch.Generator(device=device).manual_seed(seed),
|
|
)
|
|
|
|
|
|
def prepare(txt_t5, vec_clip, img: Tensor) -> dict[str, Tensor]:
|
|
txt = txt_t5
|
|
vec = vec_clip
|
|
bs, c, h, w = img.shape
|
|
|
|
|
|
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
if img.shape[0] == 1 and bs > 1:
|
|
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
|
|
|
img_ids = torch.zeros(h // 2, w // 2, 3)
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
|
|
|
|
if txt.shape[0] == 1 and bs > 1:
|
|
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
|
|
|
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
|
|
|
if vec.shape[0] == 1 and bs > 1:
|
|
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
|
|
|
return {
|
|
"img": img,
|
|
"img_ids": img_ids.to(img.device, dtype=img.dtype),
|
|
"txt": txt.to(img.device, dtype=img.dtype),
|
|
"txt_ids": txt_ids.to(img.device, dtype=img.dtype),
|
|
"vec": vec.to(img.device, dtype=img.dtype),
|
|
}
|
|
|
|
|
|
def time_shift(mu: float, sigma: float, t: Tensor):
|
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
|
|
|
|
|
def get_lin_function(
|
|
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
|
) -> Callable[[float], float]:
|
|
m = (y2 - y1) / (x2 - x1)
|
|
b = y1 - m * x1
|
|
return lambda x: m * x + b
|
|
|
|
|
|
def get_schedule(
|
|
num_steps: int,
|
|
image_seq_len: int,
|
|
base_shift: float = 0.5,
|
|
max_shift: float = 1.15,
|
|
shift: bool = True,
|
|
) -> list[float]:
|
|
# extra step for zero
|
|
timesteps = torch.linspace(1, 0, num_steps + 1)
|
|
|
|
# shifting the schedule to favor high timesteps for higher signal images
|
|
if shift:
|
|
# eastimate mu based on linear estimation between two points
|
|
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
|
timesteps = time_shift(mu, 1.0, timesteps)
|
|
|
|
return timesteps.tolist()
|
|
|
|
|
|
def denoise(
|
|
model,
|
|
# model input
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
vec: Tensor,
|
|
neg_txt: Tensor,
|
|
neg_txt_ids: Tensor,
|
|
neg_vec: Tensor,
|
|
# sampling parameters
|
|
timesteps: list[float],
|
|
guidance: float = 4.0,
|
|
true_gs = 1,
|
|
timestep_to_start_cfg=0,
|
|
image2image_strength=None,
|
|
orig_image = None,
|
|
callback = None,
|
|
width = 512,
|
|
height = 512,
|
|
):
|
|
i = 0
|
|
|
|
#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
|
if image2image_strength is not None and orig_image is not None:
|
|
|
|
t_idx = np.clip(int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps)), 0, len(timesteps) - 1)
|
|
t = timesteps[t_idx]
|
|
timesteps = timesteps[t_idx:]
|
|
orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype)
|
|
img = t * img + (1.0 - t) * orig_image
|
|
img_ids=img_ids.to(img.device, dtype=img.dtype)
|
|
txt=txt.to(img.device, dtype=img.dtype)
|
|
txt_ids=txt_ids.to(img.device, dtype=img.dtype)
|
|
vec=vec.to(img.device, dtype=img.dtype)
|
|
if hasattr(model, "guidance_in"):
|
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
|
else:
|
|
# this is ignored for schnell
|
|
guidance_vec = None
|
|
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), desc="Sampling", total = len(timesteps)-1):
|
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
pred = model_forward(
|
|
model,
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
y=vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
if i >= timestep_to_start_cfg:
|
|
neg_pred = model_forward(
|
|
model,
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=neg_txt,
|
|
txt_ids=neg_txt_ids,
|
|
y=neg_vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
neg_mode = True,
|
|
)
|
|
pred = neg_pred + true_gs * (pred - neg_pred)
|
|
img = img + (t_prev - t_curr) * pred
|
|
|
|
if callback is not None:
|
|
unpacked = unpack(img.float(), height, width)
|
|
callback(step=i, x=img, x0=unpacked, total_steps=len(timesteps) - 1)
|
|
i += 1
|
|
|
|
return img
|
|
|
|
def denoise_controlnet(
|
|
model,
|
|
controlnets_container: None|List[ControlNetContainer],
|
|
# model input
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
vec: Tensor,
|
|
neg_txt: Tensor,
|
|
neg_txt_ids: Tensor,
|
|
neg_vec: Tensor,
|
|
#controlnet_cond,
|
|
#sampling parameters
|
|
timesteps: list[float],
|
|
guidance: float = 4.0,
|
|
true_gs = 1,
|
|
#controlnet_gs=0.7,
|
|
timestep_to_start_cfg=0,
|
|
image2image_strength=None,
|
|
orig_image = None,
|
|
callback = None,
|
|
width = 512,
|
|
height = 512,
|
|
#controlnet_start_step=0,
|
|
#controlnet_end_step=None
|
|
):
|
|
i = 0
|
|
|
|
if image2image_strength is not None and orig_image is not None:
|
|
t_idx = int((1 - np.clip(image2image_strength, 0.0, 1.0)) * len(timesteps))
|
|
t = timesteps[t_idx]
|
|
timesteps = timesteps[t_idx:]
|
|
orig_image = rearrange(orig_image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).to(img.device, dtype = img.dtype)
|
|
img = t * img + (1.0 - t) * orig_image
|
|
|
|
img_ids = img_ids.to(img.device, dtype=img.dtype)
|
|
txt = txt.to(img.device, dtype=img.dtype)
|
|
txt_ids = txt_ids.to(img.device, dtype=img.dtype)
|
|
vec = vec.to(img.device, dtype=img.dtype)
|
|
for container in controlnets_container:
|
|
container.controlnet_cond = container.controlnet_cond.to(img.device, dtype=img.dtype)
|
|
container.controlnet.to(img.device, dtype=img.dtype)
|
|
#controlnet.to(img.device, dtype=img.dtype)
|
|
#controlnet_cond = controlnet_cond.to(img.device, dtype=img.dtype)
|
|
|
|
if hasattr(model, "guidance_in"):
|
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
|
else:
|
|
guidance_vec = None
|
|
|
|
for t_curr, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:]), desc="Sampling", total=len(timesteps)-1):
|
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
guidance_vec = guidance_vec.to(img.device, dtype=img.dtype)
|
|
controlnet_hidden_states = None
|
|
for container in controlnets_container:
|
|
if container.controlnet_start_step <= i <= container.controlnet_end_step:
|
|
block_res_samples = container.controlnet(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
controlnet_cond=container.controlnet_cond,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
y=vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
if controlnet_hidden_states is None:
|
|
controlnet_hidden_states = [sample * container.controlnet_gs for sample in block_res_samples]
|
|
else:
|
|
if len(controlnet_hidden_states) == len(block_res_samples):
|
|
for j in range(len(controlnet_hidden_states)):
|
|
controlnet_hidden_states[j] += block_res_samples[j] * container.controlnet_gs
|
|
|
|
|
|
pred = model_forward(
|
|
model,
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=txt,
|
|
txt_ids=txt_ids,
|
|
y=vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
block_controlnet_hidden_states=controlnet_hidden_states
|
|
)
|
|
neg_controlnet_hidden_states = None
|
|
if i >= timestep_to_start_cfg:
|
|
for container in controlnets_container:
|
|
if container.controlnet_start_step <= i <= container.controlnet_end_step:
|
|
neg_block_res_samples = container.controlnet(
|
|
img=img,
|
|
img_ids=img_ids,
|
|
controlnet_cond=container.controlnet_cond,
|
|
txt=neg_txt,
|
|
txt_ids=neg_txt_ids,
|
|
y=neg_vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
)
|
|
if neg_controlnet_hidden_states is None:
|
|
neg_controlnet_hidden_states = [sample * container.controlnet_gs for sample in neg_block_res_samples]
|
|
else:
|
|
if len(neg_controlnet_hidden_states) == len(neg_block_res_samples):
|
|
for j in range(len(neg_controlnet_hidden_states)):
|
|
neg_controlnet_hidden_states[j] += neg_block_res_samples[j] * container.controlnet_gs
|
|
|
|
|
|
neg_pred = model_forward(
|
|
model,
|
|
img=img,
|
|
img_ids=img_ids,
|
|
txt=neg_txt,
|
|
txt_ids=neg_txt_ids,
|
|
y=neg_vec,
|
|
timesteps=t_vec,
|
|
guidance=guidance_vec,
|
|
block_controlnet_hidden_states=neg_controlnet_hidden_states,
|
|
neg_mode=True,
|
|
)
|
|
pred = neg_pred + true_gs * (pred - neg_pred)
|
|
img = img + (t_prev - t_curr) * pred
|
|
|
|
if callback is not None:
|
|
unpacked = unpack(img.float(), height, width)
|
|
callback(step=i, x=img, x0=unpacked, total_steps=len(timesteps) - 1)
|
|
i += 1
|
|
return img
|
|
|
|
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
|
return rearrange(
|
|
x,
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
h=math.ceil(height / 16),
|
|
w=math.ceil(width / 16),
|
|
ph=2,
|
|
pw=2,
|
|
)
|