Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
|
from collections.abc import Iterable
|
|
from itertools import repeat
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, Iterable) and not isinstance(x, str):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
return parse
|
|
|
|
to_1tuple = _ntuple(1)
|
|
to_2tuple = _ntuple(2)
|
|
|
|
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
|
|
assert isinstance(model, nn.Module)
|
|
|
|
def set_attr(module):
|
|
module.grad_checkpointing = True
|
|
module.fp32_attention = use_fp32_attention
|
|
module.grad_checkpointing_step = gc_step
|
|
model.apply(set_attr)
|
|
|
|
def auto_grad_checkpoint(module, *args, **kwargs):
|
|
if getattr(module, 'grad_checkpointing', False):
|
|
if isinstance(module, Iterable):
|
|
gc_step = module[0].grad_checkpointing_step
|
|
return checkpoint_sequential(module, gc_step, *args, **kwargs)
|
|
else:
|
|
return checkpoint(module, *args, **kwargs)
|
|
return module(*args, **kwargs)
|
|
|
|
def checkpoint_sequential(functions, step, input, *args, **kwargs):
|
|
|
|
# Hack for keyword-only parameter in a python 2.7-compliant way
|
|
preserve = kwargs.pop('preserve_rng_state', True)
|
|
if kwargs:
|
|
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
|
|
|
def run_function(start, end, functions):
|
|
def forward(input):
|
|
for j in range(start, end + 1):
|
|
input = functions[j](input, *args)
|
|
return input
|
|
return forward
|
|
|
|
if isinstance(functions, torch.nn.Sequential):
|
|
functions = list(functions.children())
|
|
|
|
# the last chunk has to be non-volatile
|
|
end = -1
|
|
segment = len(functions) // step
|
|
for start in range(0, step * (segment - 1), step):
|
|
end = start + step - 1
|
|
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
|
|
return run_function(end + 1, len(functions) - 1, functions)(input)
|
|
|
|
def get_rel_pos(q_size, k_size, rel_pos):
|
|
"""
|
|
Get relative positional embeddings according to the relative positions of
|
|
query and key sizes.
|
|
Args:
|
|
q_size (int): size of query q.
|
|
k_size (int): size of key k.
|
|
rel_pos (Tensor): relative position embeddings (L, C).
|
|
|
|
Returns:
|
|
Extracted positional embeddings according to relative positions.
|
|
"""
|
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
# Interpolate rel pos if needed.
|
|
if rel_pos.shape[0] != max_rel_dist:
|
|
# Interpolate rel pos.
|
|
rel_pos_resized = F.interpolate(
|
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
size=max_rel_dist,
|
|
mode="linear",
|
|
)
|
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
|
else:
|
|
rel_pos_resized = rel_pos
|
|
|
|
# Scale the coords with short length if shapes for q and k are different.
|
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
|
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
|
|
return rel_pos_resized[relative_coords.long()]
|
|
|
|
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
|
"""
|
|
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
|
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
|
Args:
|
|
attn (Tensor): attention map.
|
|
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
|
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
|
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
|
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
|
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
|
|
|
Returns:
|
|
attn (Tensor): attention map with added relative positional embeddings.
|
|
"""
|
|
q_h, q_w = q_size
|
|
k_h, k_w = k_size
|
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
|
|
B, _, dim = q.shape
|
|
r_q = q.reshape(B, q_h, q_w, dim)
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
|
|
attn = (
|
|
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
|
).view(B, q_h * q_w, k_h * k_w)
|
|
|
|
return attn
|