Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
270
custom_nodes/sd-perturbed-attention/.gitignore
vendored
Normal file
@@ -0,0 +1,270 @@
|
||||
# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,macos,python,venv
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=windows,visualstudiocode,macos,python,venv
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### macOS Patch ###
|
||||
# iCloud generated files
|
||||
*.icloud
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
### venv ###
|
||||
# Virtualenv
|
||||
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
||||
[Bb]in
|
||||
[Ii]nclude
|
||||
[Ll]ib
|
||||
[Ll]ib64
|
||||
[Ll]ocal
|
||||
[Ss]cripts
|
||||
pyvenv.cfg
|
||||
pip-selfcheck.json
|
||||
|
||||
### VisualStudioCode ###
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
### VisualStudioCode Patch ###
|
||||
# Ignore all local history of files
|
||||
.history
|
||||
.ionide
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/windows,visualstudiocode,macos,python,venv
|
||||
|
||||
# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)
|
||||
ref/
|
||||
21
custom_nodes/sd-perturbed-attention/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 pamparamm
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
86
custom_nodes/sd-perturbed-attention/README.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Various Guidance implementations for ComfyUI / SD WebUI (reForge)
|
||||
|
||||
Implementation of
|
||||
|
||||
- Perturbed-Attention Guidance from [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance (D. Ahn et al.)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/)
|
||||
- [Smoothed Energy Guidance: Guiding Diffusion Models with Reduced Energy Curvature of Attention (Susung Hong)](https://arxiv.org/abs/2408.00760)
|
||||
- Sliding Window Guidance from [The Unreasonable Effectiveness of Guidance for Diffusion Models (Kaiser et al.)](https://arxiv.org/abs/2411.10257)
|
||||
- [PLADIS: Pushing the Limits of Attention in Diffusion Models at Inference Time by Leveraging Sparsity](https://cubeyoung.github.io/pladis-proejct/) (ComfyUI-only)
|
||||
- [Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models](https://arxiv.org/abs/2505.21179) (ComfyUI-only, has a description inside ComfyUI)
|
||||
- [Token Perturbation Guidance for Diffusion Models](https://arxiv.org/abs/2506.10036) (ComfyUI-only)
|
||||
|
||||
as an extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [SD WebUI (reForge)](https://github.com/Panchovix/stable-diffusion-webui-reForge).
|
||||
|
||||
Works with SD1.5 and SDXL.
|
||||
|
||||
## Installation
|
||||
|
||||
### ComfyUI
|
||||
|
||||
You can either:
|
||||
|
||||
- `git clone https://github.com/pamparamm/sd-perturbed-attention.git` into `ComfyUI/custom-nodes/` folder.
|
||||
|
||||
- Install it via [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) (search for custom node named "Perturbed-Attention Guidance").
|
||||
|
||||
- Install it via [comfy-cli](https://comfydocs.org/comfy-cli/getting-started) with `comfy node registry-install sd-perturbed-attention`
|
||||
|
||||
### SD WebUI (reForge)
|
||||
|
||||
`git clone https://github.com/pamparamm/sd-perturbed-attention.git` into `stable-diffusion-webui-forge/extensions/` folder.
|
||||
|
||||
### SD WebUI (Auto1111)
|
||||
|
||||
As an alternative for A1111 WebUI you can use PAG implementation from [sd-webui-incantations](https://github.com/v0xie/sd-webui-incantations) extension.
|
||||
|
||||
## Guidance Nodes/Scripts
|
||||
|
||||
### ComfyUI
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
### SD WebUI (reForge)
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
> [!NOTE]
|
||||
> You can override `CFG Scale` and `PAG Scale`/`SEG Scale` for Hires. fix by opening/enabling `Override for Hires. fix` tab.
|
||||
> To disable PAG during Hires. fix, you can set `PAG Scale` under Override to 0.
|
||||
|
||||
### Inputs
|
||||
|
||||
- `scale`: Guidance scale, higher values can both increase structural coherence of an image and oversaturate/fry it entirely.
|
||||
- `adaptive_scale` (PAG only): PAG dampening factor, it penalizes PAG during late denoising stages, resulting in overall speedup: 0.0 means no penalty and 1.0 completely removes PAG.
|
||||
- `blur_sigma` (SEG only): Normal deviation of Gaussian blur, higher values increase "clarity" of an image. Negative values set `blur_sigma` to infinity.
|
||||
- `unet_block`: Part of U-Net to which Guidance is applied, original paper suggests to use `middle`.
|
||||
- `unet_block_id`: Id of U-Net layer in a selected block to which Guidance is applied. Guidance can be applied only to layers containing Self-attention blocks.
|
||||
- `sigma_start` / `sigma_end`: Guidance will be active only between `sigma_start` and `sigma_end`. Set both values to negative to disable this feature.
|
||||
- `rescale`: Acts similar to RescaleCFG node - it prevents over-exposure on high `scale` values. Based on Algorithm 2 from [Common Diffusion Noise Schedules and Sample Steps are Flawed (Lin et al.)](https://arxiv.org/abs/2305.08891). Set to 0 to disable this feature.
|
||||
- `rescale_mode`:
|
||||
- `full` - takes into account both CFG and Guidance.
|
||||
- `partial` - depends only on Guidance.
|
||||
- `snf` - Saliency-adaptive Noise Fusion from [High-fidelity Person-centric Subject-to-Image Synthesis (Wang et al.)](https://arxiv.org/abs/2311.10329). Should increase image quality on high guidance scales. Ignores `rescale` value.
|
||||
- `unet_block_list`: Optional input, replaces both `unet_block` and `unet_block_id` and allows you to select multiple U-Net layers separated with commas. SDXL U-Net has multiple indices for layers, you can specify them by using dot symbol (if not specified, Guidance will be applied to the whole layer). Example value: `m0,u0.4` (it applies Guidance to middle block 0 and to output block 0 with index 4)
|
||||
- In terms of U-Net `d` means `input`, `m` means `middle` and `u` means `output`.
|
||||
- SD1.5 U-Net has layers `d0`-`d5`, `m0`, `u0`-`u8`.
|
||||
- SDXL U-Net has layers `d0`-`d3`, `m0`, `u0`-`u5`. In addition, each block except `d0` and `d1` has `0-9` index values (like `m0.7` or `u0.4`). `d0` and `d1` have `0-1` index values.
|
||||
- Supports block ranges (`d0-d3` corresponds to `d0,d1,d2,d3`) and index value ranges (`d2.2-9` corresponds to all index values of `d2` with the exclusion of `d2.0` and `d2.1`).
|
||||
|
||||
## ComfyUI TensorRT PAG (Experimental)
|
||||
|
||||
To use PAG together with [ComfyUI_TensorRT](https://github.com/comfyanonymous/ComfyUI_TensorRT), you'll need to:
|
||||
|
||||
0. Have 24GB of VRAM.
|
||||
1. Build static/dynamic TRT engine of a desired model.
|
||||
2. Build static/dynamic TRT engine of the same model with the same TRT parameters, but with fixed PAG injection in selected UNET blocks (`TensorRT Attach PAG` node).
|
||||
3. Use `TensorRT Perturbed-Attention Guidance` node with two model inputs: one for base engine and one for PAG engine.
|
||||
|
||||

|
||||
|
||||

|
||||
25
custom_nodes/sd-perturbed-attention/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from . import nag_nodes, tpg_nodes, pladis_nodes
|
||||
from .pag_nodes import PerturbedAttention, SlidingWindowGuidanceAdvanced, SmoothedEnergyGuidanceAdvanced
|
||||
from .pag_trt_nodes import TRTAttachPag, TRTPerturbedAttention
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PerturbedAttention": PerturbedAttention,
|
||||
"SmoothedEnergyGuidanceAdvanced": SmoothedEnergyGuidanceAdvanced,
|
||||
"SlidingWindowGuidanceAdvanced": SlidingWindowGuidanceAdvanced,
|
||||
"TRTAttachPag": TRTAttachPag,
|
||||
"TRTPerturbedAttention": TRTPerturbedAttention,
|
||||
**nag_nodes.NODE_CLASS_MAPPINGS,
|
||||
**tpg_nodes.NODE_CLASS_MAPPINGS,
|
||||
**pladis_nodes.NODE_CLASS_MAPPINGS,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PerturbedAttention": "Perturbed-Attention Guidance (Advanced)",
|
||||
"SmoothedEnergyGuidanceAdvanced": "Smoothed Energy Guidance (Advanced)",
|
||||
"SlidingWindowGuidanceAdvanced": "Sliding Window Guidance (Advanced)",
|
||||
"TRTAttachPag": "TensorRT Attach PAG",
|
||||
"TRTPerturbedAttention": "TensorRT Perturbed-Attention Guidance",
|
||||
**nag_nodes.NODE_DISPLAY_NAME_MAPPINGS,
|
||||
**tpg_nodes.NODE_DISPLAY_NAME_MAPPINGS,
|
||||
**pladis_nodes.NODE_DISPLAY_NAME_MAPPINGS,
|
||||
}
|
||||
|
After Width: | Height: | Size: 102 KiB |
@@ -0,0 +1,625 @@
|
||||
{
|
||||
"id": "319b510b-b5ec-46d6-8605-a6a5fd7d6c6c",
|
||||
"revision": 0,
|
||||
"last_node_id": 25,
|
||||
"last_link_id": 54,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 3,
|
||||
"type": "KSampler",
|
||||
"pos": [
|
||||
1100,
|
||||
620
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
474
|
||||
],
|
||||
"flags": {},
|
||||
"order": 6,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "MODEL",
|
||||
"link": 41
|
||||
},
|
||||
{
|
||||
"name": "positive",
|
||||
"type": "CONDITIONING",
|
||||
"link": 35
|
||||
},
|
||||
{
|
||||
"name": "negative",
|
||||
"type": "CONDITIONING",
|
||||
"link": 36
|
||||
},
|
||||
{
|
||||
"name": "latent_image",
|
||||
"type": "LATENT",
|
||||
"link": 2
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "LATENT",
|
||||
"type": "LATENT",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
7
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "KSampler"
|
||||
},
|
||||
"widgets_values": [
|
||||
0,
|
||||
"fixed",
|
||||
25,
|
||||
7,
|
||||
"euler",
|
||||
"sgm_uniform",
|
||||
1
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "EmptyLatentImage",
|
||||
"pos": [
|
||||
580,
|
||||
810
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
106
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "LATENT",
|
||||
"type": "LATENT",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
2,
|
||||
47
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "EmptyLatentImage"
|
||||
},
|
||||
"widgets_values": [
|
||||
1024,
|
||||
1024,
|
||||
1
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "VAEDecode",
|
||||
"pos": [
|
||||
1320,
|
||||
620
|
||||
],
|
||||
"size": [
|
||||
140,
|
||||
46
|
||||
],
|
||||
"flags": {
|
||||
"collapsed": false
|
||||
},
|
||||
"order": 8,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "samples",
|
||||
"type": "LATENT",
|
||||
"link": 7
|
||||
},
|
||||
{
|
||||
"name": "vae",
|
||||
"type": "VAE",
|
||||
"link": 51
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
12
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAEDecode"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"type": "PreviewImage",
|
||||
"pos": [
|
||||
1320,
|
||||
700
|
||||
],
|
||||
"size": [
|
||||
440,
|
||||
480
|
||||
],
|
||||
"flags": {},
|
||||
"order": 10,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": 12
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"properties": {
|
||||
"Node name for S&R": "PreviewImage"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"type": "NormalizedAttentionGuidance",
|
||||
"pos": [
|
||||
850,
|
||||
620
|
||||
],
|
||||
"size": [
|
||||
233.67147827148438,
|
||||
198
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "MODEL",
|
||||
"link": 53
|
||||
},
|
||||
{
|
||||
"name": "negative",
|
||||
"type": "CONDITIONING",
|
||||
"link": 40
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "MODEL",
|
||||
"type": "MODEL",
|
||||
"links": [
|
||||
41
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "NormalizedAttentionGuidance"
|
||||
},
|
||||
"widgets_values": [
|
||||
4,
|
||||
0.5,
|
||||
1,
|
||||
-1,
|
||||
10.000000000000002,
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 19,
|
||||
"type": "CLIPTextEncode",
|
||||
"pos": [
|
||||
400,
|
||||
530
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
100
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "clip",
|
||||
"type": "CLIP",
|
||||
"link": 49
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "CONDITIONING",
|
||||
"type": "CONDITIONING",
|
||||
"links": [
|
||||
35,
|
||||
45
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "CLIPTextEncode"
|
||||
},
|
||||
"widgets_values": [
|
||||
"elsa \\(frozen\\), portrait,"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 20,
|
||||
"type": "CLIPTextEncode",
|
||||
"pos": [
|
||||
400,
|
||||
670
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
100
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "clip",
|
||||
"type": "CLIP",
|
||||
"link": 50
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "CONDITIONING",
|
||||
"type": "CONDITIONING",
|
||||
"links": [
|
||||
36,
|
||||
40,
|
||||
46
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "CLIPTextEncode"
|
||||
},
|
||||
"widgets_values": [
|
||||
"ugly, sketch, blurry, collage, blonde hair, blue eyes,"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 22,
|
||||
"type": "KSampler",
|
||||
"pos": [
|
||||
1100,
|
||||
10
|
||||
],
|
||||
"size": [
|
||||
210,
|
||||
474
|
||||
],
|
||||
"flags": {},
|
||||
"order": 5,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "MODEL",
|
||||
"link": 54
|
||||
},
|
||||
{
|
||||
"name": "positive",
|
||||
"type": "CONDITIONING",
|
||||
"link": 45
|
||||
},
|
||||
{
|
||||
"name": "negative",
|
||||
"type": "CONDITIONING",
|
||||
"link": 46
|
||||
},
|
||||
{
|
||||
"name": "latent_image",
|
||||
"type": "LATENT",
|
||||
"link": 47
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "LATENT",
|
||||
"type": "LATENT",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
43
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "KSampler"
|
||||
},
|
||||
"widgets_values": [
|
||||
0,
|
||||
"fixed",
|
||||
25,
|
||||
7,
|
||||
"euler",
|
||||
"sgm_uniform",
|
||||
1
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"type": "PreviewImage",
|
||||
"pos": [
|
||||
1320,
|
||||
90
|
||||
],
|
||||
"size": [
|
||||
440,
|
||||
480
|
||||
],
|
||||
"flags": {},
|
||||
"order": 9,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": 42
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"properties": {
|
||||
"Node name for S&R": "PreviewImage"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 24,
|
||||
"type": "VAEDecode",
|
||||
"pos": [
|
||||
1320,
|
||||
10
|
||||
],
|
||||
"size": [
|
||||
140,
|
||||
46
|
||||
],
|
||||
"flags": {
|
||||
"collapsed": false
|
||||
},
|
||||
"order": 7,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "samples",
|
||||
"type": "LATENT",
|
||||
"link": 43
|
||||
},
|
||||
{
|
||||
"name": "vae",
|
||||
"type": "VAE",
|
||||
"link": 52
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
42
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "VAEDecode"
|
||||
},
|
||||
"widgets_values": []
|
||||
},
|
||||
{
|
||||
"id": 25,
|
||||
"type": "CheckpointLoaderSimple",
|
||||
"pos": [
|
||||
400,
|
||||
390
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
98
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "MODEL",
|
||||
"type": "MODEL",
|
||||
"links": [
|
||||
53,
|
||||
54
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "CLIP",
|
||||
"type": "CLIP",
|
||||
"links": [
|
||||
49,
|
||||
50
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "VAE",
|
||||
"type": "VAE",
|
||||
"links": [
|
||||
51,
|
||||
52
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "CheckpointLoaderSimple"
|
||||
},
|
||||
"widgets_values": [
|
||||
"sdxl\\base\\sd_xl_base_1.0.safetensors"
|
||||
]
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
[
|
||||
2,
|
||||
5,
|
||||
0,
|
||||
3,
|
||||
3,
|
||||
"LATENT"
|
||||
],
|
||||
[
|
||||
7,
|
||||
3,
|
||||
0,
|
||||
8,
|
||||
0,
|
||||
"LATENT"
|
||||
],
|
||||
[
|
||||
12,
|
||||
8,
|
||||
0,
|
||||
13,
|
||||
0,
|
||||
"IMAGE"
|
||||
],
|
||||
[
|
||||
35,
|
||||
19,
|
||||
0,
|
||||
3,
|
||||
1,
|
||||
"CONDITIONING"
|
||||
],
|
||||
[
|
||||
36,
|
||||
20,
|
||||
0,
|
||||
3,
|
||||
2,
|
||||
"CONDITIONING"
|
||||
],
|
||||
[
|
||||
40,
|
||||
20,
|
||||
0,
|
||||
18,
|
||||
1,
|
||||
"CONDITIONING"
|
||||
],
|
||||
[
|
||||
41,
|
||||
18,
|
||||
0,
|
||||
3,
|
||||
0,
|
||||
"MODEL"
|
||||
],
|
||||
[
|
||||
42,
|
||||
24,
|
||||
0,
|
||||
23,
|
||||
0,
|
||||
"IMAGE"
|
||||
],
|
||||
[
|
||||
43,
|
||||
22,
|
||||
0,
|
||||
24,
|
||||
0,
|
||||
"LATENT"
|
||||
],
|
||||
[
|
||||
45,
|
||||
19,
|
||||
0,
|
||||
22,
|
||||
1,
|
||||
"CONDITIONING"
|
||||
],
|
||||
[
|
||||
46,
|
||||
20,
|
||||
0,
|
||||
22,
|
||||
2,
|
||||
"CONDITIONING"
|
||||
],
|
||||
[
|
||||
47,
|
||||
5,
|
||||
0,
|
||||
22,
|
||||
3,
|
||||
"LATENT"
|
||||
],
|
||||
[
|
||||
49,
|
||||
25,
|
||||
1,
|
||||
19,
|
||||
0,
|
||||
"CLIP"
|
||||
],
|
||||
[
|
||||
50,
|
||||
25,
|
||||
1,
|
||||
20,
|
||||
0,
|
||||
"CLIP"
|
||||
],
|
||||
[
|
||||
51,
|
||||
25,
|
||||
2,
|
||||
8,
|
||||
1,
|
||||
"VAE"
|
||||
],
|
||||
[
|
||||
52,
|
||||
25,
|
||||
2,
|
||||
24,
|
||||
1,
|
||||
"VAE"
|
||||
],
|
||||
[
|
||||
53,
|
||||
25,
|
||||
0,
|
||||
18,
|
||||
0,
|
||||
"MODEL"
|
||||
],
|
||||
[
|
||||
54,
|
||||
25,
|
||||
0,
|
||||
22,
|
||||
0,
|
||||
"MODEL"
|
||||
]
|
||||
],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": {
|
||||
"frontendVersion": "1.23.0"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
241
custom_nodes/sd-perturbed-attention/guidance_utils.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import math
|
||||
from itertools import groupby
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def parse_unet_blocks(model, unet_block_list: str, attn: Literal["attn1", "attn2"] | None):
|
||||
output: list[tuple[str, int, int | None]] = []
|
||||
names: list[str] = []
|
||||
|
||||
# Get all Self-attention blocks
|
||||
input_blocks: list[tuple[int, str]] = []
|
||||
middle_blocks: list[tuple[int, str]] = []
|
||||
output_blocks: list[tuple[int, str]] = []
|
||||
for name, module in model.model.diffusion_model.named_modules():
|
||||
if module.__class__.__name__ == "BasicTransformerBlock" and (attn is None or hasattr(module, attn)):
|
||||
parts = name.split(".")
|
||||
unet_part = parts[0]
|
||||
block_id = int(parts[1])
|
||||
if unet_part.startswith("input"):
|
||||
input_blocks.append((block_id, name))
|
||||
elif unet_part.startswith("middle"):
|
||||
middle_blocks.append((block_id - 1, name))
|
||||
elif unet_part.startswith("output"):
|
||||
output_blocks.append((block_id, name))
|
||||
|
||||
def group_blocks(blocks: list[tuple[int, str]]):
|
||||
grouped_blocks = [(i, list(gr)) for i, gr in groupby(blocks, lambda b: b[0])]
|
||||
return [(i, len(gr), list(idx[1] for idx in gr)) for i, gr in grouped_blocks]
|
||||
|
||||
input_blocks_gr, middle_blocks_gr, output_blocks_gr = (
|
||||
group_blocks(input_blocks),
|
||||
group_blocks(middle_blocks),
|
||||
group_blocks(output_blocks),
|
||||
)
|
||||
|
||||
user_inputs = [b.strip() for b in unet_block_list.split(",")]
|
||||
for user_input in user_inputs:
|
||||
unet_part_s, indices = user_input[0], user_input[1:].split(".")
|
||||
match unet_part_s:
|
||||
case "d":
|
||||
unet_part, unet_group = "input", input_blocks_gr
|
||||
case "m":
|
||||
unet_part, unet_group = "middle", middle_blocks_gr
|
||||
case "u":
|
||||
unet_part, unet_group = "output", output_blocks_gr
|
||||
case _:
|
||||
raise ValueError(f"Block {user_input}: Unknown block prefix {unet_part_s}")
|
||||
|
||||
block_index_range = [int(b.strip()) for b in indices[0].split("-")]
|
||||
block_index_range_start = block_index_range[0]
|
||||
block_index_range_end = block_index_range[0] if len(block_index_range) != 2 else block_index_range[1]
|
||||
for block_index in range(block_index_range_start, block_index_range_end + 1):
|
||||
if block_index < 0 or block_index >= len(unet_group):
|
||||
raise ValueError(
|
||||
f"Block {user_input}: Block index in out of range 0 <= {block_index} < {len(unet_group)}"
|
||||
)
|
||||
|
||||
block_group = unet_group[block_index]
|
||||
block_index_real = block_group[0]
|
||||
|
||||
if len(indices) == 1:
|
||||
output.append((unet_part, block_index_real, None))
|
||||
names.extend(block_group[2])
|
||||
else:
|
||||
transformer_index_range = [int(b.strip()) for b in indices[1].split("-")]
|
||||
transformer_index_range_start = transformer_index_range[0]
|
||||
transformer_index_range_end = (
|
||||
transformer_index_range[0] if len(transformer_index_range) != 2 else transformer_index_range[1]
|
||||
)
|
||||
for transformer_index in range(transformer_index_range_start, transformer_index_range_end + 1):
|
||||
if transformer_index is not None and (transformer_index < 0 or transformer_index >= block_group[1]):
|
||||
raise ValueError(
|
||||
f"Block {user_input}: Transformer index in out of range 0 <= {transformer_index} < {block_group[1]}"
|
||||
)
|
||||
|
||||
output.append((unet_part, block_index_real, transformer_index))
|
||||
names.append(block_group[2][transformer_index])
|
||||
|
||||
return output, names
|
||||
|
||||
|
||||
# Copied from https://github.com/comfyanonymous/ComfyUI/blob/719fb2c81d716ce8edd7f1bdc7804ae160a71d3a/comfy/model_patcher.py#L21 for backward compatibility
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
else:
|
||||
to["patches_replace"] = to["patches_replace"].copy()
|
||||
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
else:
|
||||
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
||||
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
model_options["transformer_options"] = to
|
||||
return model_options
|
||||
|
||||
|
||||
def set_model_options_value(model_options, key: str, value: Any):
|
||||
to = model_options["transformer_options"].copy()
|
||||
to[key] = value
|
||||
model_options["transformer_options"] = to
|
||||
return model_options
|
||||
|
||||
|
||||
def perturbed_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options, mask=None):
|
||||
"""Perturbed self-attention"""
|
||||
return v
|
||||
|
||||
|
||||
# Modified 'Algorithm 2 Classifier-Free Guidance with Rescale' from Common Diffusion Noise Schedules and Sample Steps are Flawed (Lin et al.).
|
||||
def rescale_guidance(
|
||||
guidance: torch.Tensor, cond_pred: torch.Tensor, cfg_result: torch.Tensor, rescale=0.0, rescale_mode="full"
|
||||
):
|
||||
if rescale == 0.0:
|
||||
return guidance
|
||||
|
||||
match rescale_mode:
|
||||
case "full":
|
||||
guidance_result = cfg_result + guidance
|
||||
case _:
|
||||
guidance_result = cond_pred + guidance
|
||||
|
||||
std_cond = torch.std(cond_pred, dim=(1, 2, 3), keepdim=True)
|
||||
std_guidance = torch.std(guidance_result, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
factor = std_cond / std_guidance
|
||||
factor = rescale * factor + (1.0 - rescale)
|
||||
|
||||
return guidance * factor
|
||||
|
||||
|
||||
# Gaussian blur
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
height = img.shape[-1]
|
||||
kernel_size = min(kernel_size, height - (height % 2 - 1))
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
|
||||
x_kernel = pdf / pdf.sum()
|
||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||
|
||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def seg_attention_wrapper(attention, blur_sigma=1.0):
|
||||
def seg_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options, mask=None):
|
||||
"""Smoothed Energy Guidance self-attention"""
|
||||
heads = extra_options["n_heads"]
|
||||
bs, area, inner_dim = q.shape
|
||||
|
||||
height_orig, width_orig = extra_options["original_shape"][2:4]
|
||||
aspect_ratio = width_orig / height_orig
|
||||
|
||||
if aspect_ratio >= 1.0:
|
||||
height = round((area / aspect_ratio) ** 0.5)
|
||||
q = q.permute(0, 2, 1).reshape(bs, inner_dim, height, -1)
|
||||
else:
|
||||
width = round((area * aspect_ratio) ** 0.5)
|
||||
q = q.permute(0, 2, 1).reshape(bs, inner_dim, -1, width)
|
||||
|
||||
if blur_sigma >= 0:
|
||||
kernel_size = math.ceil(6 * blur_sigma) + 1 - math.ceil(6 * blur_sigma) % 2
|
||||
q = gaussian_blur_2d(q, kernel_size, blur_sigma)
|
||||
else:
|
||||
q[:] = q.mean(dim=(-2, -1), keepdim=True)
|
||||
|
||||
q = q.reshape(bs, inner_dim, -1).permute(0, 2, 1)
|
||||
|
||||
return attention(q, k, v, heads=heads)
|
||||
|
||||
return seg_attention
|
||||
|
||||
|
||||
# Modified algorithm from 2411.10257 'The Unreasonable Effectiveness of Guidance for Diffusion Models' (Figure 6.)
|
||||
def swg_pred_calc(
|
||||
x: torch.Tensor, tile_width: int, tile_height: int, tile_overlap: int, calc_func: Callable[..., tuple[torch.Tensor]]
|
||||
):
|
||||
b, c, h, w = x.shape
|
||||
swg_pred = torch.zeros_like(x)
|
||||
overlap = torch.zeros_like(x)
|
||||
|
||||
tiles_w = math.ceil(w / (tile_width - tile_overlap))
|
||||
tiles_h = math.ceil(h / (tile_height - tile_overlap))
|
||||
|
||||
for w_i in range(tiles_w):
|
||||
for h_i in range(tiles_h):
|
||||
left, right = tile_width * w_i, tile_width * (w_i + 1) + tile_overlap
|
||||
top, bottom = tile_height * h_i, tile_height * (h_i + 1) + tile_overlap
|
||||
|
||||
x_window = x[:, :, top:bottom, left:right]
|
||||
if x_window.shape[-1] == 0 or x_window.shape[-2] == 0:
|
||||
continue
|
||||
|
||||
swg_pred_window = calc_func(x_in=x_window)[0]
|
||||
swg_pred[:, :, top:bottom, left:right] += swg_pred_window
|
||||
|
||||
overlap_window = torch.ones_like(swg_pred_window)
|
||||
overlap[:, :, top:bottom, left:right] += overlap_window
|
||||
|
||||
swg_pred = swg_pred / overlap
|
||||
return swg_pred
|
||||
|
||||
|
||||
# Saliency-adaptive Noise Fusion based on High-fidelity Person-centric Subject-to-Image Synthesis (Wang et al.)
|
||||
# https://github.com/CodeGoat24/Face-diffuser/blob/edff1a5178ac9984879d9f5e542c1d0f0059ca5f/facediffuser/pipeline.py#L535-L562
|
||||
def snf_guidance(t_guidance: torch.Tensor, s_guidance: torch.Tensor):
|
||||
b, c, h, w = t_guidance.shape
|
||||
|
||||
t_omega = gaussian_blur_2d(torch.abs(t_guidance), 3, 1)
|
||||
s_omega = gaussian_blur_2d(torch.abs(s_guidance), 3, 1)
|
||||
t_softmax = torch.softmax(t_omega.reshape(b * c, h * w), dim=1).reshape(b, c, h, w)
|
||||
s_softmax = torch.softmax(s_omega.reshape(b * c, h * w), dim=1).reshape(b, c, h, w)
|
||||
|
||||
guidance_stacked = torch.stack([t_guidance, s_guidance], dim=0)
|
||||
ts_softmax = torch.stack([t_softmax, s_softmax], dim=0)
|
||||
|
||||
argeps = torch.argmax(ts_softmax, dim=0, keepdim=True)
|
||||
|
||||
snf = torch.gather(guidance_stacked, dim=0, index=argeps).squeeze(0)
|
||||
return snf
|
||||
235
custom_nodes/sd-perturbed-attention/nag_nodes.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from contextlib import suppress
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.ldm.modules.attention import BasicTransformerBlock, CrossAttention, optimized_attention
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
from .guidance_utils import parse_unet_blocks
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
|
||||
def nag_attn2_replace_wrapper(
|
||||
nag_scale: float,
|
||||
tau: float,
|
||||
alpha: float,
|
||||
sigma_start: float,
|
||||
sigma_end: float,
|
||||
k_neg: torch.Tensor,
|
||||
v_neg: torch.Tensor,
|
||||
prev_attn2_replace: Callable | None = None,
|
||||
):
|
||||
# Modified Algorithm 1 from 2505.21179 'Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models'
|
||||
def nag_attn2_replace(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, extra_options: dict):
|
||||
heads = extra_options["n_heads"]
|
||||
attn_precision = extra_options.get("attn_precision")
|
||||
sigma = extra_options["sigmas"]
|
||||
cond_or_uncond: list[int] = extra_options.get("cond_or_uncond") # type: ignore
|
||||
|
||||
# Perform batched CA
|
||||
z = (
|
||||
optimized_attention(q, k, v, heads, attn_precision)
|
||||
if prev_attn2_replace is None
|
||||
else prev_attn2_replace(q, k, v, extra_options)
|
||||
)
|
||||
|
||||
if nag_scale == 0 or not (sigma_end < sigma[0] <= sigma_start) or COND not in cond_or_uncond:
|
||||
return z
|
||||
|
||||
bs = q.shape[0] // len(cond_or_uncond) * cond_or_uncond.count(COND)
|
||||
|
||||
k_neg_, v_neg_ = k_neg.repeat_interleave(bs, dim=0), v_neg.repeat_interleave(bs, dim=0)
|
||||
|
||||
# Get conditional queries for NAG
|
||||
# Assume that cond_or_uncond has a layout [1, 1..., 0, 0...]
|
||||
q_chunked = q.chunk(len(cond_or_uncond))
|
||||
q_pos = torch.cat(q_chunked[cond_or_uncond.index(COND) :])
|
||||
|
||||
# Apply NAG only to conditional parts of batched CA
|
||||
z_chunked = z.chunk(len(cond_or_uncond))
|
||||
z_pos = torch.cat(z_chunked[cond_or_uncond.index(COND) :])
|
||||
z_neg = optimized_attention(q_pos, k_neg_, v_neg_, heads, attn_precision)
|
||||
|
||||
z_tilde = z_pos + nag_scale * (z_pos - z_neg)
|
||||
|
||||
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True)
|
||||
norm_tilde = torch.norm(z_tilde, p=1, dim=-1, keepdim=True)
|
||||
ratio = norm_tilde / norm_pos
|
||||
|
||||
z_hat = torch.where(ratio > tau, tau, ratio) / ratio * z_tilde
|
||||
|
||||
z_nag = alpha * z_hat + (1 - alpha) * z_pos
|
||||
|
||||
# Prepend unconditional CA result to NAG result
|
||||
if UNCOND in cond_or_uncond:
|
||||
z_nag = torch.cat(z_chunked[cond_or_uncond.index(UNCOND) : cond_or_uncond.index(COND)] + (z_nag,))
|
||||
|
||||
return z_nag
|
||||
|
||||
return nag_attn2_replace
|
||||
|
||||
|
||||
class NormalizedAttentionGuidance(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (
|
||||
IO.MODEL,
|
||||
{
|
||||
"tooltip": (
|
||||
"The diffusion model.\n"
|
||||
"If you are using any other attn2 replacer (such as `IPAdapter`), you should place this node after it."
|
||||
)
|
||||
},
|
||||
),
|
||||
"negative": (
|
||||
IO.CONDITIONING,
|
||||
{"tooltip": "Negative conditioning: either the one you use for CFG or a completely different one."},
|
||||
),
|
||||
"scale": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 2.0,
|
||||
"min": 0.0,
|
||||
"max": 100.0,
|
||||
"step": 0.1,
|
||||
"round": 0.01,
|
||||
"tooltip": "Scale of NAG, does nothing when `tau=0`.",
|
||||
},
|
||||
),
|
||||
"tau": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 2.5,
|
||||
"min": 0.0,
|
||||
"max": 100.0,
|
||||
"step": 0.1,
|
||||
"round": 0.01,
|
||||
"tooltip": "Normalization threshold, larger value should increase `scale` impact.",
|
||||
},
|
||||
),
|
||||
"alpha": (
|
||||
IO.FLOAT,
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.001,
|
||||
"round": 0.001,
|
||||
"tooltip": "Linear interpolation between original (at `alpha=0`) and NAG (at `alpha=1`) results.",
|
||||
},
|
||||
),
|
||||
"sigma_start": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"sigma_end": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
},
|
||||
"optional": {
|
||||
"unet_block_list": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"tooltip": (
|
||||
"Comma-separated blocks to which NAG is being applied to. When the list is empty, NAG is being applied to all block.\n"
|
||||
"Read README from sd-perturbed-attention for more details."
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
|
||||
FUNCTION = "patch"
|
||||
DESCRIPTION = (
|
||||
"An additional way to apply negative prompts to the image.\n"
|
||||
"It's compatible with CFG, PAG, and other guidances, and can be used with guidance- and step-distilled models as well.\n"
|
||||
"It's also compatible with other attn2 replacers (such as `IPAdapter`) - but make sure to place NAG node **after** other model patches!"
|
||||
)
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
negative,
|
||||
scale=2.0,
|
||||
tau=2.5,
|
||||
alpha=0.5,
|
||||
sigma_start: float = -1.0,
|
||||
sigma_end: float = -1.0,
|
||||
unet_block_list="",
|
||||
):
|
||||
m = model.clone()
|
||||
inner_model: BaseModel = m.model
|
||||
dtype = inner_model.get_dtype()
|
||||
if inner_model.manual_cast_dtype is not None:
|
||||
dtype = inner_model.manual_cast_dtype
|
||||
device_model = inner_model.device
|
||||
device_infer = comfy.model_management.get_torch_device()
|
||||
|
||||
sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
||||
|
||||
negative_cond = negative[0][0].to(device_model, dtype=dtype)
|
||||
|
||||
blocks, block_names = parse_unet_blocks(m, unet_block_list, "attn2") if unet_block_list else (None, None)
|
||||
|
||||
# Apply NAG only to transformer blocks with cross-attention (attn2)
|
||||
for name, module in (
|
||||
(n, m)
|
||||
for n, m in inner_model.diffusion_model.named_modules()
|
||||
if isinstance(m, BasicTransformerBlock) and getattr(m, "attn2", None)
|
||||
):
|
||||
attn2: CrossAttention = module.attn2 # type: ignore
|
||||
parts: list[str] = name.split(".")
|
||||
block_name: str = parts[0].split("_")[0]
|
||||
block_id = int(parts[1])
|
||||
if block_name == "middle":
|
||||
block_id = block_id - 1
|
||||
|
||||
t_idx = None
|
||||
if "transformer_blocks" in parts:
|
||||
t_pos = parts.index("transformer_blocks") + 1
|
||||
t_idx = int(parts[t_pos])
|
||||
|
||||
if not blocks or (block_name, block_id, t_idx) in blocks or (block_name, block_id, None) in blocks:
|
||||
k_neg, v_neg = attn2.to_k(negative_cond), attn2.to_v(negative_cond)
|
||||
|
||||
# Compatibility with other attn2 replaces (such as IPAdapter)
|
||||
prev_attn2_replace = None
|
||||
with suppress(KeyError):
|
||||
block = (block_name, block_id, t_idx)
|
||||
block_full = (block_name, block_id)
|
||||
attn2_patches = m.model_options["transformer_options"]["patches_replace"]["attn2"]
|
||||
if block_full in attn2_patches:
|
||||
prev_attn2_replace = attn2_patches[block_full]
|
||||
elif block in attn2_patches:
|
||||
prev_attn2_replace = attn2_patches[block]
|
||||
|
||||
nag_attn2_replace = nag_attn2_replace_wrapper(
|
||||
scale,
|
||||
tau,
|
||||
alpha,
|
||||
sigma_start,
|
||||
sigma_end,
|
||||
k_neg.to(device_infer, dtype=dtype),
|
||||
v_neg.to(device_infer, dtype=dtype),
|
||||
prev_attn2_replace,
|
||||
)
|
||||
m.set_model_attn2_replace(nag_attn2_replace, block_name, block_id, t_idx)
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"NormalizedAttentionGuidance": NormalizedAttentionGuidance,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"NormalizedAttentionGuidance": "Normalized Attention Guidance",
|
||||
}
|
||||
315
custom_nodes/sd-perturbed-attention/pag_nodes.py
Normal file
@@ -0,0 +1,315 @@
|
||||
from functools import partial
|
||||
|
||||
BACKEND = None
|
||||
|
||||
try:
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.samplers import calc_cond_batch
|
||||
|
||||
from .guidance_utils import (
|
||||
parse_unet_blocks,
|
||||
perturbed_attention,
|
||||
rescale_guidance,
|
||||
seg_attention_wrapper,
|
||||
snf_guidance,
|
||||
swg_pred_calc,
|
||||
)
|
||||
|
||||
try:
|
||||
from comfy.model_patcher import set_model_options_patch_replace
|
||||
except ImportError:
|
||||
from .guidance_utils import set_model_options_patch_replace
|
||||
|
||||
BACKEND = "ComfyUI"
|
||||
except ImportError:
|
||||
from guidance_utils import (
|
||||
parse_unet_blocks,
|
||||
perturbed_attention,
|
||||
rescale_guidance,
|
||||
seg_attention_wrapper,
|
||||
set_model_options_patch_replace,
|
||||
snf_guidance,
|
||||
swg_pred_calc,
|
||||
)
|
||||
|
||||
try:
|
||||
from ldm_patched.ldm.modules.attention import optimized_attention
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from ldm_patched.modules.samplers import calc_cond_uncond_batch
|
||||
|
||||
BACKEND = "reForge"
|
||||
except ImportError:
|
||||
from backend.attention import attention_function as optimized_attention
|
||||
from backend.patcher.base import ModelPatcher
|
||||
from backend.sampling.sampling_function import calc_cond_uncond_batch
|
||||
|
||||
BACKEND = "Forge"
|
||||
|
||||
|
||||
class PerturbedAttention:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
|
||||
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
|
||||
"unet_block_id": ("INT", {"default": 0}),
|
||||
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
|
||||
},
|
||||
"optional": {
|
||||
"unet_block_list": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
scale: float = 3.0,
|
||||
adaptive_scale: float = 0.0,
|
||||
unet_block: str = "middle",
|
||||
unet_block_id: int = 0,
|
||||
sigma_start: float = -1.0,
|
||||
sigma_end: float = -1.0,
|
||||
rescale: float = 0.0,
|
||||
rescale_mode: str = "full",
|
||||
unet_block_list: str = "",
|
||||
):
|
||||
m = model.clone()
|
||||
|
||||
sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
||||
single_block = (unet_block, unet_block_id, None)
|
||||
blocks, block_names = (
|
||||
parse_unet_blocks(model, unet_block_list, "attn1") if unet_block_list else ([single_block], None)
|
||||
)
|
||||
|
||||
def post_cfg_function(args):
|
||||
"""CFG+PAG"""
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
signal_scale = scale
|
||||
if adaptive_scale > 0:
|
||||
t = 0
|
||||
if hasattr(model, "model_sampling"):
|
||||
t = model.model_sampling.timestep(sigma)[0].item()
|
||||
else:
|
||||
ts = model.predictor.timestep(sigma)
|
||||
t = ts[0].item()
|
||||
signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
|
||||
if signal_scale < 0:
|
||||
signal_scale = 0
|
||||
|
||||
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
||||
return cfg_result
|
||||
|
||||
# Replace Self-attention with PAG
|
||||
for block in blocks:
|
||||
layer, number, index = block
|
||||
model_options = set_model_options_patch_replace(
|
||||
model_options, perturbed_attention, "attn1", layer, number, index
|
||||
)
|
||||
|
||||
if BACKEND == "ComfyUI":
|
||||
(pag_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
if BACKEND in {"Forge", "reForge"}:
|
||||
(pag_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
|
||||
|
||||
pag = (cond_pred - pag_cond_pred) * signal_scale
|
||||
|
||||
if rescale_mode == "snf":
|
||||
if uncond_pred.any():
|
||||
return uncond_pred + snf_guidance(cfg_result - uncond_pred, pag)
|
||||
return cfg_result + pag
|
||||
|
||||
return cfg_result + rescale_guidance(pag, cond_pred, cfg_result, rescale, rescale_mode)
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
class SmoothedEnergyGuidanceAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"blur_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 9999.0, "step": 0.01, "round": 0.001}),
|
||||
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
|
||||
"unet_block_id": ("INT", {"default": 0}),
|
||||
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"rescale_mode": (["full", "partial", "snf"], {"default": "full"}),
|
||||
},
|
||||
"optional": {
|
||||
"unet_block_list": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
scale: float = 3.0,
|
||||
blur_sigma: float = -1.0,
|
||||
unet_block: str = "middle",
|
||||
unet_block_id: int = 0,
|
||||
sigma_start: float = -1.0,
|
||||
sigma_end: float = -1.0,
|
||||
rescale: float = 0.0,
|
||||
rescale_mode: str = "full",
|
||||
unet_block_list: str = "",
|
||||
):
|
||||
m = model.clone()
|
||||
|
||||
sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
||||
single_block = (unet_block, unet_block_id, None)
|
||||
blocks, block_names = (
|
||||
parse_unet_blocks(model, unet_block_list, "attn1") if unet_block_list else ([single_block], None)
|
||||
)
|
||||
|
||||
def post_cfg_function(args):
|
||||
"""CFG+SEG"""
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
signal_scale = scale
|
||||
|
||||
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
||||
return cfg_result
|
||||
|
||||
seg_attention = seg_attention_wrapper(optimized_attention, blur_sigma)
|
||||
|
||||
# Replace Self-attention with SEG attention
|
||||
for block in blocks:
|
||||
layer, number, index = block
|
||||
model_options = set_model_options_patch_replace(
|
||||
model_options, seg_attention, "attn1", layer, number, index
|
||||
)
|
||||
|
||||
if BACKEND == "ComfyUI":
|
||||
(seg_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
if BACKEND in {"Forge", "reForge"}:
|
||||
(seg_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
|
||||
|
||||
seg = (cond_pred - seg_cond_pred) * signal_scale
|
||||
|
||||
if rescale_mode == "snf":
|
||||
if uncond_pred.any():
|
||||
return uncond_pred + snf_guidance(cfg_result - uncond_pred, seg)
|
||||
return cfg_result + seg
|
||||
|
||||
return cfg_result + rescale_guidance(seg, cond_pred, cfg_result, rescale, rescale_mode)
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
class SlidingWindowGuidanceAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"tile_width": ("INT", {"default": 768, "min": 16, "max": 16384, "step": 8}),
|
||||
"tile_height": ("INT", {"default": 768, "min": 16, "max": 16384, "step": 8}),
|
||||
"tile_overlap": ("INT", {"default": 256, "min": 16, "max": 16384, "step": 8}),
|
||||
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"sigma_end": ("FLOAT", {"default": 5.42, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
scale: float = 5.0,
|
||||
tile_width: int = 768,
|
||||
tile_height: int = 768,
|
||||
tile_overlap: int = 256,
|
||||
sigma_start: float = -1.0,
|
||||
sigma_end: float = 5.42,
|
||||
):
|
||||
m = model.clone()
|
||||
|
||||
sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
||||
tile_width, tile_height, tile_overlap = tile_width // 8, tile_height // 8, tile_overlap // 8
|
||||
|
||||
def post_cfg_function(args):
|
||||
"""CFG+SWG"""
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
signal_scale = scale
|
||||
|
||||
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
||||
return cfg_result
|
||||
|
||||
calc_func = None
|
||||
|
||||
if BACKEND == "ComfyUI":
|
||||
calc_func = partial(
|
||||
calc_cond_batch,
|
||||
model=model,
|
||||
conds=[cond],
|
||||
timestep=sigma,
|
||||
model_options=model_options,
|
||||
)
|
||||
if BACKEND in {"Forge", "reForge"}:
|
||||
calc_func = partial(
|
||||
calc_cond_uncond_batch,
|
||||
model=model,
|
||||
cond=cond,
|
||||
uncond=None,
|
||||
timestep=sigma,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
swg_pred = swg_pred_calc(x, tile_width, tile_height, tile_overlap, calc_func)
|
||||
swg = (cond_pred - swg_pred) * signal_scale
|
||||
|
||||
return cfg_result + swg
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m,)
|
||||
111
custom_nodes/sd-perturbed-attention/pag_trt_nodes.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.samplers import calc_cond_batch
|
||||
|
||||
from .guidance_utils import parse_unet_blocks, perturbed_attention, rescale_guidance
|
||||
|
||||
|
||||
class TRTAttachPag:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
|
||||
"unet_block_id": ("INT", {"default": 0}),
|
||||
},
|
||||
"optional": {
|
||||
"unet_block_list": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "attach"
|
||||
|
||||
CATEGORY = "TensorRT"
|
||||
|
||||
def attach(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
unet_block: str = "middle",
|
||||
unet_block_id: int = 0,
|
||||
unet_block_list: str = "",
|
||||
):
|
||||
m = model.clone()
|
||||
|
||||
single_block = (unet_block, unet_block_id, None)
|
||||
blocks, block_names = (
|
||||
parse_unet_blocks(model, unet_block_list, "attn1") if unet_block_list else ([single_block], None)
|
||||
)
|
||||
|
||||
# Replace Self-attention with PAG
|
||||
for block in blocks:
|
||||
layer, number, index = block
|
||||
m.set_model_attn1_replace(perturbed_attention, layer, number, index)
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
class TRTPerturbedAttention:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_base": ("MODEL",),
|
||||
"model_pag": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
|
||||
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"rescale_mode": (["full", "partial"], {"default": "full"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "TensorRT"
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model_base: ModelPatcher,
|
||||
model_pag: ModelPatcher,
|
||||
scale: float = 3.0,
|
||||
adaptive_scale: float = 0.0,
|
||||
sigma_start: float = -1.0,
|
||||
sigma_end: float = -1.0,
|
||||
rescale: float = 0.0,
|
||||
rescale_mode: str = "full",
|
||||
):
|
||||
m = model_base.clone()
|
||||
|
||||
sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
||||
|
||||
def post_cfg_function(args):
|
||||
"""CFG+PAG"""
|
||||
model = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
x = args["input"]
|
||||
|
||||
signal_scale = scale
|
||||
if adaptive_scale > 0:
|
||||
t = model.model_sampling.timestep(sigma)[0].item()
|
||||
signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
|
||||
if signal_scale < 0:
|
||||
signal_scale = 0
|
||||
|
||||
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
||||
return cfg_result
|
||||
|
||||
(pag_cond_pred,) = calc_cond_batch(model_pag.model, [cond], x, sigma, model_pag.model_options)
|
||||
|
||||
pag = (cond_pred - pag_cond_pred) * signal_scale
|
||||
|
||||
return cfg_result + rescale_guidance(pag, cond_pred, cfg_result, rescale, rescale_mode)
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return (m,)
|
||||
83
custom_nodes/sd-perturbed-attention/pladis_nodes.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.ldm.modules.attention import BasicTransformerBlock
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
|
||||
from .guidance_utils import parse_unet_blocks
|
||||
from .pladis_utils import SPARSE_FUNCTIONS, pladis_attention_wrapper
|
||||
|
||||
|
||||
class Pladis(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
"scale": (IO.FLOAT, {"default": 2.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"sparse_func": (IO.COMBO, {"default": SPARSE_FUNCTIONS[0], "options": SPARSE_FUNCTIONS}),
|
||||
},
|
||||
"optional": {
|
||||
"unet_block_list": (
|
||||
IO.STRING,
|
||||
{
|
||||
"default": "",
|
||||
"tooltip": (
|
||||
"Comma-separated blocks to which Pladis is being applied to. When the list is empty, PLADIS is being applied to all `u` and `d` blocks.\n"
|
||||
"Read README from sd-perturbed-attention for more details."
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
scale=2.0,
|
||||
sparse_func=SPARSE_FUNCTIONS[0],
|
||||
unet_block_list="",
|
||||
):
|
||||
m = model.clone()
|
||||
inner_model: BaseModel = m.model
|
||||
pladis_attention = pladis_attention_wrapper(scale, sparse_func)
|
||||
|
||||
blocks, block_names = parse_unet_blocks(m, unet_block_list, "attn2") if unet_block_list else (None, None)
|
||||
|
||||
# Apply PLADIS only to transformer blocks with cross-attention (attn2)
|
||||
for name, module in (
|
||||
(n, m)
|
||||
for n, m in inner_model.diffusion_model.named_modules()
|
||||
if isinstance(m, BasicTransformerBlock) and getattr(m, "attn2", None)
|
||||
):
|
||||
parts = name.split(".")
|
||||
block_name: str = parts[0].split("_")[0]
|
||||
block_id = int(parts[1])
|
||||
if block_name == "middle":
|
||||
block_id = block_id - 1
|
||||
if not blocks:
|
||||
continue
|
||||
|
||||
t_idx = None
|
||||
if "transformer_blocks" in parts:
|
||||
t_pos = parts.index("transformer_blocks") + 1
|
||||
t_idx = int(parts[t_pos])
|
||||
|
||||
if not blocks or (block_name, block_id, t_idx) in blocks or (block_name, block_id, None) in blocks:
|
||||
m.set_model_attn2_replace(pladis_attention, block_name, block_id, t_idx)
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PLADIS": Pladis,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PLADIS": "PLADIS",
|
||||
}
|
||||
166
custom_nodes/sd-perturbed-attention/pladis_utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
ENTMAX15_FUNC = "entmax1.5" # sparse attention with alpha=1.5
|
||||
SPARSEMAX_FUNC = "sparsemax" # sparse attention with alpha=2
|
||||
|
||||
SPARSE_FUNCTIONS: list = [ENTMAX15_FUNC, SPARSEMAX_FUNC]
|
||||
|
||||
|
||||
def pladis_attention_wrapper(pladis_scale=2.0, sparse_func=SPARSE_FUNCTIONS[0]):
|
||||
# Simplified attention_basic with sparse functions instead of a softmax
|
||||
def _pladis_sparse_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
extra_options: dict,
|
||||
):
|
||||
heads = extra_options["n_heads"]
|
||||
attn_precision = extra_options.get("attn_precision")
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale: int = dim_head**-0.5
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
sim = q @ k.transpose(-2, -1) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
dense_sim = torch.softmax(sim, dim=-1)
|
||||
if sparse_func == ENTMAX15_FUNC:
|
||||
sparse_sim = Entmax.entmax15(sim, dim=-1)
|
||||
elif sparse_func == SPARSEMAX_FUNC:
|
||||
sparse_sim = Entmax.sparsemax(sim, dim=-1)
|
||||
else: # fallback to the default from paper
|
||||
sparse_sim = Entmax.entmax15(sim, dim=-1)
|
||||
|
||||
pladis_sim = pladis_scale * sparse_sim + (1 - pladis_scale) * dense_sim
|
||||
|
||||
out = pladis_sim.to(v.dtype) @ v
|
||||
|
||||
out = out.unsqueeze(0).reshape(b, heads, -1, dim_head).permute(0, 2, 1, 3).reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
return _pladis_sparse_attention
|
||||
|
||||
|
||||
class Entmax:
|
||||
"""
|
||||
Activations from `entmax` module converted to a static class.
|
||||
|
||||
Both sparsemax and entmax15, and all their inner function implementations
|
||||
are taken from https://github.com/deep-spin/entmax/blob/c2bec6d5e7d649cba7766c2172d89123ec2a6d70/entmax/activations.py
|
||||
(as recommended by PLADIS paper).
|
||||
|
||||
Author: Ben Peters
|
||||
|
||||
Author: Vlad Niculae <vlad@vene.ro>
|
||||
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def entmax15(X: torch.Tensor, dim=-1, k: Optional[int] = None):
|
||||
max_val, _ = X.max(dim=dim, keepdim=True)
|
||||
X = X - max_val # same numerical stability trick as for softmax
|
||||
X = X / 2 # divide by 2 to solve actual Entmax
|
||||
|
||||
tau_star, _ = Entmax._entmax_threshold_and_support(X, dim=dim, k=k)
|
||||
|
||||
Y = torch.clamp(X - tau_star, min=0) ** 2
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
def sparsemax(X: torch.Tensor, dim=-1, k: Optional[int] = None):
|
||||
max_val, _ = X.max(dim=dim, keepdim=True)
|
||||
X = X - max_val # same numerical stability trick as softmax
|
||||
|
||||
tau, _ = Entmax._sparsemax_threshold_and_support(X, dim=dim, k=k)
|
||||
|
||||
output = torch.clamp(X - tau, min=0)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _entmax_threshold_and_support(X, dim=-1, k=None):
|
||||
if k is None or k >= X.shape[dim]: # do full sort
|
||||
Xsrt, _ = torch.sort(X, dim=dim, descending=True)
|
||||
else:
|
||||
Xsrt, _ = torch.topk(X, k=k, dim=dim)
|
||||
|
||||
rho = Entmax._make_ix_like(Xsrt, dim)
|
||||
mean = Xsrt.cumsum(dim) / rho
|
||||
mean_sq = (Xsrt**2).cumsum(dim) / rho
|
||||
ss = rho * (mean_sq - mean**2)
|
||||
delta = (1 - ss) / rho
|
||||
|
||||
delta_nz = torch.clamp(delta, 0)
|
||||
tau = mean - torch.sqrt(delta_nz)
|
||||
|
||||
support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim)
|
||||
tau_star = tau.gather(dim, support_size - 1)
|
||||
|
||||
if k is not None and k < X.shape[dim]:
|
||||
unsolved = (support_size == k).squeeze(dim)
|
||||
|
||||
if torch.any(unsolved):
|
||||
X_ = Entmax._roll_last(X, dim)[unsolved]
|
||||
tau_, ss_ = Entmax._entmax_threshold_and_support(X_, dim=-1, k=2 * k)
|
||||
Entmax._roll_last(tau_star, dim)[unsolved] = tau_
|
||||
Entmax._roll_last(support_size, dim)[unsolved] = ss_
|
||||
|
||||
return tau_star, support_size
|
||||
|
||||
@staticmethod
|
||||
def _sparsemax_threshold_and_support(X: torch.Tensor, dim=-1, k=None):
|
||||
if k is None or k >= X.shape[dim]: # do full sort
|
||||
topk, _ = torch.sort(X, dim=dim, descending=True)
|
||||
else:
|
||||
topk, _ = torch.topk(X, k=k, dim=dim)
|
||||
|
||||
topk_cumsum = topk.cumsum(dim) - 1
|
||||
rhos = Entmax._make_ix_like(topk, dim)
|
||||
support = rhos * topk > topk_cumsum
|
||||
|
||||
support_size = support.sum(dim=dim).unsqueeze(dim)
|
||||
tau = topk_cumsum.gather(dim, support_size - 1)
|
||||
tau /= support_size.to(X.dtype)
|
||||
|
||||
if k is not None and k < X.shape[dim]:
|
||||
unsolved = (support_size == k).squeeze(dim)
|
||||
|
||||
if torch.any(unsolved):
|
||||
in_ = Entmax._roll_last(X, dim)[unsolved]
|
||||
tau_, ss_ = Entmax._sparsemax_threshold_and_support(in_, dim=-1, k=2 * k)
|
||||
Entmax._roll_last(tau, dim)[unsolved] = tau_
|
||||
Entmax._roll_last(support_size, dim)[unsolved] = ss_
|
||||
|
||||
return tau, support_size
|
||||
|
||||
@staticmethod
|
||||
def _make_ix_like(X: torch.Tensor, dim=-1):
|
||||
d = X.size(dim)
|
||||
rho = torch.arange(1, d + 1, device=X.device, dtype=X.dtype)
|
||||
view = [1] * X.dim()
|
||||
view[0] = -1
|
||||
return rho.view(view).transpose(0, dim)
|
||||
|
||||
@staticmethod
|
||||
def _roll_last(X: torch.Tensor, dim=-1):
|
||||
if dim == -1:
|
||||
return X
|
||||
elif dim < 0:
|
||||
dim = X.dim() - dim
|
||||
|
||||
perm = [i for i in range(X.dim()) if i != dim] + [dim]
|
||||
return X.permute(perm)
|
||||
14
custom_nodes/sd-perturbed-attention/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "sd-perturbed-attention"
|
||||
description = "Perturbed-Attention Guidance (PAG), Smoothed Energy Guidance (SEG), Sliding Window Guidance (SWG), PLADIS, Normalized Attention Guidance (NAG), Token Perturbation Guidance (TPG) for ComfyUI and SD reForge."
|
||||
version = "1.2.15"
|
||||
license = { text = "MIT License" }
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/pamparamm/sd-perturbed-attention"
|
||||
# Used by Comfy Registry https://comfyregistry.org
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "pamparamm"
|
||||
DisplayName = "sd-perturbed-attention"
|
||||
Icon = ""
|
||||
|
After Width: | Height: | Size: 32 KiB |
|
After Width: | Height: | Size: 12 KiB |
BIN
custom_nodes/sd-perturbed-attention/res/comfyui-node-seg.png
Normal file
|
After Width: | Height: | Size: 32 KiB |
BIN
custom_nodes/sd-perturbed-attention/res/forge-pag.png
Normal file
|
After Width: | Height: | Size: 42 KiB |
BIN
custom_nodes/sd-perturbed-attention/res/forge-seg.png
Normal file
|
After Width: | Height: | Size: 27 KiB |
BIN
custom_nodes/sd-perturbed-attention/res/trt-engines.png
Normal file
|
After Width: | Height: | Size: 162 KiB |
BIN
custom_nodes/sd-perturbed-attention/res/trt-inference.png
Normal file
|
After Width: | Height: | Size: 104 KiB |
127
custom_nodes/sd-perturbed-attention/tpg_nodes.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from comfy.ldm.modules.attention import BasicTransformerBlock
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.samplers import calc_cond_batch
|
||||
|
||||
from .guidance_utils import parse_unet_blocks, rescale_guidance, set_model_options_value, snf_guidance
|
||||
|
||||
TPG_OPTION = "tpg"
|
||||
|
||||
|
||||
# Implementation of 2506.10036 'Token Perturbation Guidance for Diffusion Models'
|
||||
class TPGTransformerWrapper(nn.Module):
|
||||
def __init__(self, transformer_block: BasicTransformerBlock) -> None:
|
||||
super().__init__()
|
||||
self.wrapped_block = transformer_block
|
||||
|
||||
def shuffle_tokens(self, x: torch.Tensor):
|
||||
# ComfyUI's torch.manual_seed generator should produce the same results here.
|
||||
permutation = torch.randperm(x.shape[1], device=x.device)
|
||||
return x[:, permutation]
|
||||
|
||||
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None, transformer_options: dict[str, Any] = {}):
|
||||
is_tpg = transformer_options.get(TPG_OPTION, False)
|
||||
x_ = self.shuffle_tokens(x) if is_tpg else x
|
||||
return self.wrapped_block(x_, context=context, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class TokenPerturbationGuidance(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"model": (IO.MODEL, {}),
|
||||
"scale": (IO.FLOAT, {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
|
||||
"sigma_start": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"sigma_end": (IO.FLOAT, {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
|
||||
"rescale": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"rescale_mode": (IO.COMBO, {"options": ["full", "partial", "snf"], "default": "full"}),
|
||||
},
|
||||
"optional": {
|
||||
"unet_block_list": (IO.STRING, {"default": "d2.2-9,d3", "tooltip": "Blocks to which TPG is applied. "}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.MODEL,)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(
|
||||
self,
|
||||
model: ModelPatcher,
|
||||
scale: float = 3.0,
|
||||
sigma_start: float = -1.0,
|
||||
sigma_end: float = -1.0,
|
||||
rescale: float = 0.0,
|
||||
rescale_mode: str = "full",
|
||||
unet_block_list: str = "",
|
||||
):
|
||||
m = model.clone()
|
||||
inner_model: BaseModel = m.model
|
||||
|
||||
sigma_start = float("inf") if sigma_start < 0 else sigma_start
|
||||
|
||||
blocks, block_names = parse_unet_blocks(model, unet_block_list, None) if unet_block_list else (None, None)
|
||||
|
||||
# Patch transformer blocks with TPG wrapper
|
||||
for name, module in inner_model.diffusion_model.named_modules():
|
||||
if (
|
||||
isinstance(module, BasicTransformerBlock)
|
||||
and not "wrapped_block" in name
|
||||
and (block_names is None or name in block_names)
|
||||
):
|
||||
# Potential memory leak?
|
||||
wrapper = TPGTransformerWrapper(module)
|
||||
m.add_object_patch(f"diffusion_model.{name}", wrapper)
|
||||
|
||||
def post_cfg_function(args):
|
||||
"""CFG+TPG"""
|
||||
model: BaseModel = args["model"]
|
||||
cond_pred = args["cond_denoised"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
cond = args["cond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
signal_scale = scale
|
||||
|
||||
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
|
||||
return cfg_result
|
||||
|
||||
# Enable TPG in patched transformer blocks
|
||||
for name, module in model.diffusion_model.named_modules():
|
||||
if isinstance(module, TPGTransformerWrapper):
|
||||
set_model_options_value(model_options, TPG_OPTION, True)
|
||||
|
||||
(tpg_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
|
||||
tpg = (cond_pred - tpg_cond_pred) * signal_scale
|
||||
|
||||
if rescale_mode == "snf":
|
||||
if uncond_pred.any():
|
||||
return uncond_pred + snf_guidance(cfg_result - uncond_pred, tpg)
|
||||
return cfg_result + tpg
|
||||
|
||||
return cfg_result + rescale_guidance(tpg, cond_pred, cfg_result, rescale, rescale_mode)
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function, rescale_mode == "snf")
|
||||
|
||||
return (m,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TokenPerturbationGuidance": TokenPerturbationGuidance,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TokenPerturbationGuidance": "Token Perturbation Guidance",
|
||||
}
|
||||