Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
167
custom_nodes/x-flux-comfyui/xflux/.gitignore
vendored
Normal file
167
custom_nodes/x-flux-comfyui/xflux/.gitignore
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
Makefile
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
weights/
|
||||
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache/
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
.DS_Store
|
||||
201
custom_nodes/x-flux-comfyui/xflux/LICENSE
Normal file
201
custom_nodes/x-flux-comfyui/xflux/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
234
custom_nodes/x-flux-comfyui/xflux/README.md
Normal file
234
custom_nodes/x-flux-comfyui/xflux/README.md
Normal file
@@ -0,0 +1,234 @@
|
||||

|
||||
|
||||
This repository provides training scripts for [Flux model](https://github.com/black-forest-labs/flux) by Black Forest Labs. <br/>
|
||||
[XLabs AI](https://github.com/XLabs-AI) team is happy to publish fune-tuning Flux scripts, including:
|
||||
|
||||
- **LoRA** 🔥
|
||||
- **ControlNet** 🔥
|
||||
|
||||
# Training
|
||||
|
||||
We trained LoRA and ControlNet models using [DeepSpeed](https://github.com/microsoft/DeepSpeed)! <br/>
|
||||
It's available for 1024x1024 resolution!
|
||||
|
||||
## Models
|
||||
|
||||
We trained **Canny ControlNet**, **Depth ControlNet**, **HED ControlNet** and **LoRA** checkpoints for [`FLUX.1 [dev]`](https://github.com/black-forest-labs/flux) <br/>
|
||||
You can download them on HuggingFace:
|
||||
|
||||
- [flux-controlnet-collections](https://huggingface.co/XLabs-AI/flux-controlnet-collections)
|
||||
- [flux-controlnet-canny](https://huggingface.co/XLabs-AI/flux-controlnet-canny)
|
||||
- [flux-RealismLora](https://huggingface.co/XLabs-AI/flux-RealismLora)
|
||||
- [flux-lora-collections](https://huggingface.co/XLabs-AI/flux-lora-collection)
|
||||
- [flux-furry-lora](https://huggingface.co/XLabs-AI/flux-furry-lora)
|
||||
|
||||
### LoRA
|
||||
|
||||
```bash
|
||||
accelerate launch train_flux_lora_deepspeed.py --config "train_configs/test_lora.yaml"
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
```bash
|
||||
accelerate launch train_flux_deepspeed_controlnet.py --config "train_configs/test_canny_controlnet.yaml"
|
||||
```
|
||||
|
||||
## Training Dataset
|
||||
|
||||
Dataset has the following format for the training process:
|
||||
|
||||
```text
|
||||
├── images/
|
||||
│ ├── 1.png
|
||||
│ ├── 1.json
|
||||
│ ├── 2.png
|
||||
│ ├── 2.json
|
||||
│ ├── ...
|
||||
```
|
||||
|
||||
### Example `images/*.json` file
|
||||
|
||||
A `.json` file contains "caption" field with a text prompt.
|
||||
|
||||
```json
|
||||
{
|
||||
"caption": "A figure stands in a misty landscape, wearing a mask with antlers and dark, embellished attire, exuding mystery and otherworldlines"
|
||||
}
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
To test our checkpoints, use commands presented below.
|
||||
|
||||
### LoRA
|
||||

|
||||
prompt: "A girl in a suit covered with bold tattoos and holding a vest pistol, beautiful woman, 25 years old, cool, future fantasy, turquoise & light orange ping curl hair"
|
||||

|
||||
prompt: "A handsome man in a suit, 25 years old, cool, futuristic"
|
||||
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "Female furry Pixie with text 'hello world'" \
|
||||
--lora_repo_id XLabs-AI/flux-furry-lora --lora_name furry_lora.safetensors --device cuda --offload --use_lora \
|
||||
--model_type flux-dev-fp8 --width 1024 --height 1024 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||
|
||||

|
||||
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "A cute corgi lives in a house made out of sushi, anime" \
|
||||
--lora_repo_id XLabs-AI/flux-lora-collection --lora_name anime_lora.safetensors \
|
||||
--device cuda --offload --use_lora --model_type flux-dev-fp8 --width 1024 --height 1024
|
||||
|
||||
```
|
||||

|
||||
|
||||
|
||||
```bash
|
||||
python3 main.py \
|
||||
--use_lora --lora_weight 0.7 \
|
||||
--width 1024 --height 768 \
|
||||
--lora_repo_id XLabs-AI/flux-lora-collection --lora_name realism_lora.safetensors \
|
||||
--guidance 4 \
|
||||
--prompt "contrast play photography of a black female wearing white suit and albino asian geisha female wearing black suit, solid background, avant garde, high fashion"
|
||||
```
|
||||

|
||||
|
||||
## Canny ControlNet
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "a viking man with white hair looking, cinematic, MM full HD" \
|
||||
--image input_image_canny.jpg \
|
||||
--control_type canny \
|
||||
--repo_id XLabs-AI/flux-controlnet-collections --name flux-canny-controlnet.safetensors --device cuda --use_controlnet \
|
||||
--model_type flux-dev --width 768 --height 768 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||

|
||||
|
||||
## Depth ControlNet
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "Photo of the bold man with beard and laptop, full hd, cinematic photo" \
|
||||
--image input_image_depth1.jpg \
|
||||
--control_type depth \
|
||||
--repo_id XLabs-AI/flux-controlnet-collections --name flux-depth-controlnet.safetensors --device cuda --use_controlnet \
|
||||
--model_type flux-dev --width 1024 --height 1024 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||

|
||||
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "photo of handsome fluffy black dog standing on a forest path, full hd, cinematic photo" \
|
||||
--image input_image_depth2.jpg \
|
||||
--control_type depth \
|
||||
--repo_id XLabs-AI/flux-controlnet-collections --name flux-depth-controlnet.safetensors --device cuda --use_controlnet \
|
||||
--model_type flux-dev --width 1024 --height 1024 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||

|
||||
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "Photo of japanese village with houses and sakura, full hd, cinematic photo" \
|
||||
--image input_image_depth3.webp \
|
||||
--control_type depth \
|
||||
--repo_id XLabs-AI/flux-controlnet-collections --name flux-depth-controlnet.safetensors --device cuda --use_controlnet \
|
||||
--model_type flux-dev --width 1024 --height 1024 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||

|
||||
|
||||
|
||||
## HED ControlNet
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "2d art of a sitting african rich woman, full hd, cinematic photo" \
|
||||
--image input_image_hed1.jpg \
|
||||
--control_type hed \
|
||||
--repo_id XLabs-AI/flux-controlnet-collections --name flux-hed-controlnet.safetensors --device cuda --use_controlnet \
|
||||
--model_type flux-dev --width 768 --height 768 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||

|
||||
|
||||
```bash
|
||||
python3 main.py \
|
||||
--prompt "anime ghibli style art of a running happy white dog, full hd" \
|
||||
--image input_image_hed2.jpg \
|
||||
--control_type hed \
|
||||
--repo_id XLabs-AI/flux-controlnet-collections --name flux-hed-controlnet.safetensors --device cuda --use_controlnet \
|
||||
--model_type flux-dev --width 768 --height 768 \
|
||||
--timestep_to_start_cfg 1 --num_steps 25 --true_gs 3.5 --guidance 4
|
||||
|
||||
```
|
||||

|
||||
|
||||
## Low memory mode
|
||||
|
||||
Use LoRA and Controlnet FP8 version based on [Flux-dev-F8](https://huggingface.co/XLabs-AI/flux-dev-fp8) with `--offload` setting to achieve lower VRAM usage (22 GB) and `--name flux-dev-fp8`:
|
||||
```bash
|
||||
python3 main.py \
|
||||
--offload --name flux-dev-fp8 \
|
||||
--lora_repo_id XLabs-AI/flux-lora-collection --lora_name realism_lora.safetensors \
|
||||
--guidance 4 \
|
||||
--prompt "A handsome girl in a suit covered with bold tattoos and holding a pistol. Animatrix illustration style, fantasy style, natural photo cinematic"
|
||||
```
|
||||

|
||||
|
||||
## Requirements
|
||||
|
||||
Install our dependencies by running the following command:
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
## Accelerate Configuration Example
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
|
||||
```
|
||||
## Models Licence
|
||||
|
||||
Our models fall under the [FLUX.1 [dev] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev) <br/> Our training and infer scripts under the Apache 2 License
|
||||
|
||||
## Near Updates
|
||||
|
||||
We are working on releasing new ControlNet weight models for Flux: **OpenPose**, **Depth** and more! <br/>
|
||||
Stay tuned with [XLabs AI](https://github.com/XLabs-AI) to see **IP-Adapters** for Flux.
|
||||
|
||||

|
||||
11
custom_nodes/x-flux-comfyui/xflux/src/flux/__init__.py
Normal file
11
custom_nodes/x-flux-comfyui/xflux/src/flux/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
try:
|
||||
from ._version import version as __version__ # type: ignore
|
||||
from ._version import version_tuple
|
||||
except ImportError:
|
||||
__version__ = "unknown (no version information available)"
|
||||
version_tuple = (0, 0, "unknown", "noinfo")
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
PACKAGE = __package__.replace("_", "-")
|
||||
PACKAGE_ROOT = Path(__file__).parent
|
||||
38
custom_nodes/x-flux-comfyui/xflux/src/flux/annotator/util.py
Normal file
38
custom_nodes/x-flux-comfyui/xflux/src/flux/annotator/util.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
222
custom_nodes/x-flux-comfyui/xflux/src/flux/controlnet.py
Normal file
222
custom_nodes/x-flux-comfyui/xflux/src/flux/controlnet.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange
|
||||
|
||||
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
class ControlNetFlux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, params: FluxParams, controlnet_depth=2):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(controlnet_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.gradient_checkpointing = False
|
||||
self.input_hint_block = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
@property
|
||||
def attn_processors(self):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.double_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
return controlnet_block_res_samples
|
||||
30
custom_nodes/x-flux-comfyui/xflux/src/flux/math.py
Normal file
30
custom_nodes/x-flux-comfyui/xflux/src/flux/math.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
217
custom_nodes/x-flux-comfyui/xflux/src/flux/model.py
Normal file
217
custom_nodes/x-flux-comfyui/xflux/src/flux/model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange
|
||||
|
||||
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
def attn_processors(self):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
block_controlnet_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
if block_controlnet_hidden_states is not None:
|
||||
controlnet_depth = len(block_controlnet_hidden_states)
|
||||
for index_block, block in enumerate(self.double_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
# controlnet residual
|
||||
if block_controlnet_hidden_states is not None:
|
||||
img = img + block_controlnet_hidden_states[index_block % 2]
|
||||
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if torch.is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -0,0 +1,312 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
@@ -0,0 +1,38 @@
|
||||
from torch import Tensor, nn
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
||||
T5Tokenizer)
|
||||
|
||||
|
||||
class HFEmbedder(nn.Module):
|
||||
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
||||
super().__init__()
|
||||
self.is_clip = version.startswith("openai")
|
||||
self.max_length = max_length
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
|
||||
if self.is_clip:
|
||||
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
||||
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
||||
else:
|
||||
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
||||
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
||||
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
||||
358
custom_nodes/x-flux-comfyui/xflux/src/flux/modules/layers.py
Normal file
358
custom_nodes/x-flux-comfyui/xflux/src/flux/modules/layers.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
class FLuxSelfAttnProcessor:
|
||||
def __call__(self, attn, x, pe, **attention_kwargs):
|
||||
print('2' * 30)
|
||||
|
||||
qkv = attn.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = attn.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = attn.proj(x)
|
||||
return x
|
||||
|
||||
class LoraFluxAttnProcessor(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
||||
super().__init__()
|
||||
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
||||
self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
|
||||
self.lora_weight = lora_weight
|
||||
|
||||
|
||||
def __call__(self, attn, x, pe, **attention_kwargs):
|
||||
qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = attn.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
|
||||
print('1' * 30)
|
||||
print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
|
||||
return x
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
def forward():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
class DoubleStreamBlockLoraProcessor(nn.Module):
|
||||
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
||||
super().__init__()
|
||||
self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
||||
self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
||||
self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
||||
self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
||||
self.lora_weight = lora_weight
|
||||
|
||||
def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
|
||||
img_mod1, img_mod2 = attn.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = attn.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = attn.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn1 = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
|
||||
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
|
||||
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
class DoubleStreamBlockProcessor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
|
||||
|
||||
img_mod1, img_mod2 = attn.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = attn.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = attn.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = attn.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn1 = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
processor = DoubleStreamBlockProcessor()
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self):
|
||||
return self.processor
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
return self.processor(self, img, txt, vec, pe)
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
248
custom_nodes/x-flux-comfyui/xflux/src/flux/sampling.py
Normal file
248
custom_nodes/x-flux-comfyui/xflux/src/flux/sampling.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
|
||||
from .model import Flux
|
||||
from .modules.conditioner import HFEmbedder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
|
||||
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(
|
||||
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
neg_txt: Tensor,
|
||||
neg_txt_ids: Tensor,
|
||||
neg_vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
true_gs = 1,
|
||||
timestep_to_start_cfg=0,
|
||||
image2image_strength=None,
|
||||
orig_image = None,
|
||||
):
|
||||
i = 0
|
||||
|
||||
#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if image2image_strength is not None and orig_image is not None:
|
||||
t_idx = int((1 - image2image_strength) * len(timesteps))
|
||||
t = timesteps[t_idx]
|
||||
timesteps = timesteps[t_idx:]
|
||||
img = t * img + (1.0 - t) * orig_image.to(img.dtype)
|
||||
# this is ignored for schnell
|
||||
if hasattr(model, "guidance_in"):
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
guidance_vec = None
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
if i >= timestep_to_start_cfg:
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
pred = neg_pred + true_gs * (pred - neg_pred)
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
i += 1
|
||||
return img
|
||||
|
||||
def denoise_controlnet(
|
||||
model: Flux,
|
||||
controlnet:None,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
neg_txt: Tensor,
|
||||
neg_txt_ids: Tensor,
|
||||
neg_vec: Tensor,
|
||||
controlnet_cond,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
true_gs = 1,
|
||||
controlnet_gs=0.7,
|
||||
timestep_to_start_cfg=0,
|
||||
image2image_strength=None,
|
||||
orig_image = None,
|
||||
):
|
||||
i = 0
|
||||
|
||||
#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if image2image_strength is not None and orig_image is not None:
|
||||
t_idx = int((1 - image2image_strength) * len(timesteps))
|
||||
t = timesteps[t_idx]
|
||||
timesteps = timesteps[t_idx:]
|
||||
img = t * img + (1.0 - t) * orig_image.to(img.dtype)
|
||||
# this is ignored for schnell
|
||||
if hasattr(model, "guidance_in"):
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
guidance_vec = None
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
block_res_samples = controlnet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=controlnet_cond,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples]
|
||||
)
|
||||
if i >= timestep_to_start_cfg:
|
||||
neg_block_res_samples = controlnet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=controlnet_cond,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples]
|
||||
)
|
||||
pred = neg_pred + true_gs * (pred - neg_pred)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
i += 1
|
||||
return img
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
350
custom_nodes/x-flux-comfyui/xflux/src/flux/util.py
Normal file
350
custom_nodes/x-flux-comfyui/xflux/src/flux/util.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file as load_sft
|
||||
|
||||
from .model import Flux, FluxParams
|
||||
from .controlnet import ControlNetFlux
|
||||
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .modules.conditioner import HFEmbedder
|
||||
|
||||
|
||||
|
||||
def load_safetensors(path):
|
||||
tensors = {}
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
return tensors
|
||||
|
||||
def get_lora_rank(checkpoint):
|
||||
for k in checkpoint.keys():
|
||||
if k.endswith(".down.weight"):
|
||||
return checkpoint[k].shape[0]
|
||||
|
||||
def load_checkpoint(local_path, repo_id, name):
|
||||
if local_path is not None:
|
||||
if '.safetensors' in local_path:
|
||||
print("Loading .safetensors checkpoint...")
|
||||
checkpoint = load_safetensors(local_path)
|
||||
else:
|
||||
print("Loading checkpoint...")
|
||||
checkpoint = torch.load(local_path, map_location='cpu')
|
||||
elif repo_id is not None and name is not None:
|
||||
print("Loading checkpoint from repo id...")
|
||||
checkpoint = load_from_repo_id(repo_id, name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def c_crop(image):
|
||||
width, height = image.size
|
||||
new_size = min(width, height)
|
||||
left = (width - new_size) / 2
|
||||
top = (height - new_size) / 2
|
||||
right = (width + new_size) / 2
|
||||
bottom = (height + new_size) / 2
|
||||
return image.crop((left, top, right, bottom))
|
||||
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, name: str, device: str):
|
||||
if name == "canny":
|
||||
processor = CannyDetector()
|
||||
elif name == "openpose":
|
||||
processor = DWposeDetector(device)
|
||||
elif name == "depth":
|
||||
processor = MidasDetector()
|
||||
elif name == "hed":
|
||||
processor = HEDdetector()
|
||||
elif name == "hough":
|
||||
processor = MLSDdetector()
|
||||
elif name == "tile":
|
||||
processor = TileDetector()
|
||||
self.name = name
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, image: Image, width: int, height: int):
|
||||
image = c_crop(image)
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image)
|
||||
if self.name == "canny":
|
||||
result = self.processor(image, low_threshold=100, high_threshold=200)
|
||||
elif self.name == "hough":
|
||||
result = self.processor(image, thr_v=0.05, thr_d=5)
|
||||
elif self.name == "depth":
|
||||
result = self.processor(image)
|
||||
result, _ = result
|
||||
else:
|
||||
result = self.processor(image)
|
||||
|
||||
if result.ndim != 3:
|
||||
result = result[:, :, None]
|
||||
result = np.concatenate([result, result, result], axis=2)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
repo_id_ae: str | None
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-dev.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-fp8": ModelSpec(
|
||||
repo_id="XLabs-AI/flux-dev-fp8",
|
||||
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux-dev-fp8.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV_FP8"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-schnell.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
||||
if len(missing) > 0 and len(unexpected) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
print("\n" + "-" * 79 + "\n")
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
elif len(missing) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
elif len(unexpected) > 0:
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
|
||||
def load_from_repo_id(repo_id, checkpoint_name):
|
||||
ckpt_path = hf_hub_download(repo_id, checkpoint_name)
|
||||
sd = load_sft(ckpt_path, device='cpu')
|
||||
return sd
|
||||
|
||||
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
||||
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
model = Flux(configs[name].params).to(torch.bfloat16)
|
||||
|
||||
if ckpt_path is not None:
|
||||
print("Loading checkpoint")
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return model
|
||||
|
||||
def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
||||
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
model = Flux(configs[name].params)
|
||||
|
||||
if ckpt_path is not None:
|
||||
print("Loading checkpoint")
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return model
|
||||
|
||||
def load_controlnet(name, device, transformer=None):
|
||||
with torch.device(device):
|
||||
controlnet = ControlNetFlux(configs[name].params)
|
||||
if transformer is not None:
|
||||
controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||
return controlnet
|
||||
|
||||
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
||||
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
||||
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
||||
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
|
||||
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
||||
ckpt_path = configs[name].ae_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_ae is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
|
||||
|
||||
# Loading the autoencoder
|
||||
print("Init AE")
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
ae = AutoEncoder(configs[name].ae_params)
|
||||
|
||||
if ckpt_path is not None:
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return ae
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
image = 0.5 * image + 0.5
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
||||
image.device
|
||||
)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
image = 2 * image - 1
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
152
custom_nodes/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py
Normal file
152
custom_nodes/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from src.flux.modules.layers import DoubleStreamBlockLoraProcessor
|
||||
from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack
|
||||
from src.flux.util import (load_ae, load_clip, load_flow_model, load_t5, load_controlnet,
|
||||
load_flow_model_quintized, Annotator, get_lora_rank, load_checkpoint)
|
||||
|
||||
|
||||
class XFluxPipeline:
|
||||
def __init__(self, model_type, device, offload: bool = False, seed: int = None):
|
||||
self.device = torch.device(device)
|
||||
self.offload = offload
|
||||
self.seed = seed
|
||||
self.model_type = model_type
|
||||
|
||||
self.clip = load_clip(self.device)
|
||||
self.t5 = load_t5(self.device, max_length=512)
|
||||
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
|
||||
if "fp8" in model_type:
|
||||
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
|
||||
else:
|
||||
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
|
||||
|
||||
self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
|
||||
self.lora_types_to_names = {
|
||||
"realism": "lora.safetensors",
|
||||
}
|
||||
self.controlnet_loaded = False
|
||||
|
||||
def set_lora(self, local_path: str = None, repo_id: str = None,
|
||||
name: str = None, lora_weight: int = 0.7):
|
||||
checkpoint = load_checkpoint(local_path, repo_id, name)
|
||||
self.update_model_with_lora(checkpoint, lora_weight)
|
||||
|
||||
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
|
||||
checkpoint = load_checkpoint(
|
||||
None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
|
||||
)
|
||||
self.update_model_with_lora(checkpoint, lora_weight)
|
||||
|
||||
def update_model_with_lora(self, checkpoint, lora_weight):
|
||||
rank = get_lora_rank(checkpoint)
|
||||
lora_attn_procs = {}
|
||||
|
||||
for name, _ in self.model.attn_processors.items():
|
||||
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
|
||||
lora_state_dict = {}
|
||||
for k in checkpoint.keys():
|
||||
if name in k:
|
||||
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
|
||||
lora_attn_procs[name].load_state_dict(lora_state_dict)
|
||||
lora_attn_procs[name].to(self.device)
|
||||
|
||||
self.model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
|
||||
self.model.to(self.device)
|
||||
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
|
||||
|
||||
checkpoint = load_checkpoint(local_path, repo_id, name)
|
||||
self.controlnet.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if control_type == "depth":
|
||||
self.controlnet_gs = 0.9
|
||||
else:
|
||||
self.controlnet_gs = 0.7
|
||||
self.annotator = Annotator(control_type, self.device)
|
||||
self.controlnet_loaded = True
|
||||
|
||||
def __call__(self,
|
||||
prompt: str,
|
||||
controlnet_image: Image = None,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
guidance: float = 4,
|
||||
num_steps: int = 50,
|
||||
true_gs = 3,
|
||||
neg_prompt: str = '',
|
||||
timestep_to_start_cfg: int = 0,
|
||||
):
|
||||
width = 16 * width // 16
|
||||
height = 16 * height // 16
|
||||
if self.controlnet_loaded:
|
||||
controlnet_image = self.annotator(controlnet_image, width, height)
|
||||
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
|
||||
|
||||
return self.forward(prompt, width, height, guidance, num_steps, controlnet_image,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg, true_gs=true_gs, neg_prompt=neg_prompt)
|
||||
|
||||
def forward(self, prompt, width, height, guidance, num_steps, controlnet_image=None, timestep_to_start_cfg=0, true_gs=3, neg_prompt=""):
|
||||
x = get_noise(
|
||||
1, height, width, device=self.device,
|
||||
dtype=torch.bfloat16, seed=self.seed
|
||||
)
|
||||
timesteps = get_schedule(
|
||||
num_steps,
|
||||
(width // 8) * (height // 8) // (16 * 16),
|
||||
shift=True,
|
||||
)
|
||||
torch.manual_seed(self.seed)
|
||||
with torch.no_grad():
|
||||
if self.offload:
|
||||
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
||||
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
|
||||
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
|
||||
|
||||
if self.offload:
|
||||
self.offload_model_to_cpu(self.t5, self.clip)
|
||||
self.model = self.model.to(self.device)
|
||||
if self.controlnet_loaded:
|
||||
x = denoise_controlnet(
|
||||
self.model, **inp_cond, controlnet=self.controlnet,
|
||||
timesteps=timesteps, guidance=guidance,
|
||||
controlnet_cond=controlnet_image,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg,
|
||||
neg_txt=neg_inp_cond['txt'],
|
||||
neg_txt_ids=neg_inp_cond['txt_ids'],
|
||||
neg_vec=neg_inp_cond['vec'],
|
||||
true_gs=true_gs,
|
||||
controlnet_gs=self.controlnet_gs,
|
||||
)
|
||||
else:
|
||||
x = denoise(self.model, **inp_cond, timesteps=timesteps, guidance=guidance,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg,
|
||||
neg_txt=neg_inp_cond['txt'],
|
||||
neg_txt_ids=neg_inp_cond['txt_ids'],
|
||||
neg_vec=neg_inp_cond['vec'],
|
||||
true_gs=true_gs
|
||||
)
|
||||
|
||||
if self.offload:
|
||||
self.offload_model_to_cpu(self.model)
|
||||
self.ae.decoder.to(x.device)
|
||||
x = unpack(x.float(), height, width)
|
||||
x = self.ae.decode(x)
|
||||
self.offload_model_to_cpu(self.ae.decoder)
|
||||
|
||||
x1 = x.clamp(-1, 1)
|
||||
x1 = rearrange(x1[-1], "c h w -> h w c")
|
||||
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
|
||||
return output_img
|
||||
|
||||
def offload_model_to_cpu(self, *models):
|
||||
if not self.offload: return
|
||||
for model in models:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user