Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
167 lines
5.6 KiB
Python
167 lines
5.6 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
ENTMAX15_FUNC = "entmax1.5" # sparse attention with alpha=1.5
|
|
SPARSEMAX_FUNC = "sparsemax" # sparse attention with alpha=2
|
|
|
|
SPARSE_FUNCTIONS: list = [ENTMAX15_FUNC, SPARSEMAX_FUNC]
|
|
|
|
|
|
def pladis_attention_wrapper(pladis_scale=2.0, sparse_func=SPARSE_FUNCTIONS[0]):
|
|
# Simplified attention_basic with sparse functions instead of a softmax
|
|
def _pladis_sparse_attention(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
extra_options: dict,
|
|
):
|
|
heads = extra_options["n_heads"]
|
|
attn_precision = extra_options.get("attn_precision")
|
|
|
|
b, _, dim_head = q.shape
|
|
dim_head //= heads
|
|
|
|
scale: int = dim_head**-0.5
|
|
|
|
q, k, v = map(
|
|
lambda t: t.unsqueeze(3)
|
|
.reshape(b, -1, heads, dim_head)
|
|
.permute(0, 2, 1, 3)
|
|
.reshape(b * heads, -1, dim_head)
|
|
.contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
sim = q @ k.transpose(-2, -1) * scale
|
|
|
|
del q, k
|
|
|
|
dense_sim = torch.softmax(sim, dim=-1)
|
|
if sparse_func == ENTMAX15_FUNC:
|
|
sparse_sim = Entmax.entmax15(sim, dim=-1)
|
|
elif sparse_func == SPARSEMAX_FUNC:
|
|
sparse_sim = Entmax.sparsemax(sim, dim=-1)
|
|
else: # fallback to the default from paper
|
|
sparse_sim = Entmax.entmax15(sim, dim=-1)
|
|
|
|
pladis_sim = pladis_scale * sparse_sim + (1 - pladis_scale) * dense_sim
|
|
|
|
out = pladis_sim.to(v.dtype) @ v
|
|
|
|
out = out.unsqueeze(0).reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
|
|
return out
|
|
|
|
return _pladis_sparse_attention
|
|
|
|
|
|
class Entmax:
|
|
"""
|
|
Activations from `entmax` module converted to a static class.
|
|
|
|
Both sparsemax and entmax15, and all their inner function implementations
|
|
are taken from https://github.com/deep-spin/entmax/blob/c2bec6d5e7d649cba7766c2172d89123ec2a6d70/entmax/activations.py
|
|
(as recommended by PLADIS paper).
|
|
|
|
Author: Ben Peters
|
|
|
|
Author: Vlad Niculae <vlad@vene.ro>
|
|
|
|
License: MIT
|
|
"""
|
|
|
|
@staticmethod
|
|
def entmax15(X: torch.Tensor, dim=-1, k: Optional[int] = None):
|
|
max_val, _ = X.max(dim=dim, keepdim=True)
|
|
X = X - max_val # same numerical stability trick as for softmax
|
|
X = X / 2 # divide by 2 to solve actual Entmax
|
|
|
|
tau_star, _ = Entmax._entmax_threshold_and_support(X, dim=dim, k=k)
|
|
|
|
Y = torch.clamp(X - tau_star, min=0) ** 2
|
|
return Y
|
|
|
|
@staticmethod
|
|
def sparsemax(X: torch.Tensor, dim=-1, k: Optional[int] = None):
|
|
max_val, _ = X.max(dim=dim, keepdim=True)
|
|
X = X - max_val # same numerical stability trick as softmax
|
|
|
|
tau, _ = Entmax._sparsemax_threshold_and_support(X, dim=dim, k=k)
|
|
|
|
output = torch.clamp(X - tau, min=0)
|
|
return output
|
|
|
|
@staticmethod
|
|
def _entmax_threshold_and_support(X, dim=-1, k=None):
|
|
if k is None or k >= X.shape[dim]: # do full sort
|
|
Xsrt, _ = torch.sort(X, dim=dim, descending=True)
|
|
else:
|
|
Xsrt, _ = torch.topk(X, k=k, dim=dim)
|
|
|
|
rho = Entmax._make_ix_like(Xsrt, dim)
|
|
mean = Xsrt.cumsum(dim) / rho
|
|
mean_sq = (Xsrt**2).cumsum(dim) / rho
|
|
ss = rho * (mean_sq - mean**2)
|
|
delta = (1 - ss) / rho
|
|
|
|
delta_nz = torch.clamp(delta, 0)
|
|
tau = mean - torch.sqrt(delta_nz)
|
|
|
|
support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
|
|
tau_star = tau.gather(dim, support_size - 1)
|
|
|
|
if k is not None and k < X.shape[dim]:
|
|
unsolved = (support_size == k).squeeze(dim)
|
|
|
|
if torch.any(unsolved):
|
|
X_ = Entmax._roll_last(X, dim)[unsolved]
|
|
tau_, ss_ = Entmax._entmax_threshold_and_support(X_, dim=-1, k=2 * k)
|
|
Entmax._roll_last(tau_star, dim)[unsolved] = tau_
|
|
Entmax._roll_last(support_size, dim)[unsolved] = ss_
|
|
|
|
return tau_star, support_size
|
|
|
|
@staticmethod
|
|
def _sparsemax_threshold_and_support(X: torch.Tensor, dim=-1, k=None):
|
|
if k is None or k >= X.shape[dim]: # do full sort
|
|
topk, _ = torch.sort(X, dim=dim, descending=True)
|
|
else:
|
|
topk, _ = torch.topk(X, k=k, dim=dim)
|
|
|
|
topk_cumsum = topk.cumsum(dim) - 1
|
|
rhos = Entmax._make_ix_like(topk, dim)
|
|
support = rhos * topk > topk_cumsum
|
|
|
|
support_size = support.sum(dim=dim).unsqueeze(dim)
|
|
tau = topk_cumsum.gather(dim, support_size - 1)
|
|
tau /= support_size.to(X.dtype)
|
|
|
|
if k is not None and k < X.shape[dim]:
|
|
unsolved = (support_size == k).squeeze(dim)
|
|
|
|
if torch.any(unsolved):
|
|
in_ = Entmax._roll_last(X, dim)[unsolved]
|
|
tau_, ss_ = Entmax._sparsemax_threshold_and_support(in_, dim=-1, k=2 * k)
|
|
Entmax._roll_last(tau, dim)[unsolved] = tau_
|
|
Entmax._roll_last(support_size, dim)[unsolved] = ss_
|
|
|
|
return tau, support_size
|
|
|
|
@staticmethod
|
|
def _make_ix_like(X: torch.Tensor, dim=-1):
|
|
d = X.size(dim)
|
|
rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype)
|
|
view = [1] * X.dim()
|
|
view[0] = -1
|
|
return rho.view(view).transpose(0, dim)
|
|
|
|
@staticmethod
|
|
def _roll_last(X: torch.Tensor, dim=-1):
|
|
if dim == -1:
|
|
return X
|
|
elif dim < 0:
|
|
dim = X.dim() - dim
|
|
|
|
perm = [i for i in range(X.dim()) if i != dim] + [dim]
|
|
return X.permute(perm)
|