Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
6
custom_nodes/ComfyUI-Easy-Use/py/__init__.py
Normal file
6
custom_nodes/ComfyUI-Easy-Use/py/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .libs.loader import easyLoader
|
||||
from .libs.sampler import easySampler
|
||||
|
||||
sampler = easySampler()
|
||||
easyCache = easyLoader()
|
||||
|
||||
407
custom_nodes/ComfyUI-Easy-Use/py/config.py
Normal file
407
custom_nodes/ComfyUI-Easy-Use/py/config.py
Normal file
@@ -0,0 +1,407 @@
|
||||
import os
|
||||
import folder_paths
|
||||
from pathlib import Path
|
||||
|
||||
BASE_RESOLUTIONS = [
|
||||
("width", "height"),
|
||||
(512, 512),
|
||||
(512, 768),
|
||||
(576, 1024),
|
||||
(768, 512),
|
||||
(768, 768),
|
||||
(768, 1024),
|
||||
(768, 1280),
|
||||
(768, 1344),
|
||||
(768, 1536),
|
||||
(816, 1920),
|
||||
(832, 1152),
|
||||
(832, 1216),
|
||||
(896, 1152),
|
||||
(896, 1088),
|
||||
(1024, 1024),
|
||||
(1024, 576),
|
||||
(1024, 768),
|
||||
(1080, 1920),
|
||||
(1440, 2560),
|
||||
(1088, 896),
|
||||
(1216, 832),
|
||||
(1152, 832),
|
||||
(1152, 896),
|
||||
(1280, 768),
|
||||
(1344, 768),
|
||||
(1536, 640),
|
||||
(1536, 768),
|
||||
(1920, 816),
|
||||
(1920, 1080),
|
||||
(2560, 1440),
|
||||
]
|
||||
MAX_SEED_NUM = 1125899906842624
|
||||
|
||||
|
||||
RESOURCES_DIR = os.path.join(Path(__file__).parent.parent, "resources")
|
||||
|
||||
# inpaint
|
||||
INPAINT_DIR = os.path.join(folder_paths.models_dir, "inpaint")
|
||||
FOOOCUS_STYLES_DIR = os.path.join(Path(__file__).parent.parent, "styles")
|
||||
FOOOCUS_STYLES_SAMPLES = 'https://raw.githubusercontent.com/lllyasviel/Fooocus/main/sdxl_styles/samples/'
|
||||
FOOOCUS_INPAINT_HEAD = {
|
||||
"fooocus_inpaint_head": {
|
||||
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth"
|
||||
}
|
||||
}
|
||||
FOOOCUS_INPAINT_PATCH = {
|
||||
"inpaint_v26 (1.32GB)": {
|
||||
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v26.fooocus.patch"
|
||||
},
|
||||
"inpaint_v25 (2.58GB)": {
|
||||
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch"
|
||||
},
|
||||
"inpaint (1.32GB)": {
|
||||
"model_url": "https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch"
|
||||
},
|
||||
}
|
||||
BRUSHNET_MODELS = {
|
||||
"random_mask": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/Kijai/BrushNet-fp16/resolve/main/brushnet_random_mask_fp16.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/yolain/brushnet/resolve/main/brushnet_random_mask_sdxl.safetensors"
|
||||
}
|
||||
},
|
||||
"segmentation_mask": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/Kijai/BrushNet-fp16/resolve/main/brushnet_segmentation_mask_fp16.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/yolain/brushnet/resolve/main/brushnet_segmentation_mask_sdxl.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
POWERPAINT_MODELS = {
|
||||
"base_fp16": {
|
||||
"model_url": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/model.fp16.safetensors"
|
||||
},
|
||||
"v2.1": {
|
||||
"model_url": "https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/resolve/main/PowerPaint_Brushnet/diffusion_pytorch_model.safetensors",
|
||||
"clip_url": "https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/resolve/main/PowerPaint_Brushnet/pytorch_model.bin",
|
||||
}
|
||||
}
|
||||
|
||||
# layerDiffuse
|
||||
LAYER_DIFFUSION_DIR = os.path.join(folder_paths.models_dir, "layer_model")
|
||||
LAYER_DIFFUSION_VAE = {
|
||||
"encode": {
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_encoder.safetensors"
|
||||
}
|
||||
},
|
||||
"decode": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
LAYER_DIFFUSION = {
|
||||
"Attention Injection": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_transparent_attn.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_attn.safetensors"
|
||||
},
|
||||
},
|
||||
"Conv Injection": {
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_transparent_conv.safetensors"
|
||||
},
|
||||
"sd1": {
|
||||
"model_url": None
|
||||
}
|
||||
},
|
||||
"Everything": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_joint.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": None
|
||||
}
|
||||
},
|
||||
"Foreground": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fg2ble.safetensors"
|
||||
}
|
||||
},
|
||||
"Foreground to Background": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_fg2bg.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_fgble2bg.safetensors"
|
||||
}
|
||||
},
|
||||
"Background": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bg2ble.safetensors"
|
||||
}
|
||||
},
|
||||
"Background to Foreground": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_bg2fg.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_xl_bgble2fg.safetensors"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# IC Light
|
||||
IC_LIGHT_MODELS = {
|
||||
"Foreground": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/huchenlei/IC-Light-ldm/resolve/main/iclight_sd15_fc_unet_ldm.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": None
|
||||
}
|
||||
},
|
||||
"Foreground&Background": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/huchenlei/IC-Light-ldm/resolve/main/iclight_sd15_fbc_unet_ldm.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# REMBG
|
||||
REMBG_DIR = os.path.join(folder_paths.models_dir, "rembg")
|
||||
REMBG_MODELS = {
|
||||
"RMBG-1.4": {
|
||||
"model_url": "https://huggingface.co/briaai/RMBG-1.4/resolve/main/model.pth"
|
||||
},
|
||||
"RMBG-2.0": {
|
||||
"model_url": "briaai/RMBG-2.0"
|
||||
},
|
||||
"BEN2": {
|
||||
"model_url": "https://huggingface.co/PramaLLC/BEN2/resolve/main/BEN2_Base.pth"
|
||||
}
|
||||
}
|
||||
|
||||
#ipadapter
|
||||
IPADAPTER_DIR = os.path.join(folder_paths.models_dir, "ipadapter")
|
||||
IPADAPTER_MODELS = {
|
||||
"LIGHT - SD1.5 only (low strength)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15_light_v11.bin"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": ""
|
||||
}
|
||||
},
|
||||
"STANDARD (medium strength)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.safetensors"
|
||||
}
|
||||
},
|
||||
"VIT-G (medium strength)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15_vit-G.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl.safetensors"
|
||||
}
|
||||
},
|
||||
"PLUS (high strength)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.safetensors"
|
||||
}
|
||||
},
|
||||
"PLUS (kolors genernal)": {
|
||||
"sd1": {
|
||||
"model_url": ""
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url":"https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus/resolve/main/ip_adapter_plus_general.bin"
|
||||
}
|
||||
},
|
||||
"REGULAR - FLUX and SD3.5 only (high strength)": {
|
||||
"flux": {
|
||||
"model_url": "https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/resolve/main/ip-adapter.bin",
|
||||
"model_file_name": "ip-adapter_flux_1_dev.bin",
|
||||
},
|
||||
"sd3": {
|
||||
"model_url": "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin",
|
||||
"model_file_name": "ip-adapter_sd35.bin",
|
||||
},
|
||||
},
|
||||
"PLUS FACE (portraits)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus-face_sdxl_vit-h.safetensors"
|
||||
}
|
||||
},
|
||||
"FULL FACE - SD1.5 only (portraits stronger)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-full-face_sd15.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": ""
|
||||
}
|
||||
},
|
||||
"FACEID": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sd15.bin",
|
||||
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sd15_lora.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sdxl.bin",
|
||||
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid_sdxl_lora.safetensors"
|
||||
}
|
||||
},
|
||||
"FACEID PLUS - SD1.5 only": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plus_sd15.bin",
|
||||
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plus_sd15_lora.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "",
|
||||
"lora_url": ""
|
||||
}
|
||||
},
|
||||
"FACEID PLUS V2": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15.bin",
|
||||
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sdxl.bin",
|
||||
"lora_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sdxl_lora.safetensors"
|
||||
}
|
||||
},
|
||||
"FACEID PLUS KOLORS":{
|
||||
"sd1":{
|
||||
|
||||
},
|
||||
"sdxl":{
|
||||
"model_url":"https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus/resolve/main/ipa-faceid-plus.bin"
|
||||
}
|
||||
},
|
||||
"FACEID PORTRAIT (style transfer)": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-portrait-v11_sd15.bin",
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-portrait_sdxl.bin",
|
||||
}
|
||||
},
|
||||
"FACEID PORTRAIT UNNORM - SDXL only (strong)": {
|
||||
"sd1": {
|
||||
"model_url":""
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-portrait_sdxl_unnorm.bin",
|
||||
}
|
||||
},
|
||||
"COMPOSITION": {
|
||||
"sd1": {
|
||||
"model_url": "https://huggingface.co/ostris/ip-composition-adapter/resolve/main/ip_plus_composition_sd15.safetensors"
|
||||
},
|
||||
"sdxl": {
|
||||
"model_url": "https://huggingface.co/ostris/ip-composition-adapter/resolve/main/ip_plus_composition_sdxl.safetensors"
|
||||
}
|
||||
}
|
||||
}
|
||||
IPADAPTER_CLIPVISION_MODELS = {
|
||||
"clip-vit-large-patch14-336":{
|
||||
"model_url": "https://huggingface.co/openai/clip-vit-large-patch14-336/resolve/main/pytorch_model.bin"
|
||||
},
|
||||
"clip-vit-h-14-laion2B-s32B-b79K":{
|
||||
"model_url": "https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_model.safetensors"
|
||||
},
|
||||
"sigclip_vision_patch14_384":{
|
||||
"model_url": "https://huggingface.co/Comfy-Org/sigclip_vision_384/resolve/main/sigclip_vision_patch14_384.safetensors"
|
||||
}
|
||||
}
|
||||
|
||||
# dynamiCrafter
|
||||
DYNAMICRAFTER_DIR = os.path.join(folder_paths.models_dir, "dynamicrafter_models")
|
||||
DYNAMICRAFTER_MODELS = {
|
||||
"dynamicrafter_unet_512 (2.98GB)": {
|
||||
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_512.safetensors",
|
||||
"vae_url": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors",
|
||||
"clip_url": "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/text_encoder/model.safetensors",
|
||||
"clip_vision_url": "https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_model.safetensors",
|
||||
},
|
||||
"dynamicrafter_unet_512_interp (2.98GB)": {
|
||||
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_512_interp.safetensors"
|
||||
},
|
||||
"dynamicrafter_unet_1024 (2.98GB)": {
|
||||
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_1024.safetensors"
|
||||
},
|
||||
"dynamicrafter_unet_256 (2.98GB)": {
|
||||
"model_url": "https://huggingface.co/ExponentialML/DynamiCrafterUNet/resolve/main/dynamicrafter_unet_256.safetensors"
|
||||
},
|
||||
}
|
||||
|
||||
#humanParsing
|
||||
HUMANPARSING_MODELS = {
|
||||
"parsing_lip": {
|
||||
"model_url": "https://huggingface.co/levihsu/OOTDiffusion/resolve/main/checkpoints/humanparsing/parsing_lip.onnx",
|
||||
},
|
||||
"human-parts":{
|
||||
"model_url":"https://huggingface.co/Metal3d/deeplabv3p-resnet50-human/resolve/main/deeplabv3p-resnet50-human.onnx",
|
||||
},
|
||||
"segformer_b3_clothes":{
|
||||
"model_name": "sayeed99/segformer_b3_clothes",
|
||||
},
|
||||
"segformer_b3_fashion":{
|
||||
"model_name": "sayeed99/segformer-b3-fashion",
|
||||
},
|
||||
"face_parsing":{
|
||||
"model_name": "jonathandinu/face-parsing"
|
||||
}
|
||||
}
|
||||
|
||||
#mediapipe
|
||||
MEDIAPIPE_DIR = os.path.join(folder_paths.models_dir, "mediapipe")
|
||||
MEDIAPIPE_MODELS = {
|
||||
"selfie_multiclass_256x256": {
|
||||
"model_url": "https://huggingface.co/yolain/selfie_multiclass_256x256/resolve/main/selfie_multiclass_256x256.tflite"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#prompt template
|
||||
PROMPT_TEMPLATE = {
|
||||
"prefix": ["Detailed photo of", "Amateur photo of", "Flicker 2008 photo of", "Fantastic artwork of",
|
||||
"Vintage photograph of", "Unreal 5 render of", "Surrealist painting of",
|
||||
"Professional advertising design of"],
|
||||
"subject": ["a man", "a woman", "a young man", "a young woman", "a handsome man", "a beautiful woman", "a monster", "a toy", "a product", "a buddha", "a dog", "a cat"],
|
||||
"action": ["looking at viewer", "looking away", "looking up", "looking down", "looking back", "open mouth", "half-closed mouth", "closed mouth", "open eyes", "half-closed eyes", "closed eyes", "wink", "standing", "sitting", "lying", "walking", "running", "adjusting hair", "waving", "hand on hip", "crossed arms", "smile", "sad", "angry", "sleepy", "tired", "expressionless"],
|
||||
"clothes": ["underwear", "clothed", "casual", "dress", "swimsuit", "uniform", "bikini", "one-piece swimsuit", "shirt", "blouse", "sweater", "hoodie", "jeans", "pants", "shorts", "skirt", "vest", "coat", "trenchoat", "jacket", "short dress", "long dress", "off-shoulder", "backless", "hairbow", "hair ribbon", "hair tie", "hairband", "cap", "beanie", "bucket hat", "sun hat", "straw hat", "rice hat", "witch hat", "crown", "chain necklace", "tooth necklace", "choker", "pendant", "bracelet", "watch", "ring", "earring", "anklet", "belt", "scarf", "gloves", "mittens", "socks", "stockings", "tights", "leggings", "boots", "sneakers", "heels", "sandals", "flip-flops", "slippers", "loafers", "mules", "oxfords", "brogues", "derbies", "monk shoes", "chelsea boots", "combat boots", "riding boots", "rain boots", "wedge heels", "platform heels", "stilettos", "block heels", "kitten heels", "moccasins", "espadrilles", "pumps", "flats", "ballet flats", "mary janes", "slingbacks", "peep-toe", "mule sandals", "gladiator sandals", "thong sandals", "slide sandals", "espadrille sandals", "wedge sandals", "platform sandals", "ankle boots", "knee-high boots", "over-the-knee boots", "thigh-high boots", "wellington boots", "chukka boots", "desert boots", "chelsea boots", "hiking boots", "work boots", "snow boots", "rain boots", "riding boots", "cowboy boots", "combat boots", "biker boots", "duck boots", "military boots", "western boots", "ankle strap heels", "block heels", "chunky heels", "cone heels", "kitten heels", "platform heels", "pumps", "slingback heels", "stiletto heels", "wedge heels", "mules", "slingbacks", "slides", "thong sandals", "gladiator sandals", "espadrilles", "wedge sandals", "platform sandals", "ankle boots", "knee-high boots", "over-the-knee boots", "thigh-high boots", "wellington boots", "chukka boots", "desert boots", "chelsea boots", "hiking boots", "work boots", "snow boots", "rain boots", "riding boots", "cowboy boots", "combat boots", "biker boots", "duck boots", "military boots", "western boots", "ankle strap heels", "block heels" ],
|
||||
"environment": ["sunshine from window", "neon night, city", "sunset over sea", "golden time", "sci-fi RGB glowing, cyberpunk", "natural lighting", "warm atmosphere, at home, bedroom", "magic lit", "evil, gothic, in a cave", "light and shadow", "shadow from window", "soft studio lighting", "home atmosphere, cozy bedroom illumination", "neon, Wong Kar-wai, warm", "moonlight through curtains", "stormy sky lighting", "underwater glow, deep sea", "foggy forest at dawn", "golden hour in a meadow", "rainbow reflections, neon", "cozy candlelight", "apocalyptic, smoky atmosphere", "red glow, emergency lights", "mystical glow, enchanted forest", "campfire light", "harsh, industrial lighting", "sunrise in the mountains", "evening glow in the desert", "moonlight in a dark alley", "golden glow at a fairground", "midnight in the forest", "purple and pink hues at twilight", "foggy morning, muted light", "candle-lit room, rustic vibe", "fluorescent office lighting", "lightning flash in storm", "night, cozy warm light from fireplace", "ethereal glow, magical forest", "dusky evening on a beach", "afternoon light filtering through trees", "blue neon light, urban street", "red and blue police lights in rain", "aurora borealis glow, arctic landscape", "sunrise through foggy mountains", "golden hour on a city skyline", "mysterious twilight, heavy mist", "early morning rays, forest clearing", "colorful lantern light at festival", "soft glow through stained glass", "harsh spotlight in dark room", "mellow evening glow on a lake", "crystal reflections in a cave", "vibrant autumn lighting in a forest", "gentle snowfall at dusk", "hazy light of a winter morning", "soft, diffused foggy glow", "underwater luminescence", "rain-soaked reflections in city lights", "golden sunlight streaming through trees", "fireflies lighting up a summer night", "glowing embers from a forge", "dim candlelight in a gothic castle", "midnight sky with bright starlight", "warm sunset in a rural village", "flickering light in a haunted house", "desert sunset with mirage-like glow", "golden beams piercing through storm clouds"],
|
||||
"background": ["cars and people", "a cozy bed and a lamp", "a forest clearing with mist", "a bustling marketplace", "a quiet beach at dusk", "an old, cobblestone street", "a futuristic cityscape", "a tranquil lake with mountains", "a mysterious cave entrance", "bookshelves and plants in the background", "an ancient temple in ruins", "tall skyscrapers and neon signs", "a starry sky over a desert", "a bustling café", "rolling hills and farmland", "a modern living room with a fireplace", "an abandoned warehouse", "a picturesque mountain range", "a starry night sky", "the interior of a futuristic spaceship", "the cluttered workshop of an inventor", "the glowing embers of a bonfire", "a misty lake surrounded by trees", "an ornate palace hall", "a busy street market", "a vast desert landscape", "a peaceful library corner", "bustling train station", "a mystical, enchanted forest", "an underwater reef with colorful fish", "a quiet rural village", "a sandy beach with palm trees", "a vibrant coral reef, teeming with life", "snow-capped mountains in distance", "a stormy ocean, waves crashing", "a rustic barn in open fields", "a futuristic lab with glowing screens", "a dark, abandoned castle", "the ruins of an ancient civilization", "a bustling urban street in rain", "an elegant grand ballroom", "a sprawling field of wildflowers", "a dense jungle with sunlight filtering through", "a dimly lit, vintage bar", "an ice cave with sparkling crystals", "a serene riverbank at sunset", "a narrow alley with graffiti walls", "a peaceful zen garden with koi pond", "a high-tech control room", "a quiet mountain village at dawn", "a lighthouse on a rocky coast", "a rainy street with flickering lights", "a frozen lake with ice formations", "an abandoned theme park", "a small fishing village on a pier", "rolling sand dunes in a desert", "a dense forest with towering redwoods", "a snowy cabin in the mountains", "a mystical cave with bioluminescent plants", "a castle courtyard under moonlight", "a bustling open-air night market", "an old train station with steam", "a tranquil waterfall surrounded by trees", "a vineyard in the countryside", "a quaint medieval village", "a bustling harbor with boats", "a high-tech futuristic mall", "a lush tropical rainforest"],
|
||||
"nsfw": ["nude", "breast", "small breast", "middle breast", "large breast", "nipples", "clothes lift", "pussy juice trail", "pussy juice puddle", "small testicles", "medium testicles", "large testicles", "disembodied penis", "cum on body", "cum inside", "cum outside", "fingering", "handjob", "fellatio", "licking penis", "paizuri", "doggystyle", "cowgirl", "reversed cowgirl", "piledriver", "suspended congress", "full nelson",],
|
||||
}
|
||||
|
||||
NEW_SCHEDULERS = ['align_your_steps', 'gits']
|
||||
113
custom_nodes/ComfyUI-Easy-Use/py/libs/add_resources.py
Normal file
113
custom_nodes/ComfyUI-Easy-Use/py/libs/add_resources.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import urllib.parse
|
||||
from os import PathLike
|
||||
from aiohttp import web
|
||||
from aiohttp.web_urldispatcher import AbstractRoute, UrlDispatcher
|
||||
from server import PromptServer
|
||||
from pathlib import Path
|
||||
|
||||
# 文件限制大小(MB)
|
||||
max_size = 50
|
||||
def suffix_limiter(self: web.StaticResource, request: web.Request):
|
||||
suffixes = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"}
|
||||
rel_url = request.match_info["filename"]
|
||||
try:
|
||||
filename = Path(rel_url)
|
||||
if filename.anchor:
|
||||
raise web.HTTPForbidden()
|
||||
filepath = self._directory.joinpath(filename).resolve()
|
||||
if filepath.exists() and filepath.suffix.lower() not in suffixes:
|
||||
raise web.HTTPForbidden(reason="File type is not allowed")
|
||||
finally:
|
||||
pass
|
||||
|
||||
def filesize_limiter(self: web.StaticResource, request: web.Request):
|
||||
rel_url = request.match_info["filename"]
|
||||
try:
|
||||
filename = Path(rel_url)
|
||||
filepath = self._directory.joinpath(filename).resolve()
|
||||
if filepath.exists() and filepath.stat().st_size > max_size * 1024 * 1024:
|
||||
raise web.HTTPForbidden(reason="File size is too large")
|
||||
finally:
|
||||
pass
|
||||
class LimitResource(web.StaticResource):
|
||||
limiters = []
|
||||
|
||||
def push_limiter(self, limiter):
|
||||
self.limiters.append(limiter)
|
||||
|
||||
async def _handle(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
for limiter in self.limiters:
|
||||
limiter(self, request)
|
||||
except (ValueError, FileNotFoundError) as error:
|
||||
raise web.HTTPNotFound() from error
|
||||
|
||||
return await super()._handle(request)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
name = "'" + self.name + "'" if self.name is not None else ""
|
||||
return f'<LimitResource {name} {self._prefix} -> {self._directory!r}>'
|
||||
|
||||
class LimitRouter(web.StaticDef):
|
||||
def __repr__(self) -> str:
|
||||
info = []
|
||||
for name, value in sorted(self.kwargs.items()):
|
||||
info.append(f", {name}={value!r}")
|
||||
return f'<LimitRouter {self.prefix} -> {self.path}{"".join(info)}>'
|
||||
|
||||
def register(self, router: UrlDispatcher) -> list[AbstractRoute]:
|
||||
# resource = router.add_static(self.prefix, self.path, **self.kwargs)
|
||||
def add_static(
|
||||
self: UrlDispatcher,
|
||||
prefix: str,
|
||||
path: PathLike,
|
||||
*,
|
||||
name=None,
|
||||
expect_handler=None,
|
||||
chunk_size: int = 256 * 1024,
|
||||
show_index: bool = False,
|
||||
follow_symlinks: bool = False,
|
||||
append_version: bool = False,
|
||||
) -> web.AbstractResource:
|
||||
assert prefix.startswith("/")
|
||||
if prefix.endswith("/"):
|
||||
prefix = prefix[:-1]
|
||||
resource = LimitResource(
|
||||
prefix,
|
||||
path,
|
||||
name=name,
|
||||
expect_handler=expect_handler,
|
||||
chunk_size=chunk_size,
|
||||
show_index=show_index,
|
||||
follow_symlinks=follow_symlinks,
|
||||
append_version=append_version,
|
||||
)
|
||||
resource.push_limiter(suffix_limiter)
|
||||
resource.push_limiter(filesize_limiter)
|
||||
self.register_resource(resource)
|
||||
return resource
|
||||
resource = add_static(router, self.prefix, self.path, **self.kwargs)
|
||||
routes = resource.get_info().get("routes", {})
|
||||
return list(routes.values())
|
||||
|
||||
def path_to_url(path):
|
||||
if not path:
|
||||
return path
|
||||
path = path.replace("\\", "/")
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
while path.startswith("//"):
|
||||
path = path[1:]
|
||||
path = path.replace("//", "/")
|
||||
return path
|
||||
|
||||
def add_static_resource(prefix, path,limit=False):
|
||||
app = PromptServer.instance.app
|
||||
prefix = path_to_url(prefix)
|
||||
prefix = urllib.parse.quote(prefix)
|
||||
prefix = path_to_url(prefix)
|
||||
if limit:
|
||||
route = LimitRouter(prefix, path, {"follow_symlinks": True})
|
||||
else:
|
||||
route = web.static(prefix, path, follow_symlinks=True)
|
||||
app.add_routes([route])
|
||||
427
custom_nodes/ComfyUI-Easy-Use/py/libs/adv_encode.py
Normal file
427
custom_nodes/ComfyUI-Easy-Use/py/libs/adv_encode.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import re
|
||||
import itertools
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG
|
||||
try:
|
||||
from comfy.text_encoders.sd3_clip import SD3ClipModel, T5XXLModel
|
||||
except ImportError:
|
||||
from comfy.sd3_clip import SD3ClipModel, T5XXLModel
|
||||
|
||||
from nodes import NODE_CLASS_MAPPINGS, ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
|
||||
|
||||
def _grouper(n, iterable):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def _norm_mag(w, n):
|
||||
d = w - 1
|
||||
return 1 + np.sign(d) * np.sqrt(np.abs(d) ** 2 / n)
|
||||
# return np.sign(w) * np.sqrt(np.abs(w)**2 / n)
|
||||
|
||||
|
||||
def divide_length(word_ids, weights):
|
||||
sums = dict(zip(*np.unique(word_ids, return_counts=True)))
|
||||
sums[0] = 1
|
||||
weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
|
||||
def shift_mean_weight(word_ids, weights):
|
||||
delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x, y) if id != 0])
|
||||
weights = [[w if id == 0 else w + delta
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
|
||||
def scale_to_norm(weights, word_ids, w_max):
|
||||
top = np.max(weights)
|
||||
w_max = min(top, w_max)
|
||||
weights = [[w_max if id == 0 else (w / top) * w_max
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
|
||||
def from_zero(weights, base_emb):
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
|
||||
return base_emb * weight_tensor
|
||||
|
||||
|
||||
def mask_word_id(tokens, word_ids, target_id, mask_token):
|
||||
new_tokens = [[mask_token if wid == target_id else t
|
||||
for t, wid in zip(x, y)] for x, y in zip(tokens, word_ids)]
|
||||
mask = np.array(word_ids) == target_id
|
||||
return (new_tokens, mask)
|
||||
|
||||
|
||||
def batched_clip_encode(tokens, length, encode_func, num_chunks):
|
||||
embs = []
|
||||
for e in _grouper(32, tokens):
|
||||
enc, pooled = encode_func(e)
|
||||
enc = enc.reshape((len(e), length, -1))
|
||||
embs.append(enc)
|
||||
embs = torch.cat(embs)
|
||||
embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1))
|
||||
return embs
|
||||
|
||||
|
||||
def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
pooled_base = base_emb[0, length - 1:length, :]
|
||||
wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True)
|
||||
weight_dict = dict((id, w)
|
||||
for id, w in zip(wids, np.array(weights).reshape(-1)[inds])
|
||||
if w != 1.0)
|
||||
|
||||
if len(weight_dict) == 0:
|
||||
return torch.zeros_like(base_emb), base_emb[0, length - 1:length, :]
|
||||
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1, -1, 1).expand(base_emb.shape)
|
||||
|
||||
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
# TODO: find most suitable masking token here
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
ws = []
|
||||
masked_tokens = []
|
||||
masks = []
|
||||
|
||||
# create prompts
|
||||
for id, w in weight_dict.items():
|
||||
masked, m = mask_word_id(tokens, word_ids, id, m_token)
|
||||
masked_tokens.extend(masked)
|
||||
|
||||
m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device)
|
||||
m = m.reshape(1, -1, 1).expand(base_emb.shape)
|
||||
masks.append(m)
|
||||
|
||||
ws.append(w)
|
||||
|
||||
# batch process prompts
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
masks = torch.cat(masks)
|
||||
|
||||
embs = (base_emb.expand(embs.shape) - embs)
|
||||
pooled = embs[0, length - 1:length, :]
|
||||
|
||||
embs *= masks
|
||||
embs = embs.sum(axis=0, keepdim=True)
|
||||
|
||||
pooled_start = pooled_base.expand(len(ws), -1)
|
||||
ws = torch.tensor(ws).reshape(-1, 1).expand(pooled_start.shape)
|
||||
pooled = (pooled - pooled_start) * (ws - 1)
|
||||
pooled = pooled.mean(axis=0, keepdim=True)
|
||||
|
||||
return ((weight_tensor - 1) * embs), pooled_base + pooled
|
||||
|
||||
|
||||
def mask_inds(tokens, inds, mask_token):
|
||||
clip_len = len(tokens[0])
|
||||
inds_set = set(inds)
|
||||
new_tokens = [[mask_token if i * clip_len + j in inds_set else t
|
||||
for j, t in enumerate(x)] for i, x in enumerate(tokens)]
|
||||
return new_tokens
|
||||
|
||||
|
||||
def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
w, w_inv = np.unique(weights, return_inverse=True)
|
||||
|
||||
if np.sum(w < 1) == 0:
|
||||
return base_emb, tokens, base_emb[0, length - 1:length, :]
|
||||
# m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
# using the comma token as a masking token seems to work better than aos tokens for SD 1.x
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
masked_tokens = []
|
||||
|
||||
masked_current = tokens
|
||||
for i in range(len(w)):
|
||||
if w[i] >= 1:
|
||||
continue
|
||||
masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token)
|
||||
masked_tokens.extend(masked_current)
|
||||
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
embs = torch.cat([base_emb, embs])
|
||||
w = w[w <= 1.0]
|
||||
w_mix = np.diff([0] + w.tolist())
|
||||
w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1, 1, 1))
|
||||
|
||||
weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True)
|
||||
return weighted_emb, masked_current, weighted_emb[0, length - 1:length, :]
|
||||
|
||||
|
||||
def scale_emb_to_mag(base_emb, weighted_emb):
|
||||
norm_base = torch.linalg.norm(base_emb)
|
||||
norm_weighted = torch.linalg.norm(weighted_emb)
|
||||
embeddings_final = (norm_base / norm_weighted) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
|
||||
def recover_dist(base_emb, weighted_emb):
|
||||
fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean())
|
||||
embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean())
|
||||
return embeddings_final
|
||||
|
||||
|
||||
def A1111_renorm(base_emb, weighted_emb):
|
||||
embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
|
||||
def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266,
|
||||
length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False):
|
||||
tokens = [[t for t, _, _ in x] for x in tokenized]
|
||||
weights = [[w for _, w, _ in x] for x in tokenized]
|
||||
word_ids = [[wid for _, _, wid in x] for x in tokenized]
|
||||
|
||||
# weight normalization
|
||||
# ====================
|
||||
|
||||
# distribute down/up weights over word lengths
|
||||
if token_normalization.startswith("length"):
|
||||
weights = divide_length(word_ids, weights)
|
||||
|
||||
# make mean of word tokens 1
|
||||
if token_normalization.endswith("mean"):
|
||||
weights = shift_mean_weight(word_ids, weights)
|
||||
|
||||
# weight interpretation
|
||||
# =====================
|
||||
pooled = None
|
||||
|
||||
if weight_interpretation == "comfy":
|
||||
weighted_tokens = [[(t, w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, pooled_base = encode_func(weighted_tokens)
|
||||
pooled = pooled_base
|
||||
else:
|
||||
unweighted_tokens = [[(t, 1.0) for t, _, _ in x] for x in tokenized]
|
||||
base_emb, pooled_base = encode_func(unweighted_tokens)
|
||||
|
||||
if weight_interpretation == "A1111":
|
||||
weighted_emb = from_zero(weights, base_emb)
|
||||
weighted_emb = A1111_renorm(base_emb, weighted_emb)
|
||||
pooled = pooled_base
|
||||
|
||||
if weight_interpretation == "compel":
|
||||
pos_tokens = [[(t, w) if w >= 1.0 else (t, 1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, _ = encode_func(pos_tokens)
|
||||
weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func)
|
||||
|
||||
if weight_interpretation == "comfy++":
|
||||
weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights]
|
||||
# unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down]
|
||||
embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weighted_emb += embs
|
||||
|
||||
if weight_interpretation == "down_weight":
|
||||
weights = scale_to_norm(weights, word_ids, w_max)
|
||||
weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
|
||||
if return_pooled:
|
||||
if apply_to_pooled:
|
||||
return weighted_emb, pooled
|
||||
else:
|
||||
return weighted_emb, pooled_base
|
||||
return weighted_emb, None
|
||||
|
||||
|
||||
def encode_token_weights_g(model, token_weight_pairs):
|
||||
return model.clip_g.encode_token_weights(token_weight_pairs)
|
||||
|
||||
|
||||
def encode_token_weights_l(model, token_weight_pairs):
|
||||
l_out, pooled = model.clip_l.encode_token_weights(token_weight_pairs)
|
||||
return l_out, pooled
|
||||
|
||||
def encode_token_weights_t5(model, token_weight_pairs):
|
||||
return model.t5xxl.encode_token_weights(token_weight_pairs)
|
||||
|
||||
|
||||
def encode_token_weights(model, token_weight_pairs, encode_func):
|
||||
if model.layer_idx is not None:
|
||||
# 2016 [c2cb8e88] 及以上版本去除了sdxl clip的clip_layer方法
|
||||
# if compare_revision(2016):
|
||||
model.cond_stage_model.set_clip_options({'layer': model.layer_idx})
|
||||
# else:
|
||||
# model.cond_stage_model.clip_layer(model.layer_idx)
|
||||
|
||||
model_management.load_model_gpu(model.patcher)
|
||||
return encode_func(model.cond_stage_model, token_weight_pairs)
|
||||
|
||||
def prepareXL(embs_l, embs_g, pooled, clip_balance):
|
||||
l_w = 1 - max(0, clip_balance - .5) * 2
|
||||
g_w = 1 - max(0, .5 - clip_balance) * 2
|
||||
if embs_l is not None:
|
||||
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
|
||||
else:
|
||||
return embs_g, pooled
|
||||
|
||||
def prepareSD3(out, pooled, clip_balance):
|
||||
lg_w = 1 - max(0, clip_balance - .5) * 2
|
||||
t5_w = 1 - max(0, .5 - clip_balance) * 2
|
||||
if out.shape[0] > 1:
|
||||
return torch.cat([out[0] * lg_w, out[1] * t5_w], dim=-1), pooled
|
||||
else:
|
||||
return out, pooled
|
||||
|
||||
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5,
|
||||
apply_to_pooled=True, width=1024, height=1024, crop_w=0, crop_h=0, target_width=1024, target_height=1024, a1111_prompt_style=False, steps=1):
|
||||
|
||||
# Use clip text encode by smzNodes like same as a1111, when if you need installed the smzNodes
|
||||
if a1111_prompt_style:
|
||||
if "smZ CLIPTextEncode" in NODE_CLASS_MAPPINGS:
|
||||
cls = NODE_CLASS_MAPPINGS['smZ CLIPTextEncode']
|
||||
embeddings_final, = cls().encode(clip, text, weight_interpretation, True, True, False, False, 6, 1024, 1024, 0, 0, 1024, 1024, '', '', steps)
|
||||
return embeddings_final
|
||||
else:
|
||||
raise Exception(f"[smzNodes Not Found] you need to install 'ComfyUI-smzNodes'")
|
||||
|
||||
time_start = 0
|
||||
time_end = 1
|
||||
match = re.search(r'TIMESTEP.*$', text)
|
||||
if match:
|
||||
timestep = match.group()
|
||||
timestep = timestep.split(' ')
|
||||
timestep = timestep[0]
|
||||
text = text.replace(timestep, '')
|
||||
value = timestep.split(':')
|
||||
if len(value) >= 3:
|
||||
time_start = float(value[1])
|
||||
time_end = float(value[2])
|
||||
elif len(value) == 2:
|
||||
time_start = float(value[1])
|
||||
time_end = 1
|
||||
elif len(value) == 1:
|
||||
time_start = 0.1
|
||||
time_end = 1
|
||||
|
||||
pass3 = [x.strip() for x in text.split("BREAK")]
|
||||
pass3 = [x for x in pass3 if x != '']
|
||||
|
||||
if len(pass3) == 0:
|
||||
pass3 = ['']
|
||||
|
||||
# pass3_str = [f'[{x}]' for x in pass3]
|
||||
# print(f"CLIP: {str.join(' + ', pass3_str)}")
|
||||
|
||||
conditioning = None
|
||||
|
||||
for text in pass3:
|
||||
tokenized = clip.tokenize(text, return_word_ids=True)
|
||||
if SD3ClipModel and isinstance(clip.cond_stage_model, SD3ClipModel):
|
||||
lg_out = None
|
||||
pooled = None
|
||||
out = None
|
||||
|
||||
if len(tokenized['l']) > 0 or len(tokenized['g']) > 0:
|
||||
if clip.cond_stage_model.clip_l is not None:
|
||||
lg_out, l_pooled = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max, return_pooled=True,)
|
||||
else:
|
||||
l_pooled = torch.zeros((1, 768), device=model_management.intermediate_device())
|
||||
|
||||
if clip.cond_stage_model.clip_g is not None:
|
||||
g_out, g_pooled = advanced_encode_from_tokens(tokenized['g'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
|
||||
w_max=w_max, return_pooled=True)
|
||||
if lg_out is not None:
|
||||
lg_out = torch.cat([lg_out, g_out], dim=-1)
|
||||
else:
|
||||
lg_out = torch.nn.functional.pad(g_out, (768, 0))
|
||||
else:
|
||||
g_out = None
|
||||
g_pooled = torch.zeros((1, 1280), device=model_management.intermediate_device())
|
||||
|
||||
if lg_out is not None:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
out = lg_out
|
||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
# t5xxl
|
||||
if 't5xxl' in tokenized:
|
||||
t5_out, t5_pooled = advanced_encode_from_tokens(tokenized['t5xxl'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_t5),
|
||||
w_max=w_max, return_pooled=True)
|
||||
if lg_out is not None:
|
||||
out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
else:
|
||||
out = t5_out
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros((1, 77, 4096), device=model_management.intermediate_device())
|
||||
|
||||
if pooled is None:
|
||||
pooled = torch.zeros((1, 768 + 1280), device=model_management.intermediate_device())
|
||||
|
||||
embeddings_final, pooled = prepareSD3(out, pooled, clip_balance)
|
||||
cond = [[embeddings_final, {"pooled_output": pooled}]]
|
||||
|
||||
elif isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)):
|
||||
embs_l = None
|
||||
embs_g = None
|
||||
pooled = None
|
||||
if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel):
|
||||
embs_l, _ = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max,
|
||||
return_pooled=False)
|
||||
if 'g' in tokenized:
|
||||
embs_g, pooled = advanced_encode_from_tokens(tokenized['g'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x,
|
||||
encode_token_weights_g),
|
||||
w_max=w_max,
|
||||
return_pooled=True,
|
||||
apply_to_pooled=apply_to_pooled)
|
||||
|
||||
embeddings_final, pooled = prepareXL(embs_l, embs_g, pooled, clip_balance)
|
||||
|
||||
cond = [[embeddings_final, {"pooled_output": pooled}]]
|
||||
# cond = [[embeddings_final,
|
||||
# {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w,
|
||||
# "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]]
|
||||
else:
|
||||
embeddings_final, pooled = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max,return_pooled=True,)
|
||||
cond = [[embeddings_final, {"pooled_output": pooled}]]
|
||||
|
||||
if conditioning is not None:
|
||||
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
|
||||
else:
|
||||
conditioning = cond
|
||||
|
||||
# setTimeStepRange
|
||||
if time_start > 0 or time_end < 1:
|
||||
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
|
||||
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
|
||||
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
|
||||
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
|
||||
372
custom_nodes/ComfyUI-Easy-Use/py/libs/api/bizyair.py
Normal file
372
custom_nodes/ComfyUI-Easy-Use/py/libs/api/bizyair.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import yaml
|
||||
import pathlib
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import zlib
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from enum import Enum
|
||||
from functools import singledispatch
|
||||
from typing import Any, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
root_path = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
config_path = os.path.join(root_path, 'config.yaml')
|
||||
|
||||
class BizyAIRAPI:
|
||||
def __init__(self):
|
||||
self.base_url = 'https://bizyair-api.siliconflow.cn/x/v1'
|
||||
self.api_key = None
|
||||
|
||||
|
||||
def getAPIKey(self):
|
||||
if self.api_key is None:
|
||||
if os.path.isfile(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if 'BIZYAIR_API_KEY' not in data:
|
||||
raise Exception("Please add BIZYAIR_API_KEY to config.yaml")
|
||||
self.api_key = data['BIZYAIR_API_KEY']
|
||||
else:
|
||||
raise Exception("Please add config.yaml to root path")
|
||||
return self.api_key
|
||||
|
||||
def send_post_request(self, url, payload, headers):
|
||||
try:
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
response_data = response.read().decode("utf-8")
|
||||
return response_data
|
||||
except urllib.error.URLError as e:
|
||||
if "Unauthorized" in str(e):
|
||||
raise Exception(
|
||||
"Key is invalid, please refer to https://cloud.siliconflow.cn to get the API key.\n"
|
||||
"If you have the key, please click the 'BizyAir Key' button at the bottom right to set the key."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to connect to the server: {e}, if you have no key, "
|
||||
)
|
||||
|
||||
# joycaption
|
||||
def joyCaption(self, payload, image, apikey_override=None, API_URL='/supernode/joycaption2'):
|
||||
if apikey_override is not None:
|
||||
api_key = apikey_override
|
||||
else:
|
||||
api_key = self.getAPIKey()
|
||||
url = f"{self.base_url}{API_URL}"
|
||||
print('Sending request to:', url)
|
||||
auth = f"Bearer {api_key}"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": auth,
|
||||
}
|
||||
input_image = encode_data(image, disable_image_marker=True)
|
||||
payload["image"] = input_image
|
||||
|
||||
ret: str = self.send_post_request(url=url, payload=payload, headers=headers)
|
||||
ret = json.loads(ret)
|
||||
|
||||
try:
|
||||
if "result" in ret:
|
||||
ret = json.loads(ret["result"])
|
||||
except Exception as e:
|
||||
raise Exception(f"Unexpected response: {ret} {e=}")
|
||||
|
||||
if ret["type"] == "error":
|
||||
raise Exception(ret["message"])
|
||||
|
||||
msg = ret["data"]
|
||||
if msg["type"] not in ("comfyair", "bizyair",):
|
||||
raise Exception(f"Unexpected response type: {msg}")
|
||||
|
||||
caption = msg["data"]
|
||||
|
||||
return caption
|
||||
|
||||
bizyairAPI = BizyAIRAPI()
|
||||
|
||||
|
||||
|
||||
BIZYAIR_DEBUG = True
|
||||
# Marker to identify base64-encoded tensors
|
||||
TENSOR_MARKER = "TENSOR:"
|
||||
IMAGE_MARKER = "IMAGE:"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: Image.Image) -> Image.Image:
|
||||
if image.mode != "RGB":
|
||||
return image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def encode_image_to_base64(
|
||||
image: Image.Image, format: str = "png", quality: int = 100, lossless=False
|
||||
) -> str:
|
||||
image = convert_image_to_rgb(image)
|
||||
with io.BytesIO() as output:
|
||||
image.save(output, format=format, quality=quality, lossless=lossless)
|
||||
output.seek(0)
|
||||
img_bytes = output.getvalue()
|
||||
if BIZYAIR_DEBUG:
|
||||
print(f"encode_image_to_base64: {format_bytes(len(img_bytes))}")
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def decode_base64_to_np(img_data: str, format: str = "png") -> np.ndarray:
|
||||
img_bytes = base64.b64decode(img_data)
|
||||
if BIZYAIR_DEBUG:
|
||||
print(f"decode_base64_to_np: {format_bytes(len(img_bytes))}")
|
||||
with io.BytesIO(img_bytes) as input_buffer:
|
||||
img = Image.open(input_buffer)
|
||||
# https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/nodes.py#L1511
|
||||
img = img.convert("RGB")
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def decode_base64_to_image(img_data: str) -> Image.Image:
|
||||
img_bytes = base64.b64decode(img_data)
|
||||
with io.BytesIO(img_bytes) as input_buffer:
|
||||
img = Image.open(input_buffer)
|
||||
if BIZYAIR_DEBUG:
|
||||
format_info = img.format.upper() if img.format else "Unknown"
|
||||
print(f"decode image format: {format_info}")
|
||||
return img
|
||||
|
||||
|
||||
def format_bytes(num_bytes: int) -> str:
|
||||
"""
|
||||
Converts a number of bytes to a human-readable string with units (B, KB, or MB).
|
||||
|
||||
:param num_bytes: The number of bytes to convert.
|
||||
:return: A string representing the number of bytes in a human-readable format.
|
||||
"""
|
||||
if num_bytes < 1024:
|
||||
return f"{num_bytes} B"
|
||||
elif num_bytes < 1024 * 1024:
|
||||
return f"{num_bytes / 1024:.2f} KB"
|
||||
else:
|
||||
return f"{num_bytes / (1024 * 1024):.2f} MB"
|
||||
|
||||
|
||||
def _legacy_encode_comfy_image(image: torch.Tensor, image_format="png") -> str:
|
||||
input_image = image.cpu().detach().numpy()
|
||||
i = 255.0 * input_image[0]
|
||||
input_image = np.clip(i, 0, 255).astype(np.uint8)
|
||||
base64ed_image = encode_image_to_base64(
|
||||
Image.fromarray(input_image), format=image_format
|
||||
)
|
||||
return base64ed_image
|
||||
|
||||
|
||||
def _legacy_decode_comfy_image(
|
||||
img_data: Union[List, str], image_format="png"
|
||||
) -> torch.tensor:
|
||||
if isinstance(img_data, List):
|
||||
decoded_imgs = [decode_comfy_image(x, old_version=True) for x in img_data]
|
||||
|
||||
combined_imgs = torch.cat(decoded_imgs, dim=0)
|
||||
return combined_imgs
|
||||
|
||||
out = decode_base64_to_np(img_data, format=image_format)
|
||||
out = np.array(out).astype(np.float32) / 255.0
|
||||
output = torch.from_numpy(out)[None,]
|
||||
return output
|
||||
|
||||
|
||||
def _new_encode_comfy_image(images: torch.Tensor, image_format="WEBP", **kwargs) -> str:
|
||||
"""https://docs.comfy.org/essentials/custom_node_snippets#save-an-image-batch
|
||||
Encode a batch of images to base64 strings.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): A batch of images.
|
||||
image_format (str, optional): The format of the images. Defaults to "WEBP".
|
||||
|
||||
Returns:
|
||||
str: A JSON string containing the base64-encoded images.
|
||||
"""
|
||||
results = {}
|
||||
for batch_number, image in enumerate(images):
|
||||
i = 255.0 * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
base64ed_image = encode_image_to_base64(img, format=image_format, **kwargs)
|
||||
results[batch_number] = base64ed_image
|
||||
|
||||
return json.dumps(results)
|
||||
|
||||
|
||||
def _new_decode_comfy_image(img_datas: str, image_format="WEBP") -> torch.tensor:
|
||||
"""
|
||||
Decode a batch of base64-encoded images.
|
||||
|
||||
Args:
|
||||
img_datas (str): A JSON string containing the base64-encoded images.
|
||||
image_format (str, optional): The format of the images. Defaults to "WEBP".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the decoded images.
|
||||
"""
|
||||
img_datas = json.loads(img_datas)
|
||||
|
||||
decoded_imgs = []
|
||||
for img_data in img_datas.values():
|
||||
decoded_image = decode_base64_to_np(img_data, format=image_format)
|
||||
decoded_image = np.array(decoded_image).astype(np.float32) / 255.0
|
||||
decoded_imgs.append(torch.from_numpy(decoded_image)[None,])
|
||||
|
||||
return torch.cat(decoded_imgs, dim=0)
|
||||
|
||||
|
||||
def encode_comfy_image(
|
||||
image: torch.Tensor, image_format="WEBP", old_version=False, lossless=False
|
||||
) -> str:
|
||||
if old_version:
|
||||
return _legacy_encode_comfy_image(image, image_format)
|
||||
return _new_encode_comfy_image(image, image_format, lossless=lossless)
|
||||
|
||||
|
||||
def decode_comfy_image(
|
||||
img_data: Union[List, str], image_format="WEBP", old_version=False
|
||||
) -> torch.tensor:
|
||||
if old_version:
|
||||
return _legacy_decode_comfy_image(img_data, image_format)
|
||||
return _new_decode_comfy_image(img_data, image_format)
|
||||
|
||||
|
||||
def tensor_to_base64(tensor: torch.Tensor, compress=True) -> str:
|
||||
tensor_np = tensor.cpu().detach().numpy()
|
||||
|
||||
tensor_bytes = pickle.dumps(tensor_np)
|
||||
if compress:
|
||||
tensor_bytes = zlib.compress(tensor_bytes)
|
||||
|
||||
tensor_b64 = base64.b64encode(tensor_bytes).decode("utf-8")
|
||||
return tensor_b64
|
||||
|
||||
|
||||
def base64_to_tensor(tensor_b64: str, compress=True) -> torch.Tensor:
|
||||
tensor_bytes = base64.b64decode(tensor_b64)
|
||||
|
||||
if compress:
|
||||
tensor_bytes = zlib.decompress(tensor_bytes)
|
||||
|
||||
tensor_np = pickle.loads(tensor_bytes)
|
||||
|
||||
tensor = torch.from_numpy(tensor_np)
|
||||
return tensor
|
||||
|
||||
|
||||
@singledispatch
|
||||
def decode_data(input, old_version=False):
|
||||
raise NotImplementedError(f"Unsupported type: {type(input)}")
|
||||
|
||||
|
||||
@decode_data.register(int)
|
||||
@decode_data.register(float)
|
||||
@decode_data.register(bool)
|
||||
@decode_data.register(type(None))
|
||||
def _(input, **kwargs):
|
||||
return input
|
||||
|
||||
|
||||
@decode_data.register(dict)
|
||||
def _(input, **kwargs):
|
||||
return {k: decode_data(v, **kwargs) for k, v in input.items()}
|
||||
|
||||
|
||||
@decode_data.register(list)
|
||||
def _(input, **kwargs):
|
||||
return [decode_data(x, **kwargs) for x in input]
|
||||
|
||||
|
||||
@decode_data.register(str)
|
||||
def _(input: str, **kwargs):
|
||||
if input.startswith(TENSOR_MARKER):
|
||||
tensor_b64 = input[len(TENSOR_MARKER) :]
|
||||
return base64_to_tensor(tensor_b64)
|
||||
elif input.startswith(IMAGE_MARKER):
|
||||
tensor_b64 = input[len(IMAGE_MARKER) :]
|
||||
old_version = kwargs.get("old_version", False)
|
||||
return decode_comfy_image(tensor_b64, old_version=old_version)
|
||||
return input
|
||||
|
||||
|
||||
@singledispatch
|
||||
def encode_data(output, disable_image_marker=False, old_version=False):
|
||||
raise NotImplementedError(f"Unsupported type: {type(output)}")
|
||||
|
||||
|
||||
@encode_data.register(dict)
|
||||
def _(output, **kwargs):
|
||||
return {k: encode_data(v, **kwargs) for k, v in output.items()}
|
||||
|
||||
|
||||
@encode_data.register(list)
|
||||
def _(output, **kwargs):
|
||||
return [encode_data(x, **kwargs) for x in output]
|
||||
|
||||
|
||||
def is_image_tensor(tensor) -> bool:
|
||||
"""https://docs.comfy.org/essentials/custom_node_datatypes#image
|
||||
|
||||
Check if the given tensor is in the format of an IMAGE (shape [B, H, W, C] where C=3).
|
||||
|
||||
`Args`:
|
||||
tensor (torch.Tensor): The tensor to check.
|
||||
|
||||
`Returns`:
|
||||
bool: True if the tensor is in the IMAGE format, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
return False
|
||||
|
||||
if len(tensor.shape) != 4:
|
||||
return False
|
||||
|
||||
B, H, W, C = tensor.shape
|
||||
if C != 3:
|
||||
return False
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
@encode_data.register(torch.Tensor)
|
||||
def _(output, **kwargs):
|
||||
if is_image_tensor(output) and not kwargs.get("disable_image_marker", False):
|
||||
old_version = kwargs.get("old_version", False)
|
||||
lossless = kwargs.get("lossless", True)
|
||||
return IMAGE_MARKER + encode_comfy_image(
|
||||
output, image_format="WEBP", old_version=old_version, lossless=lossless
|
||||
)
|
||||
return TENSOR_MARKER + tensor_to_base64(output)
|
||||
|
||||
|
||||
@encode_data.register(int)
|
||||
@encode_data.register(float)
|
||||
@encode_data.register(bool)
|
||||
@encode_data.register(type(None))
|
||||
def _(output, **kwargs):
|
||||
return output
|
||||
|
||||
|
||||
@encode_data.register(str)
|
||||
def _(output, **kwargs):
|
||||
return output
|
||||
51
custom_nodes/ComfyUI-Easy-Use/py/libs/api/fluxai.py
Normal file
51
custom_nodes/ComfyUI-Easy-Use/py/libs/api/fluxai.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
import requests
|
||||
import pathlib
|
||||
from aiohttp import web
|
||||
|
||||
root_path = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
config_path = os.path.join(root_path,'config.yaml')
|
||||
class FluxAIAPI:
|
||||
def __init__(self):
|
||||
self.api_url = "https://fluxaiimagegenerator.com/api"
|
||||
self.origin = "https://fluxaiimagegenerator.com"
|
||||
self.user_agent = None
|
||||
self.cookie = None
|
||||
|
||||
def promptGenerate(self, text, cookies=None):
|
||||
cookie = self.cookie if cookies is None else cookies
|
||||
if cookie is None:
|
||||
if os.path.isfile(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if 'FLUXAI_COOKIE' not in data:
|
||||
raise Exception("Please add FLUXAI_COOKIE to config.yaml")
|
||||
if "FLUXAI_USER_AGENT" in data:
|
||||
self.user_agent = data["FLUXAI_USER_AGENT"]
|
||||
self.cookie = cookie = data['FLUXAI_COOKIE']
|
||||
|
||||
headers = {
|
||||
"Cookie": cookie,
|
||||
"Referer": "https://fluxaiimagegenerator.com/flux-prompt-generator",
|
||||
"Origin": self.origin,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.user_agent is not None:
|
||||
headers['User-Agent'] = self.user_agent
|
||||
|
||||
url = self.api_url + '/prompt'
|
||||
json = {
|
||||
"prompt": text
|
||||
}
|
||||
|
||||
response = requests.post(url, json=json, headers=headers)
|
||||
res = response.json()
|
||||
if "error" in res:
|
||||
return res['error']
|
||||
elif "data" in res and "prompt" in res['data']:
|
||||
return res['data']['prompt']
|
||||
|
||||
fluxaiAPI = FluxAIAPI()
|
||||
|
||||
200
custom_nodes/ComfyUI-Easy-Use/py/libs/api/stability.py
Normal file
200
custom_nodes/ComfyUI-Easy-Use/py/libs/api/stability.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import json
|
||||
import os
|
||||
import yaml
|
||||
import requests
|
||||
import pathlib
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
from ..image import tensor2pil, pil2tensor, image2base64, pil2byte
|
||||
from ..log import log_node_error
|
||||
|
||||
|
||||
root_path = pathlib.Path(__file__).parent.parent.parent.parent
|
||||
config_path = os.path.join(root_path,'config.yaml')
|
||||
default_key = [{'name':'Default', 'key':''}]
|
||||
|
||||
|
||||
class StabilityAPI:
|
||||
def __init__(self):
|
||||
self.api_url = "https://api.stability.ai"
|
||||
self.api_keys = None
|
||||
self.api_current = 0
|
||||
self.user_info = {}
|
||||
|
||||
def getErrors(self, code):
|
||||
errors = {
|
||||
400: "Bad Request",
|
||||
403: "ApiKey Forbidden",
|
||||
413: "Your request was larger than 10MiB.",
|
||||
429: "You have made more than 150 requests in 10 seconds.",
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
return errors.get(code, "Unknown Error")
|
||||
|
||||
def getAPIKeys(self):
|
||||
if os.path.isfile(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
if not data:
|
||||
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
if 'STABILITY_API_KEY' not in data:
|
||||
data['STABILITY_API_KEY'] = default_key
|
||||
data['STABILITY_API_DEFAULT'] = 0
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
api_keys = data['STABILITY_API_KEY']
|
||||
self.api_current = data['STABILITY_API_DEFAULT']
|
||||
self.api_keys = api_keys
|
||||
return api_keys
|
||||
else:
|
||||
# create a yaml file
|
||||
with open(config_path, 'w') as f:
|
||||
data = {'STABILITY_API_KEY': default_key, 'STABILITY_API_DEFAULT':0}
|
||||
yaml.dump(data, f)
|
||||
return data['STABILITY_API_KEY']
|
||||
pass
|
||||
|
||||
def setAPIKeys(self, api_keys):
|
||||
if len(api_keys) > 0:
|
||||
self.api_keys = api_keys
|
||||
# load and save the yaml file
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data['STABILITY_API_KEY'] = api_keys
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
return True
|
||||
|
||||
def setAPIDefault(self, current):
|
||||
if current is not None:
|
||||
self.api_current = current
|
||||
# load and save the yaml file
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
data['STABILITY_API_DEFAULT'] = current
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
return True
|
||||
|
||||
def generate_sd3_image(self, prompt, negative_prompt, aspect_ratio, model, seed, mode='text-to-image', image=None, strength=1, output_format='png', node_name='easy stableDiffusion3API'):
|
||||
url = f"{self.api_url}/v2beta/stable-image/generate/sd3"
|
||||
api_key = self.api_keys[self.api_current]['key']
|
||||
files = None
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"mode": mode,
|
||||
"model": model,
|
||||
"seed": seed,
|
||||
"output_format": output_format,
|
||||
}
|
||||
if model == 'sd3':
|
||||
data['negative_prompt'] = negative_prompt
|
||||
|
||||
if mode == 'text-to-image':
|
||||
files = {"none": ''}
|
||||
data['aspect_ratio'] = aspect_ratio
|
||||
elif mode == 'image-to-image':
|
||||
pil_image = tensor2pil(image)
|
||||
image_byte = pil2byte(pil_image)
|
||||
files = {"image": ("output.png", image_byte, 'image/png')}
|
||||
data['strength'] = strength
|
||||
|
||||
response = requests.post(url,
|
||||
headers={"authorization": f"{api_key}", "accept": "application/json"},
|
||||
files=files,
|
||||
data=data,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
PromptServer.instance.send_sync('stable-diffusion-api-generate-succeed',{"model":model})
|
||||
json_data = response.json()
|
||||
image_base64 = json_data['image']
|
||||
image_data = image2base64(image_base64)
|
||||
output_t = pil2tensor(image_data)
|
||||
return output_t
|
||||
else:
|
||||
if 'application/json' in response.headers['Content-Type']:
|
||||
error_info = response.json()
|
||||
log_node_error(node_name, error_info.get('name', 'No name provided'))
|
||||
log_node_error(node_name, error_info.get('errors', ['No details provided']))
|
||||
error_status_text = self.getErrors(response.status_code)
|
||||
PromptServer.instance.send_sync('easyuse-toast',{"type": "error", "content": error_status_text})
|
||||
raise Exception(f"Failed to generate image: {error_status_text}")
|
||||
|
||||
# get user account
|
||||
async def getUserAccount(self, cache=True):
|
||||
url = f"{self.api_url}/v1/user/account"
|
||||
api_key = self.api_keys[self.api_current]['key']
|
||||
name = self.api_keys[self.api_current]['name']
|
||||
if cache and name in self.user_info:
|
||||
return self.user_info[name]
|
||||
else:
|
||||
response = requests.get(url, headers={"Authorization": f"Bearer {api_key}"})
|
||||
if response.status_code == 200:
|
||||
user_info = response.json()
|
||||
self.user_info[name] = user_info
|
||||
return user_info
|
||||
else:
|
||||
PromptServer.instance.send_sync('easyuse-toast',{'type': 'error', 'content': self.getErrors(response.status_code)})
|
||||
return None
|
||||
|
||||
# get user balance
|
||||
async def getUserBalance(self):
|
||||
url = f"{self.api_url}/v1/user/balance"
|
||||
api_key = self.api_keys[self.api_current]['key']
|
||||
response = requests.get(url, headers={
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
})
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
PromptServer.instance.send_sync('easyuse-toast', {'type': 'error', 'content': self.getErrors(response.status_code)})
|
||||
return None
|
||||
|
||||
stableAPI = StabilityAPI()
|
||||
|
||||
@PromptServer.instance.routes.get("/easyuse/stability/api_keys")
|
||||
async def get_stability_api_keys(request):
|
||||
stableAPI.getAPIKeys()
|
||||
return web.json_response({"keys": stableAPI.api_keys, "current": stableAPI.api_current})
|
||||
|
||||
@PromptServer.instance.routes.post("/easyuse/stability/set_api_keys")
|
||||
async def set_stability_api_keys(request):
|
||||
post = await request.post()
|
||||
api_keys = post.get("api_keys")
|
||||
current = post.get('current')
|
||||
if api_keys is not None:
|
||||
api_keys = json.loads(api_keys)
|
||||
stableAPI.setAPIKeys(api_keys)
|
||||
if current is not None:
|
||||
print(current)
|
||||
stableAPI.setAPIDefault(int(current))
|
||||
account = await stableAPI.getUserAccount()
|
||||
balance = await stableAPI.getUserBalance()
|
||||
return web.json_response({'account': account, 'balance': balance})
|
||||
else:
|
||||
return web.json_response({'status': 'ok'})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@PromptServer.instance.routes.post("/easyuse/stability/set_apikey_default")
|
||||
async def set_stability_api_default(request):
|
||||
post = await request.post()
|
||||
current = post.get("current")
|
||||
if current is not None and current < len(stableAPI.api_keys):
|
||||
stableAPI.api_current = current
|
||||
return web.json_response({'status': 'ok'})
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@PromptServer.instance.routes.get("/easyuse/stability/user_info")
|
||||
async def get_account_info(request):
|
||||
account = await stableAPI.getUserAccount()
|
||||
balance = await stableAPI.getUserBalance()
|
||||
return web.json_response({'account': account, 'balance': balance})
|
||||
|
||||
@PromptServer.instance.routes.get("/easyuse/stability/balance")
|
||||
async def get_balance_info(request):
|
||||
balance = await stableAPI.getUserBalance()
|
||||
return web.json_response({'balance': balance})
|
||||
86
custom_nodes/ComfyUI-Easy-Use/py/libs/cache.py
Normal file
86
custom_nodes/ComfyUI-Easy-Use/py/libs/cache.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import itertools
|
||||
from typing import Optional
|
||||
|
||||
class TaggedCache:
|
||||
def __init__(self, tag_settings: Optional[dict]=None):
|
||||
self._tag_settings = tag_settings or {} # tag cache size
|
||||
self._data = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
return tag_data[key]
|
||||
raise KeyError(f'Key `{key}` does not exist')
|
||||
|
||||
def __setitem__(self, key, value: tuple):
|
||||
# value: (tag: str, (islist: bool, data: *))
|
||||
|
||||
# if key already exists, pop old value
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
tag_data.pop(key, None)
|
||||
break
|
||||
|
||||
tag = value[0]
|
||||
if tag not in self._data:
|
||||
|
||||
try:
|
||||
from cachetools import LRUCache
|
||||
|
||||
default_size = 20
|
||||
if 'ckpt' in tag:
|
||||
default_size = 5
|
||||
elif tag in ['latent', 'image']:
|
||||
default_size = 100
|
||||
|
||||
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size))
|
||||
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
# TODO: implement a simple lru dict
|
||||
self._data[tag] = {}
|
||||
self._data[tag][key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
del tag_data[key]
|
||||
return
|
||||
raise KeyError(f'Key `{key}` does not exist')
|
||||
|
||||
def __contains__(self, key):
|
||||
return any(key in tag_data for tag_data in self._data.values())
|
||||
|
||||
def items(self):
|
||||
yield from itertools.chain(*map(lambda x :x.items(), self._data.values()))
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
|
||||
for tag_data in self._data.values():
|
||||
if key in tag_data:
|
||||
return tag_data[key]
|
||||
return default
|
||||
|
||||
def clear(self):
|
||||
# clear all cache
|
||||
self._data = {}
|
||||
|
||||
cache_settings = {}
|
||||
cache = TaggedCache(cache_settings)
|
||||
cache_count = {}
|
||||
|
||||
def update_cache(k, tag, v):
|
||||
cache[k] = (tag, v)
|
||||
cnt = cache_count.get(k)
|
||||
if cnt is None:
|
||||
cnt = 0
|
||||
cache_count[k] = cnt
|
||||
else:
|
||||
cache_count[k] += 1
|
||||
def remove_cache(key):
|
||||
global cache
|
||||
if key == '*':
|
||||
cache = TaggedCache(cache_settings)
|
||||
elif key in cache:
|
||||
del cache[key]
|
||||
else:
|
||||
print(f"invalid {key}")
|
||||
153
custom_nodes/ComfyUI-Easy-Use/py/libs/chooser.py
Normal file
153
custom_nodes/ComfyUI-Easy-Use/py/libs/chooser.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from server import PromptServer
|
||||
from aiohttp import web
|
||||
from comfy import model_management as mm
|
||||
from comfy_execution.graph import ExecutionBlocker
|
||||
import time
|
||||
|
||||
class ChooserCancelled(Exception):
|
||||
pass
|
||||
|
||||
def get_chooser_cache():
|
||||
"""获取选择器缓存"""
|
||||
if not hasattr(PromptServer.instance, '_easyuse_chooser_node'):
|
||||
PromptServer.instance._easyuse_chooser_node = {}
|
||||
return PromptServer.instance._easyuse_chooser_node
|
||||
|
||||
def cleanup_session_data(node_id):
|
||||
"""清理会话数据"""
|
||||
node_data = get_chooser_cache()
|
||||
if node_id in node_data:
|
||||
session_keys = ["event", "selected", "images", "total_count", "cancelled"]
|
||||
for key in session_keys:
|
||||
if key in node_data[node_id]:
|
||||
del node_data[node_id][key]
|
||||
|
||||
def wait_for_chooser(id, images, mode, period=0.1):
|
||||
try:
|
||||
node_data = get_chooser_cache()
|
||||
images = [images[i:i + 1, ...] for i in range(images.shape[0])]
|
||||
if mode == "Keep Last Selection":
|
||||
if id in node_data and "last_selection" in node_data[id]:
|
||||
last_selection = node_data[id]["last_selection"]
|
||||
if last_selection and len(last_selection) > 0:
|
||||
valid_indices = [idx for idx in last_selection if 0 <= idx < len(images)]
|
||||
if valid_indices:
|
||||
try:
|
||||
PromptServer.instance.send_sync("easyuse-image-keep-selection", {
|
||||
"id": id,
|
||||
"selected": valid_indices
|
||||
})
|
||||
except Exception as e:
|
||||
pass
|
||||
cleanup_session_data(id)
|
||||
indices_str = ','.join(str(i) for i in valid_indices)
|
||||
images = [images[idx] for idx in valid_indices]
|
||||
images = torch.cat(images, dim=0)
|
||||
return {"result": (images,)}
|
||||
|
||||
if id in node_data:
|
||||
del node_data[id]
|
||||
|
||||
event = Event()
|
||||
node_data[id] = {
|
||||
"event": event,
|
||||
"images": images,
|
||||
"selected": None,
|
||||
"total_count": len(images),
|
||||
"cancelled": False,
|
||||
}
|
||||
|
||||
while id in node_data:
|
||||
node_info = node_data[id]
|
||||
if node_info.get("cancelled", False):
|
||||
cleanup_session_data(id)
|
||||
raise ChooserCancelled("Manual selection cancelled")
|
||||
|
||||
if "selected" in node_info and node_info["selected"] is not None:
|
||||
break
|
||||
|
||||
time.sleep(period)
|
||||
|
||||
if id in node_data:
|
||||
node_info = node_data[id]
|
||||
selected_indices = node_info.get("selected")
|
||||
|
||||
if selected_indices is not None and len(selected_indices) > 0:
|
||||
valid_indices = [idx for idx in selected_indices if 0 <= idx < len(images)]
|
||||
if valid_indices:
|
||||
selected_images = [images[idx] for idx in valid_indices]
|
||||
|
||||
if id not in node_data:
|
||||
node_data[id] = {}
|
||||
node_data[id]["last_selection"] = valid_indices
|
||||
cleanup_session_data(id)
|
||||
selected_images = torch.cat(selected_images, dim=0)
|
||||
return {"result": (selected_images,)}
|
||||
else:
|
||||
cleanup_session_data(id)
|
||||
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
|
||||
else:
|
||||
cleanup_session_data(id)
|
||||
return {
|
||||
"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
|
||||
else:
|
||||
return {"result": (images[0] if len(images) > 0 else ExecutionBlocker(None),)}
|
||||
|
||||
except ChooserCancelled:
|
||||
raise mm.InterruptProcessingException()
|
||||
except Exception as e:
|
||||
node_data = get_chooser_cache()
|
||||
if id in node_data:
|
||||
cleanup_session_data(id)
|
||||
if 'image_list' in locals() and len(images) > 0:
|
||||
return {"result": (images[0])}
|
||||
else:
|
||||
return {"result": (ExecutionBlocker(None),)}
|
||||
|
||||
|
||||
@PromptServer.instance.routes.post('/easyuse/image_chooser_message')
|
||||
async def handle_image_selection(request):
|
||||
try:
|
||||
data = await request.json()
|
||||
node_id = data.get("node_id")
|
||||
selected = data.get("selected", [])
|
||||
action = data.get("action")
|
||||
|
||||
node_data = get_chooser_cache()
|
||||
|
||||
if node_id not in node_data:
|
||||
return web.json_response({"code": -1, "error": "Node data does not exist"})
|
||||
|
||||
try:
|
||||
node_info = node_data[node_id]
|
||||
|
||||
if "total_count" not in node_info:
|
||||
return web.json_response({"code": -1, "error": "The node has been processed"})
|
||||
|
||||
if action == "cancel":
|
||||
node_info["cancelled"] = True
|
||||
node_info["selected"] = []
|
||||
elif action == "select" and isinstance(selected, list):
|
||||
valid_indices = [idx for idx in selected if isinstance(idx, int) and 0 <= idx < node_info["total_count"]]
|
||||
if valid_indices:
|
||||
node_info["selected"] = valid_indices
|
||||
node_info["cancelled"] = False
|
||||
else:
|
||||
return web.json_response({"code": -1, "error": "Invalid Selection Index"})
|
||||
else:
|
||||
return web.json_response({"code": -1, "error": "Invalid operation"})
|
||||
|
||||
node_info["event"].set()
|
||||
return web.json_response({"code": 1})
|
||||
|
||||
except Exception as e:
|
||||
if node_id in node_data and "event" in node_data[node_id]:
|
||||
node_data[node_id]["event"].set()
|
||||
return web.json_response({"code": -1, "message": "Processing Failed"})
|
||||
|
||||
except Exception as e:
|
||||
return web.json_response({"code": -1, "message": "Request Failed"})
|
||||
115
custom_nodes/ComfyUI-Easy-Use/py/libs/colorfix.py
Normal file
115
custom_nodes/ComfyUI-Easy-Use/py/libs/colorfix.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchvision.transforms import ToTensor, ToPILImage
|
||||
|
||||
def adain_color_fix(target: Image, source: Image):
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply adaptive instance normalization
|
||||
result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def wavelet_color_fix(target: Image, source: Image):
|
||||
source = source.resize(target.size, resample=Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert images to tensors
|
||||
to_tensor = ToTensor()
|
||||
target_tensor = to_tensor(target).unsqueeze(0)
|
||||
source_tensor = to_tensor(source).unsqueeze(0)
|
||||
|
||||
# Apply wavelet reconstruction
|
||||
result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
|
||||
|
||||
# Convert tensor back to image
|
||||
to_image = ToPILImage()
|
||||
result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
|
||||
|
||||
return result_image
|
||||
|
||||
def calc_mean_std(feat: Tensor, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
|
||||
"""Adaptive instance normalization.
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
high_freq = torch.zeros_like(image)
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
57
custom_nodes/ComfyUI-Easy-Use/py/libs/conditioning.py
Normal file
57
custom_nodes/ComfyUI-Easy-Use/py/libs/conditioning.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from .utils import find_wildcards_seed, find_nearest_steps, is_linked_styles_selector
|
||||
from .log import log_node_warn
|
||||
from .translate import zh_to_en, has_chinese
|
||||
from .wildcards import process_with_loras
|
||||
from .adv_encode import advanced_encode
|
||||
|
||||
from nodes import ConditioningConcat, ConditioningCombine, ConditioningAverage, ConditioningSetTimestepRange, CLIPTextEncode
|
||||
|
||||
def prompt_to_cond(type, model, clip, clip_skip, lora_stack, text, prompt_token_normalization, prompt_weight_interpretation, a1111_prompt_style ,my_unique_id, prompt, easyCache, can_load_lora=True, steps=None, model_type=None):
|
||||
styles_selector = is_linked_styles_selector(prompt, my_unique_id, type)
|
||||
title = "Positive encoding" if type == 'positive' else "Negative encoding"
|
||||
|
||||
# Translate cn to en
|
||||
if model_type not in ['hydit'] and text is not None and has_chinese(text):
|
||||
text = zh_to_en([text])[0]
|
||||
|
||||
if model_type in ['hydit', 'flux', 'mochi']:
|
||||
log_node_warn(title + "...")
|
||||
embeddings_final, = CLIPTextEncode().encode(clip, text) if text is not None else (None,)
|
||||
|
||||
return (embeddings_final, "", model, clip)
|
||||
|
||||
log_node_warn(title + "...")
|
||||
|
||||
positive_seed = find_wildcards_seed(my_unique_id, text, prompt)
|
||||
model, clip, text, cond_decode, show_prompt, pipe_lora_stack = process_with_loras(
|
||||
text, model, clip, type, positive_seed, can_load_lora, lora_stack, easyCache)
|
||||
wildcard_prompt = cond_decode if show_prompt or styles_selector else ""
|
||||
|
||||
clipped = clip.clone()
|
||||
# 当clip模型不存在t5xxl时,可执行跳过层
|
||||
if not hasattr(clip.cond_stage_model, 't5xxl'):
|
||||
if clip_skip != 0:
|
||||
clipped.clip_layer(clip_skip)
|
||||
|
||||
steps = steps if steps is not None else find_nearest_steps(my_unique_id, prompt)
|
||||
return (advanced_encode(clipped, text, prompt_token_normalization,
|
||||
prompt_weight_interpretation, w_max=1.0,
|
||||
apply_to_pooled='enable',
|
||||
a1111_prompt_style=a1111_prompt_style, steps=steps) if text is not None else None, wildcard_prompt, model, clipped)
|
||||
|
||||
def set_cond(old_cond, new_cond, mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end):
|
||||
if not old_cond:
|
||||
return new_cond
|
||||
else:
|
||||
if mode == "replace":
|
||||
return new_cond
|
||||
elif mode == "concat":
|
||||
return ConditioningConcat().concat(new_cond, old_cond)[0]
|
||||
elif mode == "combine":
|
||||
return ConditioningCombine().combine(old_cond, new_cond)[0]
|
||||
elif mode == 'average':
|
||||
return ConditioningAverage().addWeighted(new_cond, old_cond, average_strength)[0]
|
||||
elif mode == 'timestep':
|
||||
cond_1 = ConditioningSetTimestepRange().set_range(old_cond, old_cond_start, old_cond_end)[0]
|
||||
cond_2 = ConditioningSetTimestepRange().set_range(new_cond, new_cond_start, new_cond_end)[0]
|
||||
return ConditioningCombine().combine(cond_1, cond_2)[0]
|
||||
93
custom_nodes/ComfyUI-Easy-Use/py/libs/controlnet.py
Normal file
93
custom_nodes/ComfyUI-Easy-Use/py/libs/controlnet.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import folder_paths
|
||||
import comfy.controlnet
|
||||
import comfy.model_management
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
|
||||
union_controlnet_types = {"auto": -1, "openpose": 0, "depth": 1, "hed/pidi/scribble/ted": 2, "canny/lineart/anime_lineart/mlsd": 3, "normal": 4, "segment": 5, "tile": 6, "repaint": 7}
|
||||
|
||||
class easyControlnet:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def apply(self, control_net_name, image, positive, negative, strength, start_percent=0, end_percent=1, control_net=None, scale_soft_weights=1, mask=None, union_type=None, easyCache=None, use_cache=True, model=None, vae=None):
|
||||
if strength == 0:
|
||||
return (positive, negative)
|
||||
|
||||
# kolors controlnet patch
|
||||
from ..modules.kolors.loader import is_kolors_model, applyKolorsUnet
|
||||
if is_kolors_model(model):
|
||||
from ..modules.kolors.model_patch import patch_controlnet
|
||||
if control_net is None:
|
||||
with applyKolorsUnet():
|
||||
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
|
||||
control_net = patch_controlnet(model, control_net)
|
||||
else:
|
||||
if control_net is None:
|
||||
if easyCache is not None:
|
||||
control_net = easyCache.load_controlnet(control_net_name, scale_soft_weights, use_cache)
|
||||
else:
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
control_net = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
|
||||
# union controlnet
|
||||
if union_type is not None:
|
||||
control_net = control_net.copy()
|
||||
type_number = union_controlnet_types[union_type]
|
||||
if type_number >= 0:
|
||||
control_net.set_extra_arg("control_type", [type_number])
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.to(self.device)
|
||||
|
||||
if mask is not None and len(mask.shape) < 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
control_hint = image.movedim(-1, 1)
|
||||
|
||||
is_cond = True
|
||||
if negative is None:
|
||||
p = []
|
||||
for t in positive:
|
||||
n = [t[0], t[1].copy()]
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
|
||||
if 'control' in t[1]:
|
||||
c_net.set_previous_controlnet(t[1]['control'])
|
||||
n[1]['control'] = c_net
|
||||
n[1]['control_apply_to_uncond'] = True
|
||||
if mask is not None:
|
||||
n[1]['mask'] = mask
|
||||
n[1]['set_area_to_bounds'] = False
|
||||
p.append(n)
|
||||
positive = p
|
||||
else:
|
||||
cnets = {}
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = []
|
||||
for t in conditioning:
|
||||
d = t[1].copy()
|
||||
|
||||
prev_cnet = d.get('control', None)
|
||||
if prev_cnet in cnets:
|
||||
c_net = cnets[prev_cnet]
|
||||
else:
|
||||
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
|
||||
c_net.set_previous_controlnet(prev_cnet)
|
||||
cnets[prev_cnet] = c_net
|
||||
|
||||
d['control'] = c_net
|
||||
d['control_apply_to_uncond'] = False
|
||||
|
||||
if mask is not None:
|
||||
d['mask'] = mask
|
||||
d['set_area_to_bounds'] = False
|
||||
|
||||
n = [t[0], d]
|
||||
c.append(n)
|
||||
out.append(c)
|
||||
positive = out[0]
|
||||
negative = out[1]
|
||||
|
||||
return (positive, negative)
|
||||
167
custom_nodes/ComfyUI-Easy-Use/py/libs/dynthres_core.py
Normal file
167
custom_nodes/ComfyUI-Easy-Use/py/libs/dynthres_core.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import torch, math
|
||||
|
||||
######################### DynThresh Core #########################
|
||||
|
||||
class DynThresh:
|
||||
|
||||
Modes = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"]
|
||||
Startpoints = ["MEAN", "ZERO"]
|
||||
Variabilities = ["AD", "STD"]
|
||||
|
||||
def __init__(self, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, max_steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi):
|
||||
self.mimic_scale = mimic_scale
|
||||
self.threshold_percentile = threshold_percentile
|
||||
self.mimic_mode = mimic_mode
|
||||
self.cfg_mode = cfg_mode
|
||||
self.max_steps = max_steps
|
||||
self.cfg_scale_min = cfg_scale_min
|
||||
self.mimic_scale_min = mimic_scale_min
|
||||
self.experiment_mode = experiment_mode
|
||||
self.sched_val = sched_val
|
||||
self.sep_feat_channels = separate_feature_channels
|
||||
self.scaling_startpoint = scaling_startpoint
|
||||
self.variability_measure = variability_measure
|
||||
self.interpolate_phi = interpolate_phi
|
||||
|
||||
def interpret_scale(self, scale, mode, min):
|
||||
scale -= min
|
||||
max = self.max_steps - 1
|
||||
frac = self.step / max
|
||||
if mode == "Constant":
|
||||
pass
|
||||
elif mode == "Linear Down":
|
||||
scale *= 1.0 - frac
|
||||
elif mode == "Half Cosine Down":
|
||||
scale *= math.cos(frac)
|
||||
elif mode == "Cosine Down":
|
||||
scale *= math.cos(frac * 1.5707)
|
||||
elif mode == "Linear Up":
|
||||
scale *= frac
|
||||
elif mode == "Half Cosine Up":
|
||||
scale *= 1.0 - math.cos(frac)
|
||||
elif mode == "Cosine Up":
|
||||
scale *= 1.0 - math.cos(frac * 1.5707)
|
||||
elif mode == "Power Up":
|
||||
scale *= math.pow(frac, self.sched_val)
|
||||
elif mode == "Power Down":
|
||||
scale *= 1.0 - math.pow(frac, self.sched_val)
|
||||
elif mode == "Linear Repeating":
|
||||
portion = (frac * self.sched_val) % 1.0
|
||||
scale *= (0.5 - portion) * 2 if portion < 0.5 else (portion - 0.5) * 2
|
||||
elif mode == "Cosine Repeating":
|
||||
scale *= math.cos(frac * 6.28318 * self.sched_val) * 0.5 + 0.5
|
||||
elif mode == "Sawtooth":
|
||||
scale *= (frac * self.sched_val) % 1.0
|
||||
scale += min
|
||||
return scale
|
||||
|
||||
def dynthresh(self, cond, uncond, cfg_scale, weights):
|
||||
mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min)
|
||||
cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min)
|
||||
# uncond shape is (batch, 4, height, width)
|
||||
conds_per_batch = cond.shape[0] / uncond.shape[0]
|
||||
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
|
||||
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
|
||||
|
||||
### Normal first part of the CFG Scale logic, basically
|
||||
diff = cond_stacked - uncond.unsqueeze(1)
|
||||
if weights is not None:
|
||||
diff = diff * weights
|
||||
relative = diff.sum(1)
|
||||
|
||||
### Get the normal result for both mimic and normal scale
|
||||
mim_target = uncond + relative * mimic_scale
|
||||
cfg_target = uncond + relative * cfg_scale
|
||||
### If we weren't doing mimic scale, we'd just return cfg_target here
|
||||
|
||||
### Now recenter the values relative to their average rather than absolute, to allow scaling from average
|
||||
mim_flattened = mim_target.flatten(2)
|
||||
cfg_flattened = cfg_target.flatten(2)
|
||||
mim_means = mim_flattened.mean(dim=2).unsqueeze(2)
|
||||
cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2)
|
||||
mim_centered = mim_flattened - mim_means
|
||||
cfg_centered = cfg_flattened - cfg_means
|
||||
|
||||
if self.sep_feat_channels:
|
||||
if self.variability_measure == 'STD':
|
||||
mim_scaleref = mim_centered.std(dim=2).unsqueeze(2)
|
||||
cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2)
|
||||
else: # 'AD'
|
||||
mim_scaleref = mim_centered.abs().max(dim=2).values.unsqueeze(2)
|
||||
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile, dim=2).unsqueeze(2)
|
||||
|
||||
else:
|
||||
if self.variability_measure == 'STD':
|
||||
mim_scaleref = mim_centered.std()
|
||||
cfg_scaleref = cfg_centered.std()
|
||||
else: # 'AD'
|
||||
mim_scaleref = mim_centered.abs().max()
|
||||
cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile)
|
||||
|
||||
if self.scaling_startpoint == 'ZERO':
|
||||
scaling_factor = mim_scaleref / cfg_scaleref
|
||||
result = cfg_flattened * scaling_factor
|
||||
|
||||
else: # 'MEAN'
|
||||
if self.variability_measure == 'STD':
|
||||
cfg_renormalized = (cfg_centered / cfg_scaleref) * mim_scaleref
|
||||
else: # 'AD'
|
||||
### Get the maximum value of all datapoints (with an optional threshold percentile on the uncond)
|
||||
max_scaleref = torch.maximum(mim_scaleref, cfg_scaleref)
|
||||
### Clamp to the max
|
||||
cfg_clamped = cfg_centered.clamp(-max_scaleref, max_scaleref)
|
||||
### Now shrink from the max to normalize and grow to the mimic scale (instead of the CFG scale)
|
||||
cfg_renormalized = (cfg_clamped / max_scaleref) * mim_scaleref
|
||||
|
||||
### Now add it back onto the averages to get into real scale again and return
|
||||
result = cfg_renormalized + cfg_means
|
||||
|
||||
actual_res = result.unflatten(2, mim_target.shape[2:])
|
||||
|
||||
if self.interpolate_phi != 1.0:
|
||||
actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi)
|
||||
|
||||
if self.experiment_mode == 1:
|
||||
num = actual_res.cpu().numpy()
|
||||
for y in range(0, 64):
|
||||
for x in range (0, 64):
|
||||
if num[0][0][y][x] > 1.0:
|
||||
num[0][1][y][x] *= 0.5
|
||||
if num[0][1][y][x] > 1.0:
|
||||
num[0][1][y][x] *= 0.5
|
||||
if num[0][2][y][x] > 1.5:
|
||||
num[0][2][y][x] *= 0.5
|
||||
actual_res = torch.from_numpy(num).to(device=uncond.device)
|
||||
elif self.experiment_mode == 2:
|
||||
num = actual_res.cpu().numpy()
|
||||
for y in range(0, 64):
|
||||
for x in range (0, 64):
|
||||
over_scale = False
|
||||
for z in range(0, 4):
|
||||
if abs(num[0][z][y][x]) > 1.5:
|
||||
over_scale = True
|
||||
if over_scale:
|
||||
for z in range(0, 4):
|
||||
num[0][z][y][x] *= 0.7
|
||||
actual_res = torch.from_numpy(num).to(device=uncond.device)
|
||||
elif self.experiment_mode == 3:
|
||||
coefs = torch.tensor([
|
||||
# R G B W
|
||||
[0.298, 0.207, 0.208, 0.0], # L1
|
||||
[0.187, 0.286, 0.173, 0.0], # L2
|
||||
[-0.158, 0.189, 0.264, 0.0], # L3
|
||||
[-0.184, -0.271, -0.473, 1.0], # L4
|
||||
], device=uncond.device)
|
||||
res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs)
|
||||
max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max()
|
||||
max_rgb = max(max_r, max_g, max_b)
|
||||
print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}")
|
||||
if self.step / (self.max_steps - 1) > 0.2:
|
||||
if max_rgb < 2.0 and max_w < 3.0:
|
||||
res_rgb /= max_rgb / 2.4
|
||||
else:
|
||||
if max_rgb > 2.4 and max_w > 3.0:
|
||||
res_rgb /= max_rgb / 2.4
|
||||
actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse())
|
||||
|
||||
return actual_res
|
||||
27
custom_nodes/ComfyUI-Easy-Use/py/libs/easing.py
Normal file
27
custom_nodes/ComfyUI-Easy-Use/py/libs/easing.py
Normal file
@@ -0,0 +1,27 @@
|
||||
@staticmethod
|
||||
def easyIn(t: float)-> float:
|
||||
return t*t
|
||||
@staticmethod
|
||||
def easyOut(t: float)-> float:
|
||||
return -(t * (t - 2))
|
||||
@staticmethod
|
||||
def easyInOut(t: float)-> float:
|
||||
if t < 0.5:
|
||||
return 2*t*t
|
||||
else:
|
||||
return (-2*t*t) + (4*t) - 1
|
||||
|
||||
class EasingBase:
|
||||
|
||||
def easing(self, t: float, function='linear') -> float:
|
||||
if function == 'easyIn':
|
||||
return easyIn(t)
|
||||
elif function == 'easyOut':
|
||||
return easyOut(t)
|
||||
elif function == 'easyInOut':
|
||||
return easyInOut(t)
|
||||
else:
|
||||
return t
|
||||
|
||||
def ease(self, start, end, t) -> float:
|
||||
return end * t + start * (1 - t)
|
||||
@@ -0,0 +1,273 @@
|
||||
import torch
|
||||
from torchvision.transforms.functional import gaussian_blur
|
||||
from comfy.k_diffusion.sampling import default_noise_sampler, get_ancestral_step, to_d, BrownianTreeNoiseSampler
|
||||
from tqdm.auto import trange
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta=1.0,
|
||||
s_noise=1.0,
|
||||
noise_sampler=None,
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
|
||||
noise_sampler = default_noise_sampler(x)
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = x + noise * sigma_up * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta=1.0,
|
||||
s_noise=1.0,
|
||||
noise_sampler=None,
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda t: t.neg().exp()
|
||||
t_fn = lambda sigma: sigma.log().neg()
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigma_down == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_next - t
|
||||
s = t + r * h
|
||||
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
|
||||
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
||||
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
noise_sampler = default_noise_sampler(x)
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = x + noise * sigma_up * s_noise
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2m_sde(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
eta=1.0,
|
||||
s_noise=1.0,
|
||||
noise_sampler=None,
|
||||
solver_type="midpoint",
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
"""DPM-Solver++(2M) SDE."""
|
||||
|
||||
if solver_type not in {"heun", "midpoint"}:
|
||||
raise ValueError("solver_type must be 'heun' or 'midpoint'")
|
||||
|
||||
seed = extra_args.get("seed", None)
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
old_denoised = None
|
||||
h_last = None
|
||||
h = None
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Denoising step
|
||||
x = denoised
|
||||
else:
|
||||
# DPM-Solver++(2M) SDE
|
||||
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
||||
h = s - t
|
||||
eta_h = eta * h
|
||||
|
||||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
|
||||
|
||||
if old_denoised is not None:
|
||||
r = h_last / h
|
||||
if solver_type == "heun":
|
||||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
|
||||
elif solver_type == "midpoint":
|
||||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
|
||||
|
||||
if eta:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
denoised = None # 次ステップとサイズがあわないのでとりあえずNoneにしておく。
|
||||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True)
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
|
||||
|
||||
old_denoised = denoised
|
||||
h_last = h
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm(
|
||||
model,
|
||||
x,
|
||||
sigmas,
|
||||
extra_args=None,
|
||||
callback=None,
|
||||
disable=None,
|
||||
noise_sampler=None,
|
||||
eta=None,
|
||||
s_noise=None,
|
||||
upscale_ratio=2.0,
|
||||
start_step=5,
|
||||
end_step=15,
|
||||
upscale_n_step=3,
|
||||
unsharp_kernel_size=3,
|
||||
unsharp_sigma=0.5,
|
||||
unsharp_strength=0.0,
|
||||
):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
# make upscale info
|
||||
upscale_steps = []
|
||||
step = start_step - 1
|
||||
while step < end_step - 1:
|
||||
upscale_steps.append(step)
|
||||
step += upscale_n_step
|
||||
height, width = x.shape[2:]
|
||||
upscale_shapes = [
|
||||
(int(height * (((upscale_ratio - 1) / i) + 1)), int(width * (((upscale_ratio - 1) / i) + 1)))
|
||||
for i in reversed(range(1, len(upscale_steps) + 1))
|
||||
]
|
||||
upscale_info = {k: v for k, v in zip(upscale_steps, upscale_shapes)}
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
# Resize
|
||||
if i in upscale_info:
|
||||
x = torch.nn.functional.interpolate(x, size=upscale_info[i], mode="bicubic", align_corners=False)
|
||||
if unsharp_strength > 0:
|
||||
blurred = gaussian_blur(x, kernel_size=unsharp_kernel_size, sigma=unsharp_sigma)
|
||||
x = x + unsharp_strength * (x - blurred)
|
||||
noise_sampler = default_noise_sampler(x)
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
|
||||
return x
|
||||
227
custom_nodes/ComfyUI-Easy-Use/py/libs/image.py
Normal file
227
custom_nodes/ComfyUI-Easy-Use/py/libs/image.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import os
|
||||
import base64
|
||||
import torch
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from typing import List, Union
|
||||
|
||||
import folder_paths
|
||||
from .utils import install_package
|
||||
|
||||
# PIL to Tensor
|
||||
def pil2tensor(image):
|
||||
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
# Tensor to PIL
|
||||
def tensor2pil(image):
|
||||
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
# np to Tensor
|
||||
def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor:
|
||||
if isinstance(img_np, list):
|
||||
return torch.cat([np2tensor(img) for img in img_np], dim=0)
|
||||
return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0)
|
||||
# Tensor to np
|
||||
def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
|
||||
if len(tensor.shape) == 3: # Single image
|
||||
return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
|
||||
else: # Batch of images
|
||||
return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor]
|
||||
|
||||
def pil2byte(pil_image, format='PNG'):
|
||||
byte_arr = BytesIO()
|
||||
pil_image.save(byte_arr, format=format)
|
||||
byte_arr.seek(0)
|
||||
return byte_arr
|
||||
|
||||
def image2base64(image_base64):
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_data = Image.open(BytesIO(image_bytes))
|
||||
return image_data
|
||||
|
||||
# Get new bounds
|
||||
def get_new_bounds(width, height, left, right, top, bottom):
|
||||
"""Returns the new bounds for an image with inset crop data."""
|
||||
left = 0 + left
|
||||
right = width - right
|
||||
top = 0 + top
|
||||
bottom = height - bottom
|
||||
return (left, right, top, bottom)
|
||||
|
||||
def RGB2RGBA(image: Image, mask: Image) -> Image:
|
||||
(R, G, B) = image.convert('RGB').split()
|
||||
return Image.merge('RGBA', (R, G, B, mask.convert('L')))
|
||||
|
||||
def image2mask(image: Image) -> torch.Tensor:
|
||||
_image = image.convert('RGBA')
|
||||
alpha = _image.split()[0]
|
||||
bg = Image.new("L", _image.size)
|
||||
_image = Image.merge('RGBA', (bg, bg, bg, alpha))
|
||||
ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
|
||||
return ret_mask
|
||||
|
||||
def mask2image(mask: torch.Tensor) -> Image:
|
||||
masks = tensor2np(mask)
|
||||
for m in masks:
|
||||
_mask = Image.fromarray(m).convert("L")
|
||||
_image = Image.new("RGBA", _mask.size, color='white')
|
||||
_image = Image.composite(
|
||||
_image, Image.new("RGBA", _mask.size, color='black'), _mask)
|
||||
return _image
|
||||
|
||||
# 图像融合
|
||||
class blendImage:
|
||||
def g(self, x):
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
def blend_mode(self, img1, img2, mode):
|
||||
if mode == "normal":
|
||||
return img2
|
||||
elif mode == "multiply":
|
||||
return img1 * img2
|
||||
elif mode == "screen":
|
||||
return 1 - (1 - img1) * (1 - img2)
|
||||
elif mode == "overlay":
|
||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
elif mode == "soft_light":
|
||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1),
|
||||
img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||
elif mode == "difference":
|
||||
return img1 - img2
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str = 'normal'):
|
||||
image2 = image2.to(image1.device)
|
||||
if image1.shape != image2.shape:
|
||||
image2 = image2.permute(0, 3, 1, 2)
|
||||
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic',
|
||||
crop='center')
|
||||
image2 = image2.permute(0, 2, 3, 1)
|
||||
|
||||
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||
blended_image = torch.clamp(blended_image, 0, 1)
|
||||
return blended_image
|
||||
|
||||
|
||||
def empty_image(width, height, batch_size=1, color=0):
|
||||
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
|
||||
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
|
||||
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
|
||||
return torch.cat((r, g, b), dim=-1)
|
||||
|
||||
|
||||
class ResizeMode(Enum):
|
||||
RESIZE = "Just Resize"
|
||||
INNER_FIT = "Crop and Resize"
|
||||
OUTER_FIT = "Resize and Fill"
|
||||
def int_value(self):
|
||||
if self == ResizeMode.RESIZE:
|
||||
return 0
|
||||
elif self == ResizeMode.INNER_FIT:
|
||||
return 1
|
||||
elif self == ResizeMode.OUTER_FIT:
|
||||
return 2
|
||||
assert False, "NOTREACHED"
|
||||
|
||||
# credit by https://github.com/chflame163/ComfyUI_LayerStyle/blob/main/py/imagefunc.py#L591C1-L617C22
|
||||
def fit_resize_image(image: Image, target_width: int, target_height: int, fit: str, resize_sampler: str,
|
||||
background_color: str = '#000000') -> Image:
|
||||
image = image.convert('RGB')
|
||||
orig_width, orig_height = image.size
|
||||
if image is not None:
|
||||
if fit == 'letterbox':
|
||||
if orig_width / orig_height > target_width / target_height: # 更宽,上下留黑
|
||||
fit_width = target_width
|
||||
fit_height = int(target_width / orig_width * orig_height)
|
||||
else: # 更瘦,左右留黑
|
||||
fit_height = target_height
|
||||
fit_width = int(target_height / orig_height * orig_width)
|
||||
fit_image = image.resize((fit_width, fit_height), resize_sampler)
|
||||
ret_image = Image.new('RGB', size=(target_width, target_height), color=background_color)
|
||||
ret_image.paste(fit_image, box=((target_width - fit_width) // 2, (target_height - fit_height) // 2))
|
||||
elif fit == 'crop':
|
||||
if orig_width / orig_height > target_width / target_height: # 更宽,裁左右
|
||||
fit_width = int(orig_height * target_width / target_height)
|
||||
fit_image = image.crop(
|
||||
((orig_width - fit_width) // 2, 0, (orig_width - fit_width) // 2 + fit_width, orig_height))
|
||||
else: # 更瘦,裁上下
|
||||
fit_height = int(orig_width * target_height / target_width)
|
||||
fit_image = image.crop(
|
||||
(0, (orig_height - fit_height) // 2, orig_width, (orig_height - fit_height) // 2 + fit_height))
|
||||
ret_image = fit_image.resize((target_width, target_height), resize_sampler)
|
||||
else:
|
||||
ret_image = image.resize((target_width, target_height), resize_sampler)
|
||||
return ret_image
|
||||
|
||||
# CLIP反推
|
||||
import comfy.utils
|
||||
from torchvision import transforms
|
||||
Config, Interrogator = None, None
|
||||
class CI_Inference:
|
||||
ci_model = None
|
||||
cache_path: str
|
||||
|
||||
def __init__(self):
|
||||
self.ci_model = None
|
||||
self.low_vram = False
|
||||
self.cache_path = os.path.join(folder_paths.models_dir, "clip_interrogator")
|
||||
|
||||
def _load_model(self, model_name, low_vram=False):
|
||||
if not (self.ci_model and model_name == self.ci_model.config.clip_model_name and self.low_vram == low_vram):
|
||||
self.low_vram = low_vram
|
||||
print(f"Load model: {model_name}")
|
||||
|
||||
config = Config(
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
download_cache=True,
|
||||
clip_model_name=model_name,
|
||||
clip_model_path=self.cache_path,
|
||||
cache_path=self.cache_path,
|
||||
caption_model_name='blip-large'
|
||||
)
|
||||
|
||||
if low_vram:
|
||||
config.apply_low_vram_defaults()
|
||||
|
||||
self.ci_model = Interrogator(config)
|
||||
|
||||
def _interrogate(self, image, mode, caption=None):
|
||||
if mode == 'best':
|
||||
prompt = self.ci_model.interrogate(image, caption=caption)
|
||||
elif mode == 'classic':
|
||||
prompt = self.ci_model.interrogate_classic(image, caption=caption)
|
||||
elif mode == 'fast':
|
||||
prompt = self.ci_model.interrogate_fast(image, caption=caption)
|
||||
elif mode == 'negative':
|
||||
prompt = self.ci_model.interrogate_negative(image)
|
||||
else:
|
||||
raise Exception(f"Unknown mode {mode}")
|
||||
return prompt
|
||||
|
||||
def image_to_prompt(self, image, mode, model_name='ViT-L-14/openai', low_vram=False):
|
||||
try:
|
||||
from clip_interrogator import Config, Interrogator
|
||||
global Config, Interrogator
|
||||
except:
|
||||
install_package("clip_interrogator", "0.6.0")
|
||||
from clip_interrogator import Config, Interrogator
|
||||
|
||||
pbar = comfy.utils.ProgressBar(len(image))
|
||||
|
||||
self._load_model(model_name, low_vram)
|
||||
prompt = []
|
||||
for i in range(len(image)):
|
||||
im = image[i]
|
||||
|
||||
im = tensor2pil(im)
|
||||
im = im.convert('RGB')
|
||||
|
||||
_prompt = self._interrogate(im, mode)
|
||||
pbar.update(1)
|
||||
prompt.append(_prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
ci = CI_Inference()
|
||||
237
custom_nodes/ComfyUI-Easy-Use/py/libs/lllite.py
Normal file
237
custom_nodes/ComfyUI-Easy-Use/py/libs/lllite.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import math
|
||||
import torch
|
||||
import comfy
|
||||
|
||||
|
||||
def extra_options_to_module_prefix(extra_options):
|
||||
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
|
||||
|
||||
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
|
||||
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
|
||||
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
|
||||
# block_index is: 0-1 or 0-9, depends on the block
|
||||
# input 7 and 8, middle has 10 blocks
|
||||
|
||||
# make module name from extra_options
|
||||
block = extra_options["block"]
|
||||
block_index = extra_options["block_index"]
|
||||
if block[0] == "input":
|
||||
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
||||
elif block[0] == "middle":
|
||||
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
|
||||
elif block[0] == "output":
|
||||
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
||||
else:
|
||||
raise Exception("invalid block name")
|
||||
return module_pfx
|
||||
|
||||
|
||||
def load_control_net_lllite_patch(path, cond_image, multiplier, num_steps, start_percent, end_percent):
|
||||
# calculate start and end step
|
||||
start_step = math.floor(num_steps * start_percent * 0.01) if start_percent > 0 else 0
|
||||
end_step = math.floor(num_steps * end_percent * 0.01) if end_percent > 0 else num_steps
|
||||
|
||||
# load weights
|
||||
ctrl_sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
|
||||
# split each weights for each module
|
||||
module_weights = {}
|
||||
for key, value in ctrl_sd.items():
|
||||
fragments = key.split(".")
|
||||
module_name = fragments[0]
|
||||
weight_name = ".".join(fragments[1:])
|
||||
|
||||
if module_name not in module_weights:
|
||||
module_weights[module_name] = {}
|
||||
module_weights[module_name][weight_name] = value
|
||||
|
||||
# load each module
|
||||
modules = {}
|
||||
for module_name, weights in module_weights.items():
|
||||
# ここの自動判定を何とかしたい
|
||||
if "conditioning1.4.weight" in weights:
|
||||
depth = 3
|
||||
elif weights["conditioning1.2.weight"].shape[-1] == 4:
|
||||
depth = 2
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
module = LLLiteModule(
|
||||
name=module_name,
|
||||
is_conv2d=weights["down.0.weight"].ndim == 4,
|
||||
in_dim=weights["down.0.weight"].shape[1],
|
||||
depth=depth,
|
||||
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
|
||||
mlp_dim=weights["down.0.weight"].shape[0],
|
||||
multiplier=multiplier,
|
||||
num_steps=num_steps,
|
||||
start_step=start_step,
|
||||
end_step=end_step,
|
||||
)
|
||||
info = module.load_state_dict(weights)
|
||||
modules[module_name] = module
|
||||
if len(modules) == 1:
|
||||
module.is_first = True
|
||||
|
||||
print(f"loaded {path} successfully, {len(modules)} modules")
|
||||
|
||||
# cond imageをセットする
|
||||
cond_image = cond_image.permute(0, 3, 1, 2) # b,h,w,3 -> b,3,h,w
|
||||
cond_image = cond_image * 2.0 - 1.0 # 0-1 -> -1-+1
|
||||
|
||||
for module in modules.values():
|
||||
module.set_cond_image(cond_image)
|
||||
|
||||
class control_net_lllite_patch:
|
||||
def __init__(self, modules):
|
||||
self.modules = modules
|
||||
|
||||
def __call__(self, q, k, v, extra_options):
|
||||
module_pfx = extra_options_to_module_prefix(extra_options)
|
||||
|
||||
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
|
||||
if is_attn1:
|
||||
module_pfx = module_pfx + "_attn1"
|
||||
else:
|
||||
module_pfx = module_pfx + "_attn2"
|
||||
|
||||
module_pfx_to_q = module_pfx + "_to_q"
|
||||
module_pfx_to_k = module_pfx + "_to_k"
|
||||
module_pfx_to_v = module_pfx + "_to_v"
|
||||
|
||||
if module_pfx_to_q in self.modules:
|
||||
q = q + self.modules[module_pfx_to_q](q)
|
||||
if module_pfx_to_k in self.modules:
|
||||
k = k + self.modules[module_pfx_to_k](k)
|
||||
if module_pfx_to_v in self.modules:
|
||||
v = v + self.modules[module_pfx_to_v](v)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def to(self, device):
|
||||
for d in self.modules.keys():
|
||||
self.modules[d] = self.modules[d].to(device)
|
||||
return self
|
||||
|
||||
return control_net_lllite_patch(modules)
|
||||
|
||||
class LLLiteModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
is_conv2d: bool,
|
||||
in_dim: int,
|
||||
depth: int,
|
||||
cond_emb_dim: int,
|
||||
mlp_dim: int,
|
||||
multiplier: int,
|
||||
num_steps: int,
|
||||
start_step: int,
|
||||
end_step: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.is_conv2d = is_conv2d
|
||||
self.multiplier = multiplier
|
||||
self.num_steps = num_steps
|
||||
self.start_step = start_step
|
||||
self.end_step = end_step
|
||||
self.is_first = False
|
||||
|
||||
modules = []
|
||||
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
|
||||
if depth == 1:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
elif depth == 2:
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
||||
elif depth == 3:
|
||||
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
||||
modules.append(torch.nn.ReLU(inplace=True))
|
||||
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
||||
|
||||
self.conditioning1 = torch.nn.Sequential(*modules)
|
||||
|
||||
if self.is_conv2d:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
else:
|
||||
self.down = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.mid = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
)
|
||||
self.up = torch.nn.Sequential(
|
||||
torch.nn.Linear(mlp_dim, in_dim),
|
||||
)
|
||||
|
||||
self.depth = depth
|
||||
self.cond_image = None
|
||||
self.cond_emb = None
|
||||
self.current_step = 0
|
||||
|
||||
# @torch.inference_mode()
|
||||
def set_cond_image(self, cond_image):
|
||||
# print("set_cond_image", self.name)
|
||||
self.cond_image = cond_image
|
||||
self.cond_emb = None
|
||||
self.current_step = 0
|
||||
|
||||
def forward(self, x):
|
||||
if self.num_steps > 0:
|
||||
if self.current_step < self.start_step:
|
||||
self.current_step += 1
|
||||
return torch.zeros_like(x)
|
||||
elif self.current_step >= self.end_step:
|
||||
if self.is_first and self.current_step == self.end_step:
|
||||
print(f"end LLLite: step {self.current_step}")
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.num_steps:
|
||||
self.current_step = 0 # reset
|
||||
return torch.zeros_like(x)
|
||||
else:
|
||||
if self.is_first and self.current_step == self.start_step:
|
||||
print(f"start LLLite: step {self.current_step}")
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.num_steps:
|
||||
self.current_step = 0 # reset
|
||||
|
||||
if self.cond_emb is None:
|
||||
# print(f"cond_emb is None, {self.name}")
|
||||
cx = self.conditioning1(self.cond_image.to(x.device, dtype=x.dtype))
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
n, c, h, w = cx.shape
|
||||
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
||||
self.cond_emb = cx
|
||||
|
||||
cx = self.cond_emb
|
||||
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
|
||||
|
||||
# uncond/condでxはバッチサイズが2倍
|
||||
if x.shape[0] != cx.shape[0]:
|
||||
if self.is_conv2d:
|
||||
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
|
||||
else:
|
||||
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
|
||||
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
|
||||
|
||||
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
||||
cx = self.mid(cx)
|
||||
cx = self.up(cx)
|
||||
return cx * self.multiplier
|
||||
570
custom_nodes/ComfyUI-Easy-Use/py/libs/loader.py
Normal file
570
custom_nodes/ComfyUI-Easy-Use/py/libs/loader.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import re, time, os, psutil
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import comfy.controlnet
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
from collections import defaultdict
|
||||
from .log import log_node_info, log_node_error
|
||||
from ..modules.dit.pixArt.loader import load_pixart
|
||||
|
||||
diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy fluxLoader", "easy comfyLoader", "easy hunyuanDiTLoader", "easy zero123Loader", "easy svdLoader"]
|
||||
stable_cascade_loaders = ["easy cascadeLoader"]
|
||||
dit_loaders = ['easy pixArtLoader']
|
||||
controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV", "easy controlnetLoader++"]
|
||||
instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"]
|
||||
cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"]
|
||||
model_merge_node = ["easy XYInputs: ModelMergeBlocks"]
|
||||
lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy fluxLoader"]
|
||||
|
||||
class easyLoader:
|
||||
def __init__(self):
|
||||
self.loaded_objects = {
|
||||
"ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)}
|
||||
"unet": defaultdict(tuple),
|
||||
"clip": defaultdict(tuple),
|
||||
"clip_vision": defaultdict(tuple),
|
||||
"bvae": defaultdict(tuple),
|
||||
"vae": defaultdict(object),
|
||||
"lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}}
|
||||
"controlnet": defaultdict(dict),
|
||||
"t5": defaultdict(tuple),
|
||||
"chatglm3": defaultdict(tuple),
|
||||
}
|
||||
self.memory_threshold = self.determine_memory_threshold(1)
|
||||
self.lora_name_cache = []
|
||||
|
||||
def clean_values(self, values: str):
|
||||
original_values = values.split("; ")
|
||||
cleaned_values = []
|
||||
|
||||
for value in original_values:
|
||||
cleaned_value = value.strip(';').strip()
|
||||
if cleaned_value == "":
|
||||
continue
|
||||
try:
|
||||
cleaned_value = int(cleaned_value)
|
||||
except ValueError:
|
||||
try:
|
||||
cleaned_value = float(cleaned_value)
|
||||
except ValueError:
|
||||
pass
|
||||
cleaned_values.append(cleaned_value)
|
||||
|
||||
return cleaned_values
|
||||
|
||||
def clear_unused_objects(self, desired_names: set, object_type: str):
|
||||
keys = set(self.loaded_objects[object_type].keys())
|
||||
for key in keys - desired_names:
|
||||
del self.loaded_objects[object_type][key]
|
||||
|
||||
def get_input_value(self, entry, key, prompt=None):
|
||||
val = entry["inputs"][key]
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif isinstance(val, list):
|
||||
if prompt is not None and val[0]:
|
||||
return prompt[val[0]]['inputs'][key]
|
||||
else:
|
||||
return val[0]
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""):
|
||||
for idx in range(1, num_loras + 1):
|
||||
lora_name_key = f"{suffix}lora{idx}_name"
|
||||
desired_lora_names.add(self.get_input_value(entry, lora_name_key))
|
||||
setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}'
|
||||
desired_lora_settings.add(setting)
|
||||
|
||||
desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name"))
|
||||
|
||||
def update_loaded_objects(self, prompt):
|
||||
desired_ckpt_names = set()
|
||||
desired_unet_names = set()
|
||||
desired_clip_names = set()
|
||||
desired_vae_names = set()
|
||||
desired_lora_names = set()
|
||||
desired_lora_settings = set()
|
||||
desired_controlnet_names = set()
|
||||
desired_t5_names = set()
|
||||
desired_glm3_names = set()
|
||||
|
||||
for entry in prompt.values():
|
||||
class_type = entry["class_type"]
|
||||
if class_type in lora_widget:
|
||||
lora_name = self.get_input_value(entry, "lora_name")
|
||||
desired_lora_names.add(lora_name)
|
||||
setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}'
|
||||
desired_lora_settings.add(setting)
|
||||
|
||||
if class_type in diffusion_loaders:
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt))
|
||||
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
|
||||
|
||||
elif class_type in ['easy kolorsLoader']:
|
||||
desired_unet_names.add(self.get_input_value(entry, "unet_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
|
||||
desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name"))
|
||||
|
||||
elif class_type in dit_loaders:
|
||||
t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None
|
||||
clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None
|
||||
model_name = self.get_input_value(entry, "model_name")
|
||||
ckpt_name = self.get_input_value(entry, "ckpt_name", prompt)
|
||||
if t5_name:
|
||||
desired_t5_names.add(t5_name)
|
||||
if clip_name:
|
||||
desired_clip_names.add(clip_name)
|
||||
desired_ckpt_names.add(ckpt_name+'_'+model_name)
|
||||
|
||||
elif class_type in stable_cascade_loaders:
|
||||
desired_unet_names.add(self.get_input_value(entry, "stage_c"))
|
||||
desired_unet_names.add(self.get_input_value(entry, "stage_b"))
|
||||
desired_clip_names.add(self.get_input_value(entry, "clip_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, "stage_a"))
|
||||
|
||||
elif class_type in cascade_vae_node:
|
||||
encode_vae_name = self.get_input_value(entry, "encode_vae_name")
|
||||
decode_vae_name = self.get_input_value(entry, "decode_vae_name")
|
||||
if encode_vae_name and encode_vae_name != 'None':
|
||||
desired_vae_names.add(encode_vae_name)
|
||||
if decode_vae_name and decode_vae_name != 'None':
|
||||
desired_vae_names.add(decode_vae_name)
|
||||
|
||||
elif class_type in controlnet_loaders:
|
||||
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
|
||||
scale_soft_weights = self.get_input_value(entry, "scale_soft_weights")
|
||||
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
|
||||
|
||||
elif class_type in instant_loaders:
|
||||
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
|
||||
scale_soft_weights = self.get_input_value(entry, "cn_soft_weights")
|
||||
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
|
||||
|
||||
elif class_type in model_merge_node:
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1"))
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2"))
|
||||
vae_use = self.get_input_value(entry, "vae_use")
|
||||
if vae_use != 'Use Model 1' and vae_use != 'Use Model 2':
|
||||
desired_vae_names.add(vae_use)
|
||||
|
||||
object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"]
|
||||
for object_type in object_types:
|
||||
if object_type == 'unet':
|
||||
desired_names = desired_unet_names
|
||||
elif object_type in ["ckpt", "clip", "bvae"]:
|
||||
if object_type == 'clip':
|
||||
desired_names = desired_ckpt_names.union(desired_clip_names)
|
||||
else:
|
||||
desired_names = desired_ckpt_names
|
||||
elif object_type == "vae":
|
||||
desired_names = desired_vae_names
|
||||
elif object_type == "controlnet":
|
||||
desired_names = desired_controlnet_names
|
||||
elif object_type == "t5":
|
||||
desired_names = desired_t5_names
|
||||
elif object_type == "chatglm3":
|
||||
desired_names = desired_glm3_names
|
||||
else:
|
||||
desired_names = desired_lora_names
|
||||
self.clear_unused_objects(desired_names, object_type)
|
||||
|
||||
def add_to_cache(self, obj_type, key, value):
|
||||
"""
|
||||
Add an item to the cache with the current timestamp.
|
||||
"""
|
||||
timestamped_value = (value, time.time())
|
||||
self.loaded_objects[obj_type][key] = timestamped_value
|
||||
|
||||
def determine_memory_threshold(self, percentage=0.8):
|
||||
"""
|
||||
Determines the memory threshold as a percentage of the total available memory.
|
||||
Args:
|
||||
- percentage (float): The fraction of total memory to use as the threshold.
|
||||
Should be a value between 0 and 1. Default is 0.8 (80%).
|
||||
Returns:
|
||||
- memory_threshold (int): Memory threshold in bytes.
|
||||
"""
|
||||
total_memory = psutil.virtual_memory().total
|
||||
memory_threshold = total_memory * percentage
|
||||
return memory_threshold
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""
|
||||
Returns the memory usage of the current process in bytes.
|
||||
"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss
|
||||
|
||||
def eviction_based_on_memory(self):
|
||||
"""
|
||||
Evicts objects from cache based on memory usage and priority.
|
||||
"""
|
||||
current_memory = self.get_memory_usage()
|
||||
if current_memory < self.memory_threshold:
|
||||
return
|
||||
eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"]
|
||||
for obj_type in eviction_order:
|
||||
if current_memory < self.memory_threshold:
|
||||
break
|
||||
# Sort items based on age (using the timestamp)
|
||||
items = list(self.loaded_objects[obj_type].items())
|
||||
items.sort(key=lambda x: x[1][1]) # Sorting by timestamp
|
||||
|
||||
for item in items:
|
||||
if current_memory < self.memory_threshold:
|
||||
break
|
||||
del self.loaded_objects[obj_type][item[0]]
|
||||
current_memory = self.get_memory_usage()
|
||||
|
||||
def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False):
|
||||
cache_name = ckpt_name
|
||||
if config_name not in [None, "Default"]:
|
||||
cache_name = ckpt_name + "_" + config_name
|
||||
if cache_name in self.loaded_objects["ckpt"]:
|
||||
clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None
|
||||
clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None
|
||||
return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision
|
||||
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
|
||||
output_clip = False if load_vision else True
|
||||
output_clipvision = True if load_vision else False
|
||||
if config_name not in [None, "Default"]:
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
else:
|
||||
model_options = {}
|
||||
if re.search("nf4", ckpt_name):
|
||||
from ..modules.bitsandbytes_NF4 import OPS
|
||||
model_options = {"custom_operations": OPS}
|
||||
loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
|
||||
|
||||
self.add_to_cache("ckpt", cache_name, loaded_ckpt[0])
|
||||
self.add_to_cache("bvae", cache_name, loaded_ckpt[2])
|
||||
|
||||
clip = loaded_ckpt[1]
|
||||
clip_vision = loaded_ckpt[3]
|
||||
if clip:
|
||||
self.add_to_cache("clip", cache_name, clip)
|
||||
if clip_vision:
|
||||
self.add_to_cache("clip_vision", cache_name, clip_vision)
|
||||
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision
|
||||
|
||||
def load_vae(self, vae_name):
|
||||
if vae_name in self.loaded_objects["vae"]:
|
||||
return self.loaded_objects["vae"][vae_name][0]
|
||||
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
loaded_vae = comfy.sd.VAE(sd=sd)
|
||||
self.add_to_cache("vae", vae_name, loaded_vae)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return loaded_vae
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
if unet_name in self.loaded_objects["unet"]:
|
||||
log_node_info("Load UNet", f"{unet_name} cached")
|
||||
return self.loaded_objects["unet"][unet_name][0]
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path)
|
||||
self.add_to_cache("unet", unet_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return model
|
||||
|
||||
def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True):
|
||||
unique_id = f'{control_net_name};{str(scale_soft_weights)}'
|
||||
if use_cache and unique_id in self.loaded_objects["controlnet"]:
|
||||
return self.loaded_objects["controlnet"][unique_id][0]
|
||||
if scale_soft_weights < 1:
|
||||
if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS:
|
||||
soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights']
|
||||
(weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False)
|
||||
cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced']
|
||||
control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe)
|
||||
else:
|
||||
raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'")
|
||||
else:
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
control_net = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
if use_cache:
|
||||
self.add_to_cache("controlnet", unique_id, control_net)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return control_net
|
||||
def load_clip(self, clip_name, type='stable_diffusion', load_clip=None):
|
||||
if clip_name in self.loaded_objects["clip"]:
|
||||
return self.loaded_objects["clip"][clip_name][0]
|
||||
|
||||
if type == 'stable_diffusion':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == 'stable_cascade':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
elif type == 'sd3':
|
||||
clip_type = comfy.sd.CLIPType.SD3
|
||||
elif type == 'flux':
|
||||
clip_type = comfy.sd.CLIPType.FLUX
|
||||
elif type == 'stable_audio':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
self.add_to_cache("clip", clip_name, load_clip)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return load_clip
|
||||
|
||||
def load_lora(self, lora, model=None, clip=None, type=None , use_cache=True):
|
||||
lora_name = lora["lora_name"]
|
||||
model = model if model is not None else lora["model"]
|
||||
clip = clip if clip is not None else lora["clip"]
|
||||
model_strength = lora["model_strength"]
|
||||
clip_strength = lora["clip_strength"]
|
||||
lbw = lora["lbw"] if "lbw" in lora else None
|
||||
lbw_a = lora["lbw_a"] if "lbw_a" in lora else None
|
||||
lbw_b = lora["lbw_b"] if "lbw_b" in lora else None
|
||||
|
||||
model_hash = str(model)[44:-1]
|
||||
clip_hash = str(clip)[25:-1] if clip else ''
|
||||
|
||||
unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}'
|
||||
|
||||
if use_cache and unique_id in self.loaded_objects["lora"]:
|
||||
log_node_info("Load LORA",f"{lora_name} cached")
|
||||
return self.loaded_objects["lora"][unique_id][0]
|
||||
|
||||
orig_lora_name = lora_name
|
||||
lora_name = self.resolve_lora_name(lora_name)
|
||||
|
||||
if lora_name is not None:
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
else:
|
||||
lora_path = None
|
||||
|
||||
if lora_path is not None:
|
||||
log_node_info("Load LORA",f"{lora_name}: model={model_strength:.3f}, clip={clip_strength:.3f}, LBW={lbw}, A={lbw_a}, B={lbw_b}")
|
||||
if lbw:
|
||||
lbw = lora["lbw"]
|
||||
lbw_a = lora["lbw_a"]
|
||||
lbw_b = lora["lbw_b"]
|
||||
if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS:
|
||||
raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack')
|
||||
cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire']
|
||||
model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0,
|
||||
lbw_a, lbw_b, "", lbw)
|
||||
else:
|
||||
_lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
keys = _lora.keys()
|
||||
if "down_blocks.0.resnets.0.norm1.bias" in keys:
|
||||
print('Using LORA for Resadapter')
|
||||
key_map = {}
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
mapping_norm = {}
|
||||
|
||||
for key in keys:
|
||||
if ".weight" in key:
|
||||
key_name_in_ori_sd = key_map[key.replace(".weight", "")]
|
||||
mapping_norm[key_name_in_ori_sd] = _lora[key]
|
||||
elif ".bias" in key:
|
||||
key_name_in_ori_sd = key_map[key.replace(".bias", "")]
|
||||
mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[
|
||||
key
|
||||
]
|
||||
else:
|
||||
print("===>Unexpected key", key)
|
||||
mapping_norm[key] = _lora[key]
|
||||
|
||||
for k in mapping_norm.keys():
|
||||
if k not in model.model.state_dict():
|
||||
print("===>Missing key:", k)
|
||||
model.model.load_state_dict(mapping_norm, strict=False)
|
||||
return (model, clip)
|
||||
|
||||
# PixArt
|
||||
if type is not None and type == 'PixArt':
|
||||
from ..modules.dit.pixArt.loader import load_pixart_lora
|
||||
model = load_pixart_lora(model, _lora, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength)
|
||||
|
||||
if use_cache:
|
||||
self.add_to_cache("lora", unique_id, (model, clip))
|
||||
self.eviction_based_on_memory()
|
||||
else:
|
||||
log_node_error(f"LORA NOT FOUND", orig_lora_name)
|
||||
|
||||
return model, clip
|
||||
|
||||
def resolve_lora_name(self, name):
|
||||
if os.path.exists(name):
|
||||
return name
|
||||
else:
|
||||
if len(self.lora_name_cache) == 0:
|
||||
loras = folder_paths.get_filename_list("loras")
|
||||
self.lora_name_cache.extend(loras)
|
||||
for x in self.lora_name_cache:
|
||||
if x.endswith(name):
|
||||
return x
|
||||
|
||||
# 如果刷新网页后新添加的lora走这个逻辑
|
||||
log_node_info("LORA NOT IN CACHE", f"{name}")
|
||||
loras = folder_paths.get_filename_list("loras")
|
||||
for x in loras:
|
||||
if x.endswith(name):
|
||||
self.lora_name_cache.append(x)
|
||||
return x
|
||||
|
||||
return None
|
||||
|
||||
def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False):
|
||||
model: ModelPatcher | None = None
|
||||
clip: comfy.sd.CLIP | None = None
|
||||
vae: comfy.sd.VAE | None = None
|
||||
clip_vision = None
|
||||
lora_stack = []
|
||||
|
||||
# Check for model override
|
||||
can_load_lora = True
|
||||
# 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型
|
||||
# Determine whether there is a model or Lora overlapping xyplot, and if there is, prioritize caching the first model.
|
||||
xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks",
|
||||
"easy XYInputs: Checkpoint"]), None)
|
||||
# This will find nodes that aren't actively connected to anything, and skip loading lora's for them.
|
||||
xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None)
|
||||
if xy_lora_id is not None:
|
||||
can_load_lora = False
|
||||
if xy_model_id is not None:
|
||||
node = prompt[xy_model_id]
|
||||
if "ckpt_name_1" in node["inputs"]:
|
||||
ckpt_name_1 = node["inputs"]["ckpt_name_1"]
|
||||
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1)
|
||||
can_load_lora = False
|
||||
elif model_override is not None and clip_override is not None and vae_override is not None:
|
||||
model = model_override
|
||||
clip = clip_override
|
||||
vae = vae_override
|
||||
else:
|
||||
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name)
|
||||
if model_override is not None:
|
||||
model = model_override
|
||||
if vae_override is not None:
|
||||
vae = vae_override
|
||||
elif clip_override is not None:
|
||||
clip = clip_override
|
||||
|
||||
|
||||
if optional_lora_stack is not None and can_load_lora:
|
||||
for lora in optional_lora_stack:
|
||||
# This is a subtle bit of code because it uses the model created by the last call, and passes it to the next call.
|
||||
lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1],
|
||||
"clip_strength": lora[2]}
|
||||
model, clip = self.load_lora(lora)
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
lora_stack.append(lora)
|
||||
|
||||
if lora_name != "None" and can_load_lora:
|
||||
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength,
|
||||
"clip_strength": lora_clip_strength}
|
||||
model, clip = self.load_lora(lora)
|
||||
lora_stack.append(lora)
|
||||
|
||||
# Check for custom VAE
|
||||
if vae_name not in ["Baked VAE", "Baked-VAE"]:
|
||||
vae = self.load_vae(vae_name)
|
||||
# CLIP skip
|
||||
if not clip:
|
||||
raise Exception("No CLIP found")
|
||||
|
||||
return model, clip, vae, clip_vision, lora_stack
|
||||
|
||||
# Kolors
|
||||
def load_kolors_unet(self, unet_name):
|
||||
if unet_name in self.loaded_objects["unet"]:
|
||||
log_node_info("Load Kolors UNet", f"{unet_name} cached")
|
||||
return self.loaded_objects["unet"][unet_name][0]
|
||||
else:
|
||||
from ..modules.kolors.loader import applyKolorsUnet
|
||||
with applyKolorsUnet():
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = comfy.sd.load_unet_state_dict(sd)
|
||||
if model is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
|
||||
self.add_to_cache("unet", unet_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return model
|
||||
|
||||
def load_chatglm3(self, chatglm3_name):
|
||||
from ..modules.kolors.loader import load_chatglm3
|
||||
if chatglm3_name in self.loaded_objects["chatglm3"]:
|
||||
log_node_info("Load ChatGLM3", f"{chatglm3_name} cached")
|
||||
return self.loaded_objects["chatglm3"][chatglm3_name][0]
|
||||
|
||||
chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name))
|
||||
self.add_to_cache("chatglm3", chatglm3_name, chatglm_model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return chatglm_model
|
||||
|
||||
|
||||
# DiT
|
||||
def load_dit_ckpt(self, ckpt_name, model_name, **kwargs):
|
||||
if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]:
|
||||
return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0]
|
||||
model = None
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt'
|
||||
if model_type == 'PixArt':
|
||||
pixart_conf = kwargs['pixart_conf']
|
||||
model_conf = pixart_conf[model_name]
|
||||
model = load_pixart(ckpt_path, model_conf)
|
||||
if model:
|
||||
self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
return model
|
||||
|
||||
def load_t5_from_sd3_clip(self, sd3_clip, padding):
|
||||
try:
|
||||
from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel
|
||||
except:
|
||||
from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel
|
||||
import copy
|
||||
|
||||
clip = sd3_clip.clone()
|
||||
assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!"
|
||||
|
||||
# remove transformer
|
||||
transformer = clip.cond_stage_model.t5xxl.transformer
|
||||
clip.cond_stage_model.t5xxl.transformer = None
|
||||
|
||||
# clone object
|
||||
tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False)
|
||||
tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl)
|
||||
# put transformer back
|
||||
clip.cond_stage_model.t5xxl.transformer = transformer
|
||||
tmp.t5xxl.transformer = transformer
|
||||
|
||||
# override special tokens
|
||||
tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens)
|
||||
tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match
|
||||
|
||||
# tokenizer
|
||||
tok = SD3Tokenizer()
|
||||
tok.t5xxl.min_length = padding
|
||||
|
||||
clip.cond_stage_model = tmp
|
||||
clip.tokenizer = tok
|
||||
|
||||
return clip
|
||||
77
custom_nodes/ComfyUI-Easy-Use/py/libs/log.py
Normal file
77
custom_nodes/ComfyUI-Easy-Use/py/libs/log.py
Normal file
@@ -0,0 +1,77 @@
|
||||
COLORS_FG = {
|
||||
'BLACK': '\33[30m',
|
||||
'RED': '\33[31m',
|
||||
'GREEN': '\33[32m',
|
||||
'YELLOW': '\33[33m',
|
||||
'BLUE': '\33[34m',
|
||||
'MAGENTA': '\33[35m',
|
||||
'CYAN': '\33[36m',
|
||||
'WHITE': '\33[37m',
|
||||
'GREY': '\33[90m',
|
||||
'BRIGHT_RED': '\33[91m',
|
||||
'BRIGHT_GREEN': '\33[92m',
|
||||
'BRIGHT_YELLOW': '\33[93m',
|
||||
'BRIGHT_BLUE': '\33[94m',
|
||||
'BRIGHT_MAGENTA': '\33[95m',
|
||||
'BRIGHT_CYAN': '\33[96m',
|
||||
'BRIGHT_WHITE': '\33[97m',
|
||||
}
|
||||
COLORS_STYLE = {
|
||||
'RESET': '\33[0m',
|
||||
'BOLD': '\33[1m',
|
||||
'NORMAL': '\33[22m',
|
||||
'ITALIC': '\33[3m',
|
||||
'UNDERLINE': '\33[4m',
|
||||
'BLINK': '\33[5m',
|
||||
'BLINK2': '\33[6m',
|
||||
'SELECTED': '\33[7m',
|
||||
}
|
||||
COLORS_BG = {
|
||||
'BLACK': '\33[40m',
|
||||
'RED': '\33[41m',
|
||||
'GREEN': '\33[42m',
|
||||
'YELLOW': '\33[43m',
|
||||
'BLUE': '\33[44m',
|
||||
'MAGENTA': '\33[45m',
|
||||
'CYAN': '\33[46m',
|
||||
'WHITE': '\33[47m',
|
||||
'GREY': '\33[100m',
|
||||
'BRIGHT_RED': '\33[101m',
|
||||
'BRIGHT_GREEN': '\33[102m',
|
||||
'BRIGHT_YELLOW': '\33[103m',
|
||||
'BRIGHT_BLUE': '\33[104m',
|
||||
'BRIGHT_MAGENTA': '\33[105m',
|
||||
'BRIGHT_CYAN': '\33[106m',
|
||||
'BRIGHT_WHITE': '\33[107m',
|
||||
}
|
||||
|
||||
def log_node_success(node_name, message=None):
|
||||
"""Logs a success message."""
|
||||
_log_node(COLORS_FG["GREEN"], node_name, message)
|
||||
|
||||
def log_node_info(node_name, message=None):
|
||||
"""Logs an info message."""
|
||||
_log_node(COLORS_FG["CYAN"], node_name, message)
|
||||
|
||||
|
||||
def log_node_warn(node_name, message=None):
|
||||
"""Logs an warn message."""
|
||||
_log_node(COLORS_FG["YELLOW"], node_name, message)
|
||||
|
||||
def log_node_error(node_name, message=None):
|
||||
"""Logs an warn message."""
|
||||
_log_node(COLORS_FG["RED"], node_name, message)
|
||||
|
||||
def log_node(node_name, message=None):
|
||||
"""Logs a message."""
|
||||
_log_node(COLORS_FG["CYAN"], node_name, message)
|
||||
|
||||
|
||||
def _log_node(color, node_name, message=None, prefix=''):
|
||||
print(_get_log_msg(color, node_name, message, prefix=prefix))
|
||||
|
||||
def _get_log_msg(color, node_name, message=None, prefix=''):
|
||||
msg = f'{COLORS_STYLE["BOLD"]}{color}{prefix}[EasyUse] {node_name.replace(" (EasyUse)", "")}'
|
||||
msg += f':{COLORS_STYLE["RESET"]} {message}' if message is not None else f'{COLORS_STYLE["RESET"]}'
|
||||
return msg
|
||||
|
||||
133
custom_nodes/ComfyUI-Easy-Use/py/libs/math.py
Normal file
133
custom_nodes/ComfyUI-Easy-Use/py/libs/math.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Math utility functions for formula evaluation
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
|
||||
def evaluate_formula(formula: str, a=0, b=0, c=0, d=0) -> float:
|
||||
"""
|
||||
计算字符串数学公式
|
||||
|
||||
支持的运算符和函数:
|
||||
- 基本运算:+, -, *, /, //, %, **
|
||||
- 比较运算:>, <, >=, <=, ==, !=
|
||||
- 数学函数:abs, pow, round, ceil, floor, sqrt, exp, log, log10
|
||||
- 三角函数:sin, cos, tan, asin, acos, atan
|
||||
- 常量:pi, e
|
||||
|
||||
Args:
|
||||
formula: 数学公式字符串,可以使用变量a、b、c、d
|
||||
a: 变量a的值
|
||||
b: 变量b的值
|
||||
c: 变量c的值
|
||||
d: 变量d的值
|
||||
|
||||
Returns:
|
||||
计算结果
|
||||
|
||||
Examples:
|
||||
>>> evaluate_formula("a + b", 1, 2)
|
||||
3.0
|
||||
>>> evaluate_formula("pow(a, 2)", 5)
|
||||
25.0
|
||||
>>> evaluate_formula("ceil(a / b)", 5, 2)
|
||||
3.0
|
||||
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 5, 3)
|
||||
3.0
|
||||
>>> evaluate_formula("(a>b)*b+(a<=b)*a", 2, 3)
|
||||
2.0
|
||||
"""
|
||||
# 安全的数学函数白名单
|
||||
safe_dict = {
|
||||
# 基本运算
|
||||
'abs': abs,
|
||||
'pow': pow,
|
||||
'round': round,
|
||||
# 数学函数
|
||||
'ceil': math.ceil,
|
||||
'floor': math.floor,
|
||||
'sqrt': math.sqrt,
|
||||
'exp': math.exp,
|
||||
'log': math.log,
|
||||
'log10': math.log10,
|
||||
# 三角函数
|
||||
'sin': math.sin,
|
||||
'cos': math.cos,
|
||||
'tan': math.tan,
|
||||
'asin': math.asin,
|
||||
'acos': math.acos,
|
||||
'atan': math.atan,
|
||||
# 常量
|
||||
'pi': math.pi,
|
||||
'e': math.e,
|
||||
# 变量
|
||||
'a': float(a),
|
||||
'b': float(b),
|
||||
'c': float(c),
|
||||
'd': float(d),
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用eval计算公式,限制可用的函数和变量
|
||||
result = eval(formula, {"__builtins__": {}}, safe_dict)
|
||||
return float(result)
|
||||
except Exception as e:
|
||||
raise ValueError(f"公式计算错误: {str(e)}")
|
||||
|
||||
|
||||
def ceil_value(value: float) -> int:
|
||||
"""向上取整"""
|
||||
return math.ceil(value)
|
||||
|
||||
|
||||
def floor_value(value: float) -> int:
|
||||
"""向下取整"""
|
||||
return math.floor(value)
|
||||
|
||||
|
||||
def round_value(value: float, decimals: int = 0) -> float:
|
||||
"""
|
||||
四舍五入
|
||||
|
||||
Args:
|
||||
value: 要取整的值
|
||||
decimals: 保留小数位数
|
||||
|
||||
Returns:
|
||||
四舍五入后的值
|
||||
"""
|
||||
return round(value, decimals)
|
||||
|
||||
|
||||
def power(base: float, exponent: float) -> float:
|
||||
"""计算幂运算"""
|
||||
return math.pow(base, exponent)
|
||||
|
||||
|
||||
def sqrt_value(value: float) -> float:
|
||||
"""计算平方根"""
|
||||
if value < 0:
|
||||
raise ValueError("不能对负数求平方根")
|
||||
return math.sqrt(value)
|
||||
|
||||
|
||||
def add(a: float, b: float) -> float:
|
||||
"""加法"""
|
||||
return a + b
|
||||
|
||||
|
||||
def subtract(a: float, b: float) -> float:
|
||||
"""减法"""
|
||||
return a - b
|
||||
|
||||
|
||||
def multiply(a: float, b: float) -> float:
|
||||
"""乘法"""
|
||||
return a * b
|
||||
|
||||
|
||||
def divide(a: float, b: float) -> float:
|
||||
"""除法"""
|
||||
if b == 0:
|
||||
raise ValueError("除数不能为零")
|
||||
return a / b
|
||||
55
custom_nodes/ComfyUI-Easy-Use/py/libs/messages.py
Normal file
55
custom_nodes/ComfyUI-Easy-Use/py/libs/messages.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from server import PromptServer
|
||||
from aiohttp import web
|
||||
import time
|
||||
import json
|
||||
|
||||
class MessageCancelled(Exception):
|
||||
pass
|
||||
|
||||
class Message:
|
||||
stash = {}
|
||||
messages = {}
|
||||
cancelled = False
|
||||
|
||||
@classmethod
|
||||
def addMessage(cls, id, message):
|
||||
if message == '__cancel__':
|
||||
cls.messages = {}
|
||||
cls.cancelled = True
|
||||
elif message == '__start__':
|
||||
cls.messages = {}
|
||||
cls.stash = {}
|
||||
cls.cancelled = False
|
||||
else:
|
||||
cls.messages[str(id)] = message
|
||||
|
||||
@classmethod
|
||||
def waitForMessage(cls, id, period=0.1, asList=False):
|
||||
sid = str(id)
|
||||
while not (sid in cls.messages) and not ("-1" in cls.messages):
|
||||
if cls.cancelled:
|
||||
cls.cancelled = False
|
||||
raise MessageCancelled()
|
||||
time.sleep(period)
|
||||
if cls.cancelled:
|
||||
cls.cancelled = False
|
||||
raise MessageCancelled()
|
||||
message = cls.messages.pop(str(id), None) or cls.messages.pop("-1")
|
||||
try:
|
||||
if asList:
|
||||
return [str(x.strip()) for x in message.split(",")]
|
||||
else:
|
||||
try:
|
||||
return json.loads(message)
|
||||
except ValueError:
|
||||
return message
|
||||
except ValueError:
|
||||
print( f"ERROR IN MESSAGE - failed to parse '${message}' as ${'comma separated list of strings' if asList else 'string'}")
|
||||
return [message] if asList else message
|
||||
|
||||
|
||||
@PromptServer.instance.routes.post('/easyuse/message_callback')
|
||||
async def message_callback(request):
|
||||
post = await request.post()
|
||||
Message.addMessage(post.get("id"), post.get("message"))
|
||||
return web.json_response({})
|
||||
58
custom_nodes/ComfyUI-Easy-Use/py/libs/model.py
Normal file
58
custom_nodes/ComfyUI-Easy-Use/py/libs/model.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import os
|
||||
import folder_paths
|
||||
import server
|
||||
from .utils import find_tags
|
||||
|
||||
class easyModelManager:
|
||||
|
||||
def __init__(self):
|
||||
self.img_suffixes = [".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".tif", ".tiff"]
|
||||
self.default_suffixes = [".ckpt", ".pt", ".bin", ".pth", ".safetensors"]
|
||||
self.models_config = {
|
||||
"checkpoints": {"suffix": self.default_suffixes},
|
||||
"loras": {"suffix": self.default_suffixes},
|
||||
"unet": {"suffix": self.default_suffixes},
|
||||
}
|
||||
self.model_lists = {}
|
||||
|
||||
def find_thumbnail(self, model_type, name):
|
||||
file_no_ext = os.path.splitext(name)[0]
|
||||
for ext in self.img_suffixes:
|
||||
full_path = folder_paths.get_full_path(model_type, file_no_ext + ext)
|
||||
if os.path.isfile(str(full_path)):
|
||||
return full_path
|
||||
return None
|
||||
|
||||
def get_model_lists(self, model_type):
|
||||
if model_type not in self.models_config:
|
||||
return []
|
||||
filenames = folder_paths.get_filename_list(model_type)
|
||||
model_lists = []
|
||||
for name in filenames:
|
||||
model_suffix = os.path.splitext(name)[-1]
|
||||
if model_suffix not in self.models_config[model_type]["suffix"]:
|
||||
continue
|
||||
else:
|
||||
cfg = {
|
||||
"name": os.path.basename(os.path.splitext(name)[0]),
|
||||
"full_name": name,
|
||||
"remark": '',
|
||||
"file_path": folder_paths.get_full_path(model_type, name),
|
||||
"type": model_type,
|
||||
"suffix": model_suffix,
|
||||
"dir_tags": find_tags(name),
|
||||
"cover": self.find_thumbnail(model_type, name),
|
||||
"metadata": None,
|
||||
"sha256": None
|
||||
}
|
||||
model_lists.append(cfg)
|
||||
|
||||
return model_lists
|
||||
|
||||
def get_model_info(self, model_type, model_name):
|
||||
pass
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# manager = easyModelManager()
|
||||
# print(manager.get_model_lists("checkpoints"))
|
||||
1053
custom_nodes/ComfyUI-Easy-Use/py/libs/sampler.py
Normal file
1053
custom_nodes/ComfyUI-Easy-Use/py/libs/sampler.py
Normal file
File diff suppressed because it is too large
Load Diff
148
custom_nodes/ComfyUI-Easy-Use/py/libs/styleAlign.py
Normal file
148
custom_nodes/ComfyUI-Easy-Use/py/libs/styleAlign.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from typing import Union
|
||||
|
||||
T = torch.Tensor
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d
|
||||
|
||||
|
||||
class StyleAlignedArgs:
|
||||
def __init__(self, share_attn: str) -> None:
|
||||
self.adain_keys = "k" in share_attn
|
||||
self.adain_values = "v" in share_attn
|
||||
self.adain_queries = "q" in share_attn
|
||||
|
||||
share_attention: bool = True
|
||||
adain_queries: bool = True
|
||||
adain_keys: bool = True
|
||||
adain_values: bool = True
|
||||
|
||||
|
||||
def expand_first(
|
||||
feat: T,
|
||||
scale=1.0,
|
||||
) -> T:
|
||||
"""
|
||||
Expand the first element so it has the same shape as the rest of the batch.
|
||||
"""
|
||||
b = feat.shape[0]
|
||||
feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
|
||||
if scale == 1:
|
||||
feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
|
||||
else:
|
||||
feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
|
||||
feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
|
||||
return feat_style.reshape(*feat.shape)
|
||||
|
||||
|
||||
def concat_first(feat: T, dim=2, scale=1.0) -> T:
|
||||
"""
|
||||
concat the the feature and the style feature expanded above
|
||||
"""
|
||||
feat_style = expand_first(feat, scale=scale)
|
||||
return torch.cat((feat, feat_style), dim=dim)
|
||||
|
||||
|
||||
def calc_mean_std(feat, eps: float = 1e-5) -> "tuple[T, T]":
|
||||
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
|
||||
feat_mean = feat.mean(dim=-2, keepdims=True)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adain(feat: T) -> T:
|
||||
feat_mean, feat_std = calc_mean_std(feat)
|
||||
feat_style_mean = expand_first(feat_mean)
|
||||
feat_style_std = expand_first(feat_std)
|
||||
feat = (feat - feat_mean) / feat_std
|
||||
feat = feat * feat_style_std + feat_style_mean
|
||||
return feat
|
||||
|
||||
class SharedAttentionProcessor:
|
||||
def __init__(self, args: StyleAlignedArgs, scale: float):
|
||||
self.args = args
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, q, k, v, extra_options):
|
||||
if self.args.adain_queries:
|
||||
q = adain(q)
|
||||
if self.args.adain_keys:
|
||||
k = adain(k)
|
||||
if self.args.adain_values:
|
||||
v = adain(v)
|
||||
if self.args.share_attention:
|
||||
k = concat_first(k, -2, scale=self.scale)
|
||||
v = concat_first(v, -2)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
||||
def get_norm_layers(
|
||||
layer: nn.Module,
|
||||
norm_layers_: "dict[str, list[Union[nn.GroupNorm, nn.LayerNorm]]]",
|
||||
share_layer_norm: bool,
|
||||
share_group_norm: bool,
|
||||
):
|
||||
if isinstance(layer, nn.LayerNorm) and share_layer_norm:
|
||||
norm_layers_["layer"].append(layer)
|
||||
if isinstance(layer, nn.GroupNorm) and share_group_norm:
|
||||
norm_layers_["group"].append(layer)
|
||||
else:
|
||||
for child_layer in layer.children():
|
||||
get_norm_layers(
|
||||
child_layer, norm_layers_, share_layer_norm, share_group_norm
|
||||
)
|
||||
|
||||
|
||||
def register_norm_forward(
|
||||
norm_layer: Union[nn.GroupNorm, nn.LayerNorm],
|
||||
) -> Union[nn.GroupNorm, nn.LayerNorm]:
|
||||
if not hasattr(norm_layer, "orig_forward"):
|
||||
setattr(norm_layer, "orig_forward", norm_layer.forward)
|
||||
orig_forward = norm_layer.orig_forward
|
||||
|
||||
def forward_(hidden_states: T) -> T:
|
||||
n = hidden_states.shape[-2]
|
||||
hidden_states = concat_first(hidden_states, dim=-2)
|
||||
hidden_states = orig_forward(hidden_states) # type: ignore
|
||||
return hidden_states[..., :n, :]
|
||||
|
||||
norm_layer.forward = forward_ # type: ignore
|
||||
return norm_layer
|
||||
|
||||
|
||||
def register_shared_norm(
|
||||
model: ModelPatcher,
|
||||
share_group_norm: bool = True,
|
||||
share_layer_norm: bool = True,
|
||||
):
|
||||
norm_layers = {"group": [], "layer": []}
|
||||
get_norm_layers(model.model, norm_layers, share_layer_norm, share_group_norm)
|
||||
print(
|
||||
f"Patching {len(norm_layers['group'])} group norms, {len(norm_layers['layer'])} layer norms."
|
||||
)
|
||||
return [register_norm_forward(layer) for layer in norm_layers["group"]] + [
|
||||
register_norm_forward(layer) for layer in norm_layers["layer"]
|
||||
]
|
||||
|
||||
|
||||
SHARE_NORM_OPTIONS = ["both", "group", "layer", "disabled"]
|
||||
SHARE_ATTN_OPTIONS = ["q+k", "q+k+v", "disabled"]
|
||||
|
||||
|
||||
def styleAlignBatch(model, share_norm, share_attn, scale=1.0):
|
||||
m = model.clone()
|
||||
share_group_norm = share_norm in ["group", "both"]
|
||||
share_layer_norm = share_norm in ["layer", "both"]
|
||||
register_shared_norm(model, share_group_norm, share_layer_norm)
|
||||
args = StyleAlignedArgs(share_attn)
|
||||
m.set_model_attn1_patch(SharedAttentionProcessor(args, scale))
|
||||
return m
|
||||
247
custom_nodes/ComfyUI-Easy-Use/py/libs/translate.py
Normal file
247
custom_nodes/ComfyUI-Easy-Use/py/libs/translate.py
Normal file
@@ -0,0 +1,247 @@
|
||||
#credit to shadowcz007 for this module
|
||||
#from https://github.com/shadowcz007/comfyui-mixlab-nodes/blob/main/nodes/TextGenerateNode.py
|
||||
import re
|
||||
import os
|
||||
import folder_paths
|
||||
|
||||
import comfy.utils
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from .utils import install_package
|
||||
try:
|
||||
from lark import Lark, Transformer, v_args
|
||||
except:
|
||||
print('install lark...')
|
||||
install_package('lark')
|
||||
from lark import Lark, Transformer, v_args
|
||||
|
||||
model_path = os.path.join(folder_paths.models_dir, 'prompt_generator')
|
||||
zh_en_model_path = os.path.join(model_path, 'opus-mt-zh-en')
|
||||
zh_en_model, zh_en_tokenizer = None, None
|
||||
|
||||
def correct_prompt_syntax(prompt=""):
|
||||
# print("input prompt",prompt)
|
||||
corrected_elements = []
|
||||
# 处理成统一的英文标点
|
||||
prompt = prompt.replace('(', '(').replace(')', ')').replace(',', ',').replace(';', ',').replace('。', '.').replace(':',':').replace('\\',',')
|
||||
# 删除多余的空格
|
||||
prompt = re.sub(r'\s+', ' ', prompt).strip()
|
||||
prompt = prompt.replace("< ","<").replace(" >",">").replace("( ","(").replace(" )",")").replace("[ ","[").replace(' ]',']')
|
||||
|
||||
# 分词
|
||||
prompt_elements = prompt.split(',')
|
||||
|
||||
def balance_brackets(element, open_bracket, close_bracket):
|
||||
open_brackets_count = element.count(open_bracket)
|
||||
close_brackets_count = element.count(close_bracket)
|
||||
return element + close_bracket * (open_brackets_count - close_brackets_count)
|
||||
|
||||
for element in prompt_elements:
|
||||
element = element.strip()
|
||||
|
||||
# 处理空元素
|
||||
if not element:
|
||||
continue
|
||||
|
||||
# 检查并处理圆括号、方括号、尖括号
|
||||
if element[0] in '([':
|
||||
corrected_element = balance_brackets(element, '(', ')') if element[0] == '(' else balance_brackets(element, '[', ']')
|
||||
elif element[0] == '<':
|
||||
corrected_element = balance_brackets(element, '<', '>')
|
||||
else:
|
||||
# 删除开头的右括号或右方括号
|
||||
corrected_element = element.lstrip(')]')
|
||||
|
||||
corrected_elements.append(corrected_element)
|
||||
|
||||
# 重组修正后的prompt
|
||||
return ','.join(corrected_elements)
|
||||
|
||||
def detect_language(input_str):
|
||||
# 统计中文和英文字符的数量
|
||||
count_cn = count_en = 0
|
||||
for char in input_str:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
count_cn += 1
|
||||
elif char.isalpha():
|
||||
count_en += 1
|
||||
|
||||
# 根据统计的字符数量判断主要语言
|
||||
if count_cn > count_en:
|
||||
return "cn"
|
||||
elif count_en > count_cn:
|
||||
return "en"
|
||||
else:
|
||||
return "unknow"
|
||||
|
||||
def has_chinese(text):
|
||||
has_cn = False
|
||||
_text = text
|
||||
_text = re.sub(r'<.*?>', '', _text)
|
||||
_text = re.sub(r'__.*?__', '', _text)
|
||||
_text = re.sub(r'embedding:.*?$', '', _text)
|
||||
for char in _text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
has_cn = True
|
||||
break
|
||||
elif char.isalpha():
|
||||
continue
|
||||
return has_cn
|
||||
|
||||
def translate(text):
|
||||
global zh_en_model_path, zh_en_model, zh_en_tokenizer
|
||||
|
||||
if not os.path.exists(zh_en_model_path):
|
||||
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
|
||||
|
||||
if zh_en_model is None:
|
||||
|
||||
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
|
||||
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
|
||||
|
||||
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
with torch.no_grad():
|
||||
encoded = zh_en_tokenizer([text], return_tensors="pt")
|
||||
encoded.to(zh_en_model.device)
|
||||
sequences = zh_en_model.generate(**encoded)
|
||||
return zh_en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
|
||||
|
||||
@v_args(inline=True) # Decorator to flatten the tree directly into the function arguments
|
||||
class ChinesePromptTranslate(Transformer):
|
||||
|
||||
def sentence(self, *args):
|
||||
return ", ".join(args)
|
||||
|
||||
def phrase(self, *args):
|
||||
return "".join(args)
|
||||
|
||||
def emphasis(self, *args):
|
||||
# Reconstruct the emphasis with translated content
|
||||
return "(" + "".join(args) + ")"
|
||||
|
||||
def weak_emphasis(self, *args):
|
||||
print('weak_emphasis:', args)
|
||||
return "[" + "".join(args) + "]"
|
||||
|
||||
def embedding(self, *args):
|
||||
print('prompt embedding', args[0])
|
||||
if len(args) == 1:
|
||||
embedding_name = str(args[0])
|
||||
return f"embedding:{embedding_name}"
|
||||
elif len(args) > 1:
|
||||
embedding_name, *numbers = args
|
||||
|
||||
if len(numbers) == 2:
|
||||
return f"embedding:{embedding_name}:{numbers[0]}:{numbers[1]}"
|
||||
elif len(numbers) == 1:
|
||||
return f"embedding:{embedding_name}:{numbers[0]}"
|
||||
else:
|
||||
return f"embedding:{embedding_name}"
|
||||
|
||||
def lora(self, *args):
|
||||
if len(args) == 1:
|
||||
return f"<lora:{args[0]}>"
|
||||
elif len(args) > 1:
|
||||
# print('lora', args)
|
||||
_, loar_name, *numbers = args
|
||||
loar_name = str(loar_name).strip()
|
||||
if len(numbers) == 2:
|
||||
return f"<lora:{loar_name}:{numbers[0]}:{numbers[1]}>"
|
||||
elif len(numbers) == 1:
|
||||
return f"<lora:{loar_name}:{numbers[0]}>"
|
||||
else:
|
||||
return f"<lora:{loar_name}>"
|
||||
|
||||
def weight(self, word, number):
|
||||
translated_word = translate(str(word)).rstrip('.')
|
||||
return f"({translated_word}:{str(number).strip()})"
|
||||
|
||||
def schedule(self, *args):
|
||||
print('prompt schedule', args)
|
||||
data = [str(arg).strip() for arg in args]
|
||||
|
||||
return f"[{':'.join(data)}]"
|
||||
|
||||
def word(self, word):
|
||||
# Translate each word using the dictionary
|
||||
word = str(word)
|
||||
match_cn = re.search(r'@.*?@', word)
|
||||
if re.search(r'__.*?__', word):
|
||||
return word.rstrip('.')
|
||||
elif match_cn:
|
||||
chinese = match_cn.group()
|
||||
before = word.split('@', 1)
|
||||
before = before[0] if len(before) > 0 else ''
|
||||
before = translate(str(before)).rstrip('.') if before else ''
|
||||
after = word.rsplit('@', 1)
|
||||
after = after[len(after)-1] if len(after) > 1 else ''
|
||||
after = translate(after).rstrip('.') if after else ''
|
||||
return before + chinese.replace('@', '').rstrip('.') + after
|
||||
elif detect_language(word) == "cn":
|
||||
return translate(word).rstrip('.')
|
||||
else:
|
||||
return word.rstrip('.')
|
||||
|
||||
|
||||
#定义Prompt文法
|
||||
grammar = r"""
|
||||
start: sentence
|
||||
sentence: phrase ("," phrase)*
|
||||
phrase: emphasis | weight | word | lora | embedding | schedule
|
||||
emphasis: "(" sentence ")" -> emphasis
|
||||
| "[" sentence "]" -> weak_emphasis
|
||||
weight: "(" word ":" NUMBER ")"
|
||||
schedule: "[" word ":" word ":" NUMBER "]"
|
||||
lora: "<" WORD ":" WORD (":" NUMBER)? (":" NUMBER)? ">"
|
||||
embedding: "embedding" ":" WORD (":" NUMBER)? (":" NUMBER)?
|
||||
word: WORD
|
||||
|
||||
NUMBER: /\s*-?\d+(\.\d+)?\s*/
|
||||
WORD: /[^,:\(\)\[\]<>]+/
|
||||
"""
|
||||
def zh_to_en(text):
|
||||
global zh_en_model_path, zh_en_model, zh_en_tokenizer
|
||||
# 进度条
|
||||
pbar = comfy.utils.ProgressBar(len(text) + 1)
|
||||
texts = [correct_prompt_syntax(t) for t in text]
|
||||
|
||||
install_package('sentencepiece', '0.2.0')
|
||||
|
||||
if not os.path.exists(zh_en_model_path):
|
||||
zh_en_model_path = 'Helsinki-NLP/opus-mt-zh-en'
|
||||
|
||||
if zh_en_model is None:
|
||||
zh_en_model = AutoModelForSeq2SeqLM.from_pretrained(zh_en_model_path).eval()
|
||||
zh_en_tokenizer = AutoTokenizer.from_pretrained(zh_en_model_path, padding=True, truncation=True)
|
||||
|
||||
zh_en_model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
prompt_result = []
|
||||
|
||||
en_texts = []
|
||||
|
||||
for t in texts:
|
||||
if t:
|
||||
# translated_text = translated_word = translate(zh_en_tokenizer,zh_en_model,str(t))
|
||||
parser = Lark(grammar, start="start", parser="lalr", transformer=ChinesePromptTranslate())
|
||||
# print('t',t)
|
||||
result = parser.parse(t).children
|
||||
# print('en_result',result)
|
||||
# en_text=translate(zh_en_tokenizer,zh_en_model,text_without_syntax)
|
||||
en_texts.append(result[0])
|
||||
|
||||
zh_en_model.to('cpu')
|
||||
# print("test en_text", en_texts)
|
||||
# en_text.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
pbar.update(1)
|
||||
for t in en_texts:
|
||||
prompt_result.append(t)
|
||||
pbar.update(1)
|
||||
|
||||
# print('prompt_result', prompt_result, )
|
||||
if len(prompt_result) == 0:
|
||||
prompt_result = [""]
|
||||
|
||||
return prompt_result
|
||||
282
custom_nodes/ComfyUI-Easy-Use/py/libs/utils.py
Normal file
282
custom_nodes/ComfyUI-Easy-Use/py/libs/utils.py
Normal file
@@ -0,0 +1,282 @@
|
||||
class AlwaysEqualProxy(str):
|
||||
def __eq__(self, _):
|
||||
return True
|
||||
|
||||
def __ne__(self, _):
|
||||
return False
|
||||
|
||||
class TautologyStr(str):
|
||||
def __ne__(self, other):
|
||||
return False
|
||||
|
||||
class ByPassTypeTuple(tuple):
|
||||
def __getitem__(self, index):
|
||||
if index>0:
|
||||
index=0
|
||||
item = super().__getitem__(index)
|
||||
if isinstance(item, str):
|
||||
return TautologyStr(item)
|
||||
return item
|
||||
|
||||
comfy_ui_revision = None
|
||||
def get_comfyui_revision():
|
||||
try:
|
||||
import git
|
||||
import os
|
||||
import folder_paths
|
||||
repo = git.Repo(os.path.dirname(folder_paths.__file__))
|
||||
comfy_ui_revision = len(list(repo.iter_commits('HEAD')))
|
||||
except:
|
||||
comfy_ui_revision = "Unknown"
|
||||
return comfy_ui_revision
|
||||
|
||||
|
||||
import sys
|
||||
import importlib.util
|
||||
import importlib.metadata
|
||||
import comfy.model_management as mm
|
||||
import gc
|
||||
from packaging import version
|
||||
from server import PromptServer
|
||||
def is_package_installed(package):
|
||||
try:
|
||||
module = importlib.util.find_spec(package)
|
||||
return module is not None
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def install_package(package, v=None, compare=True, compare_version=None):
|
||||
run_install = True
|
||||
if is_package_installed(package):
|
||||
try:
|
||||
installed_version = importlib.metadata.version(package)
|
||||
if v is not None:
|
||||
if compare_version is None:
|
||||
compare_version = v
|
||||
if not compare or version.parse(installed_version) >= version.parse(compare_version):
|
||||
run_install = False
|
||||
else:
|
||||
run_install = False
|
||||
except:
|
||||
run_install = False
|
||||
|
||||
if run_install:
|
||||
import subprocess
|
||||
package_command = package + '==' + v if v is not None else package
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f"Installing {package_command}...", 'duration': 5000})
|
||||
result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', package_command], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed successfully", 'type': 'success', 'duration': 5000})
|
||||
print(f"Package {package} installed successfully")
|
||||
return True
|
||||
else:
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f"{package} installed failed", 'type': 'error', 'duration': 5000})
|
||||
print(f"Package {package} installed failed")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def compare_revision(num):
|
||||
global comfy_ui_revision
|
||||
if not comfy_ui_revision:
|
||||
comfy_ui_revision = get_comfyui_revision()
|
||||
return True if comfy_ui_revision == 'Unknown' or int(comfy_ui_revision) >= num else False
|
||||
|
||||
def find_tags(string: str, sep="/") -> list[str]:
|
||||
"""
|
||||
find tags from string use the sep for split
|
||||
Note: string may contain the \\ or / for path separator
|
||||
"""
|
||||
if not string:
|
||||
return []
|
||||
string = string.replace("\\", "/")
|
||||
while "//" in string:
|
||||
string = string.replace("//", "/")
|
||||
if string and sep in string:
|
||||
return string.split(sep)[:-1]
|
||||
return []
|
||||
|
||||
|
||||
from comfy.model_base import BaseModel
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
def get_sd_version(model):
|
||||
base: BaseModel = model.model
|
||||
model_config: comfy.supported_models.supported_models_base.BASE = base.model_config
|
||||
if isinstance(model_config, comfy.supported_models.SDXL):
|
||||
return 'sdxl'
|
||||
elif isinstance(model_config, comfy.supported_models.SDXLRefiner):
|
||||
return 'sdxl_refiner'
|
||||
elif isinstance(
|
||||
model_config, (comfy.supported_models.SD15, comfy.supported_models.SD20)
|
||||
):
|
||||
return 'sd1'
|
||||
elif isinstance(
|
||||
model_config, (comfy.supported_models.SVD_img2vid)
|
||||
):
|
||||
return 'svd'
|
||||
elif isinstance(model_config, comfy.supported_models.SD3):
|
||||
return 'sd3'
|
||||
elif isinstance(model_config, comfy.supported_models.HunyuanDiT):
|
||||
return 'hydit'
|
||||
elif isinstance(model_config, comfy.supported_models.Flux):
|
||||
return 'flux'
|
||||
elif isinstance(model_config, comfy.supported_models.GenmoMochi):
|
||||
return 'mochi'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
def find_nearest_steps(clip_id, prompt):
|
||||
"""Find the nearest KSampler or preSampling node that references the given id."""
|
||||
def check_link_to_clip(node_id, clip_id, visited=None, node=None):
|
||||
"""Check if a given node links directly or indirectly to a loader node."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if node_id in visited:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
if "pipe" in node["inputs"]:
|
||||
link_ids = node["inputs"]["pipe"]
|
||||
for id in link_ids:
|
||||
if id != 0 and id == str(clip_id):
|
||||
return True
|
||||
return False
|
||||
|
||||
for id in prompt:
|
||||
node = prompt[id]
|
||||
if "Sampler" in node["class_type"] or "sampler" in node["class_type"] or "Sampling" in node["class_type"]:
|
||||
# Check if this KSampler node directly or indirectly references the given CLIPTextEncode node
|
||||
if check_link_to_clip(id, clip_id, None, node):
|
||||
steps = node["inputs"]["steps"] if "steps" in node["inputs"] else 1
|
||||
return steps
|
||||
return 1
|
||||
|
||||
def find_wildcards_seed(clip_id, text, prompt):
|
||||
""" Find easy wildcards seed value"""
|
||||
def find_link_clip_id(id, seed, wildcard_id):
|
||||
node = prompt[id]
|
||||
if "positive" in node['inputs']:
|
||||
link_ids = node["inputs"]["positive"]
|
||||
if type(link_ids) == list:
|
||||
for id in link_ids:
|
||||
if id != 0:
|
||||
if id == wildcard_id:
|
||||
wildcard_node = prompt[wildcard_id]
|
||||
seed = wildcard_node["inputs"]["seed"] if "seed" in wildcard_node["inputs"] else None
|
||||
if seed is None:
|
||||
seed = wildcard_node["inputs"]["seed_num"] if "seed_num" in wildcard_node["inputs"] else None
|
||||
return seed
|
||||
else:
|
||||
return find_link_clip_id(id, seed, wildcard_id)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if "__" in text:
|
||||
seed = None
|
||||
for id in prompt:
|
||||
node = prompt[id]
|
||||
if "wildcards" in node["class_type"]:
|
||||
wildcard_id = id
|
||||
return find_link_clip_id(str(clip_id), seed, wildcard_id)
|
||||
return seed
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_linked_styles_selector(prompt, unique_id, prompt_type='positive'):
|
||||
unique_id = unique_id.split('.')[len(unique_id.split('.')) - 1] if "." in unique_id else unique_id
|
||||
inputs_values = prompt[unique_id]['inputs'][prompt_type] if prompt_type in prompt[unique_id][
|
||||
'inputs'] else None
|
||||
if type(inputs_values) == list and inputs_values != 'undefined' and inputs_values[0]:
|
||||
return True if prompt[inputs_values[0]] and prompt[inputs_values[0]]['class_type'] == 'easy stylesSelector' else False
|
||||
else:
|
||||
return False
|
||||
|
||||
use_mirror = False
|
||||
def get_local_filepath(url, dirname, local_file_name=None):
|
||||
"""Get local file path when is already downloaded or download it"""
|
||||
import os
|
||||
from server import PromptServer
|
||||
from urllib.parse import urlparse
|
||||
from torch.hub import download_url_to_file
|
||||
global use_mirror
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
if not local_file_name:
|
||||
parsed_url = urlparse(url)
|
||||
local_file_name = os.path.basename(parsed_url.path)
|
||||
destination = os.path.join(dirname, local_file_name)
|
||||
if not os.path.exists(destination):
|
||||
try:
|
||||
if use_mirror:
|
||||
url = url.replace('huggingface.co', 'hf-mirror.com')
|
||||
print(f'downloading {url} to {destination}')
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Downloading model to {destination}, please wait...', 'duration': 10000})
|
||||
download_url_to_file(url, destination)
|
||||
except Exception as e:
|
||||
use_mirror = True
|
||||
url = url.replace('huggingface.co', 'hf-mirror.com')
|
||||
print(f'Unable to download from huggingface, trying mirror: {url}')
|
||||
PromptServer.instance.send_sync("easyuse-toast", {'content': f'Unable to connect to huggingface, trying mirror: {url}', 'duration': 10000})
|
||||
try:
|
||||
download_url_to_file(url, destination)
|
||||
except Exception as err:
|
||||
error_msg = str(err.args[0]) if err.args else str(err)
|
||||
PromptServer.instance.send_sync("easyuse-toast",
|
||||
{'content': f'Unable to download model from {url}', 'type':'error'})
|
||||
raise Exception(f'Download failed. Original URL and mirror both failed.\nError: {error_msg}')
|
||||
return destination
|
||||
|
||||
def to_lora_patch_dict(state_dict: dict) -> dict:
|
||||
""" Convert raw lora state_dict to patch_dict that can be applied on
|
||||
modelpatcher."""
|
||||
patch_dict = {}
|
||||
for k, w in state_dict.items():
|
||||
model_key, patch_type, weight_index = k.split('::')
|
||||
if model_key not in patch_dict:
|
||||
patch_dict[model_key] = {}
|
||||
if patch_type not in patch_dict[model_key]:
|
||||
patch_dict[model_key][patch_type] = [None] * 16
|
||||
patch_dict[model_key][patch_type][int(weight_index)] = w
|
||||
|
||||
patch_flat = {}
|
||||
for model_key, v in patch_dict.items():
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
return patch_flat
|
||||
|
||||
def easySave(images, filename_prefix, output_type, prompt=None, extra_pnginfo=None):
|
||||
"""Save or Preview Image"""
|
||||
from nodes import PreviewImage, SaveImage
|
||||
if output_type in ["Hide", "None"]:
|
||||
return list()
|
||||
elif output_type in ["Preview", "Preview&Choose"]:
|
||||
filename_prefix = 'easyPreview'
|
||||
results = PreviewImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
|
||||
return results['ui']['images']
|
||||
else:
|
||||
results = SaveImage().save_images(images, filename_prefix, prompt, extra_pnginfo)
|
||||
return results['ui']['images']
|
||||
|
||||
def getMetadata(filepath):
|
||||
with open(filepath, "rb") as file:
|
||||
# https://github.com/huggingface/safetensors#format
|
||||
# 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
|
||||
header_size = int.from_bytes(file.read(8), "little", signed=False)
|
||||
|
||||
if header_size <= 0:
|
||||
raise BufferError("Invalid header size")
|
||||
|
||||
header = file.read(header_size)
|
||||
if header_size <= 0:
|
||||
raise BufferError("Invalid header")
|
||||
|
||||
return header
|
||||
|
||||
def cleanGPUUsedForce():
|
||||
gc.collect()
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
476
custom_nodes/ComfyUI-Easy-Use/py/libs/wildcards.py
Normal file
476
custom_nodes/ComfyUI-Easy-Use/py/libs/wildcards.py
Normal file
@@ -0,0 +1,476 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from math import prod
|
||||
|
||||
import yaml
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .log import log_node_info
|
||||
|
||||
easy_wildcard_dict = {}
|
||||
|
||||
def get_wildcard_list():
|
||||
return [f"__{x}__" for x in easy_wildcard_dict.keys()]
|
||||
|
||||
def wildcard_normalize(x):
|
||||
return x.replace("\\", "/").lower()
|
||||
|
||||
def read_wildcard(k, v):
|
||||
if isinstance(v, list):
|
||||
k = wildcard_normalize(k)
|
||||
easy_wildcard_dict[k] = v
|
||||
elif isinstance(v, dict):
|
||||
for k2, v2 in v.items():
|
||||
new_key = f"{k}/{k2}"
|
||||
new_key = wildcard_normalize(new_key)
|
||||
read_wildcard(new_key, v2)
|
||||
|
||||
def read_wildcard_dict(wildcard_path):
|
||||
global easy_wildcard_dict
|
||||
for root, directories, files in os.walk(wildcard_path, followlinks=True):
|
||||
for file in files:
|
||||
if file.endswith('.txt'):
|
||||
file_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(file_path, wildcard_path)
|
||||
key = os.path.splitext(rel_path)[0].replace('\\', '/').lower()
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding="UTF-8", errors="ignore") as f:
|
||||
lines = f.read().splitlines()
|
||||
easy_wildcard_dict[key] = lines
|
||||
except UnicodeDecodeError:
|
||||
with open(file_path, 'r', encoding="ISO-8859-1") as f:
|
||||
lines = f.read().splitlines()
|
||||
easy_wildcard_dict[key] = lines
|
||||
elif file.endswith('.yaml'):
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, 'r') as f:
|
||||
yaml_data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
for k, v in yaml_data.items():
|
||||
read_wildcard(k, v)
|
||||
elif file.endswith('.json'):
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
for key, value in json_data.items():
|
||||
key = wildcard_normalize(key)
|
||||
easy_wildcard_dict[key] = value
|
||||
except ValueError:
|
||||
print('json files load error')
|
||||
return easy_wildcard_dict
|
||||
|
||||
|
||||
def process(text, seed=None):
|
||||
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
|
||||
def replace_options(string):
|
||||
replacements_found = False
|
||||
|
||||
def replace_option(match):
|
||||
nonlocal replacements_found
|
||||
options = match.group(1).split('|')
|
||||
|
||||
multi_select_pattern = options[0].split('$$')
|
||||
select_range = None
|
||||
select_sep = ' '
|
||||
range_pattern = r'(\d+)(-(\d+))?'
|
||||
range_pattern2 = r'-(\d+)'
|
||||
|
||||
if len(multi_select_pattern) > 1:
|
||||
r = re.match(range_pattern, options[0])
|
||||
|
||||
if r is None:
|
||||
r = re.match(range_pattern2, options[0])
|
||||
a = '1'
|
||||
b = r.group(1).strip()
|
||||
else:
|
||||
a = r.group(1).strip()
|
||||
b = r.group(3).strip()
|
||||
|
||||
if r is not None:
|
||||
if b is not None and is_numeric_string(a) and is_numeric_string(b):
|
||||
# PATTERN: num1-num2
|
||||
select_range = int(a), int(b)
|
||||
elif is_numeric_string(a):
|
||||
# PATTERN: num
|
||||
x = int(a)
|
||||
select_range = (x, x)
|
||||
|
||||
if select_range is not None and len(multi_select_pattern) == 2:
|
||||
# PATTERN: count$$
|
||||
options[0] = multi_select_pattern[1]
|
||||
elif select_range is not None and len(multi_select_pattern) == 3:
|
||||
# PATTERN: count$$ sep $$
|
||||
select_sep = multi_select_pattern[1]
|
||||
options[0] = multi_select_pattern[2]
|
||||
|
||||
adjusted_probabilities = []
|
||||
|
||||
total_prob = 0
|
||||
|
||||
for option in options:
|
||||
parts = option.split('::', 1)
|
||||
if len(parts) == 2 and is_numeric_string(parts[0].strip()):
|
||||
config_value = float(parts[0].strip())
|
||||
else:
|
||||
config_value = 1 # Default value if no configuration is provided
|
||||
|
||||
adjusted_probabilities.append(config_value)
|
||||
total_prob += config_value
|
||||
|
||||
normalized_probabilities = [prob / total_prob for prob in adjusted_probabilities]
|
||||
|
||||
if select_range is None:
|
||||
select_count = 1
|
||||
else:
|
||||
select_count = random.randint(select_range[0], select_range[1])
|
||||
|
||||
if select_count > len(options):
|
||||
selected_items = options
|
||||
else:
|
||||
selected_items = random.choices(options, weights=normalized_probabilities, k=select_count)
|
||||
selected_items = set(selected_items)
|
||||
|
||||
try_count = 0
|
||||
while len(selected_items) < select_count and try_count < 10:
|
||||
remaining_count = select_count - len(selected_items)
|
||||
additional_items = random.choices(options, weights=normalized_probabilities, k=remaining_count)
|
||||
selected_items |= set(additional_items)
|
||||
try_count += 1
|
||||
|
||||
selected_items2 = [re.sub(r'^\s*[0-9.]+::', '', x, 1) for x in selected_items]
|
||||
replacement = select_sep.join(selected_items2)
|
||||
if '::' in replacement:
|
||||
pass
|
||||
|
||||
replacements_found = True
|
||||
return replacement
|
||||
|
||||
pattern = r'{([^{}]*?)}'
|
||||
replaced_string = re.sub(pattern, replace_option, string)
|
||||
|
||||
return replaced_string, replacements_found
|
||||
|
||||
def replace_wildcard(string):
|
||||
global easy_wildcard_dict
|
||||
pattern = r"__([\w\s.\-+/*\\]+?)__"
|
||||
matches = re.findall(pattern, string)
|
||||
replacements_found = False
|
||||
|
||||
for match in matches:
|
||||
keyword = match.lower()
|
||||
keyword = wildcard_normalize(keyword)
|
||||
if keyword in easy_wildcard_dict:
|
||||
replacement = random.choice(easy_wildcard_dict[keyword])
|
||||
replacements_found = True
|
||||
string = string.replace(f"__{match}__", replacement, 1)
|
||||
elif '*' in keyword:
|
||||
subpattern = keyword.replace('*', '.*').replace('+', r'\+')
|
||||
total_patterns = []
|
||||
found = False
|
||||
for k, v in easy_wildcard_dict.items():
|
||||
if re.match(subpattern, k) is not None:
|
||||
total_patterns += v
|
||||
found = True
|
||||
|
||||
if found:
|
||||
replacement = random.choice(total_patterns)
|
||||
replacements_found = True
|
||||
string = string.replace(f"__{match}__", replacement, 1)
|
||||
elif '/' not in keyword:
|
||||
string_fallback = string.replace(f"__{match}__", f"__*/{match}__", 1)
|
||||
string, replacements_found = replace_wildcard(string_fallback)
|
||||
|
||||
return string, replacements_found
|
||||
|
||||
replace_depth = 100
|
||||
stop_unwrap = False
|
||||
while not stop_unwrap and replace_depth > 1:
|
||||
replace_depth -= 1 # prevent infinite loop
|
||||
|
||||
# pass1: replace options
|
||||
pass1, is_replaced1 = replace_options(text)
|
||||
|
||||
while is_replaced1:
|
||||
pass1, is_replaced1 = replace_options(pass1)
|
||||
|
||||
# pass2: replace wildcards
|
||||
text, is_replaced2 = replace_wildcard(pass1)
|
||||
stop_unwrap = not is_replaced1 and not is_replaced2
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def is_numeric_string(input_str):
|
||||
return re.match(r'^-?\d+(\.\d+)?$', input_str) is not None
|
||||
|
||||
|
||||
def safe_float(x):
|
||||
if is_numeric_string(x):
|
||||
return float(x)
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
|
||||
def extract_lora_values(string):
|
||||
pattern = r'<lora:([^>]+)>'
|
||||
matches = re.findall(pattern, string)
|
||||
|
||||
def touch_lbw(text):
|
||||
return re.sub(r'LBW=[A-Za-z][A-Za-z0-9_-]*:', r'LBW=', text)
|
||||
|
||||
items = [touch_lbw(match.strip(':')) for match in matches]
|
||||
|
||||
added = set()
|
||||
result = []
|
||||
for item in items:
|
||||
item = item.split(':')
|
||||
|
||||
lora = None
|
||||
a = None
|
||||
b = None
|
||||
lbw = None
|
||||
lbw_a = None
|
||||
lbw_b = None
|
||||
|
||||
if len(item) > 0:
|
||||
lora = item[0]
|
||||
|
||||
for sub_item in item[1:]:
|
||||
if is_numeric_string(sub_item):
|
||||
if a is None:
|
||||
a = float(sub_item)
|
||||
elif b is None:
|
||||
b = float(sub_item)
|
||||
elif sub_item.startswith("LBW="):
|
||||
for lbw_item in sub_item[4:].split(';'):
|
||||
if lbw_item.startswith("A="):
|
||||
lbw_a = safe_float(lbw_item[2:].strip())
|
||||
elif lbw_item.startswith("B="):
|
||||
lbw_b = safe_float(lbw_item[2:].strip())
|
||||
elif lbw_item.strip() != '':
|
||||
lbw = lbw_item
|
||||
|
||||
if a is None:
|
||||
a = 1.0
|
||||
if b is None:
|
||||
b = 1.0
|
||||
|
||||
if lora is not None and lora not in added:
|
||||
result.append((lora, a, b, lbw, lbw_a, lbw_b))
|
||||
added.add(lora)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_lora_tags(string):
|
||||
pattern = r'<lora:[^>]+>'
|
||||
result = re.sub(pattern, '', string)
|
||||
|
||||
return result
|
||||
|
||||
def process_with_loras(wildcard_opt, model, clip, title="Positive", seed=None, can_load_lora=True, pipe_lora_stack=[], easyCache=None):
|
||||
pass1 = process(wildcard_opt, seed)
|
||||
loras = extract_lora_values(pass1)
|
||||
pass2 = remove_lora_tags(pass1)
|
||||
|
||||
has_noodle_key = True if "__" in wildcard_opt else False
|
||||
has_loras = True if loras != [] else False
|
||||
show_wildcard_prompt = True if has_noodle_key or has_loras else False
|
||||
|
||||
if can_load_lora and has_loras:
|
||||
for lora_name, model_weight, clip_weight, lbw, lbw_a, lbw_b in loras:
|
||||
if (lora_name.split('.')[-1]) not in folder_paths.supported_pt_extensions:
|
||||
lora_name = lora_name+".safetensors"
|
||||
lora = {
|
||||
"lora_name": lora_name, "model": model, "clip": clip, "model_strength": model_weight,
|
||||
"clip_strength": clip_weight,
|
||||
"lbw_a": lbw_a,
|
||||
"lbw_b": lbw_b,
|
||||
"lbw": lbw
|
||||
}
|
||||
model, clip = easyCache.load_lora(lora)
|
||||
lora["model"] = model
|
||||
lora["clip"] = clip
|
||||
pipe_lora_stack.append(lora)
|
||||
|
||||
log_node_info("easy wildcards",f"{title}: {pass2}")
|
||||
if pass1 != pass2:
|
||||
log_node_info("easy wildcards",f'{title}_decode: {pass1}')
|
||||
|
||||
return model, clip, pass2, pass1, show_wildcard_prompt, pipe_lora_stack
|
||||
|
||||
|
||||
def expand_wildcard(keyword: str) -> tuple[str]:
|
||||
"""传入文件通配符的关键词,从 easy_wildcard_dict 中获取通配符的所有选项。"""
|
||||
global easy_wildcard_dict
|
||||
if keyword in easy_wildcard_dict:
|
||||
return tuple(easy_wildcard_dict[keyword])
|
||||
elif '*' in keyword:
|
||||
subpattern = keyword.replace('*', '.*').replace('+', r"\+")
|
||||
total_pattern = []
|
||||
for k, v in easy_wildcard_dict.items():
|
||||
if re.match(subpattern, k) is not None:
|
||||
total_pattern.extend(v)
|
||||
if total_pattern:
|
||||
return tuple(total_pattern)
|
||||
elif '/' not in keyword:
|
||||
return expand_wildcard(f"*/{keyword}")
|
||||
|
||||
def expand_options(options: str) -> tuple[str]:
|
||||
"""传入去掉 {} 的选项。
|
||||
展开选项通配符,返回该选项中的每一项,这里的每一项都是一个替换项。
|
||||
不会对选项内容进行任何处理,即便存在空格或特殊符号,也会原样返回。"""
|
||||
return tuple(options.split("|"))
|
||||
|
||||
|
||||
def decimal_to_irregular(n, bases):
|
||||
"""
|
||||
将十进制数转换为不规则进制
|
||||
|
||||
:param n: 十进制数
|
||||
:param bases: 各位置的基数列表,从低位到高位
|
||||
:return: 不规则进制表示的列表,从低位到高位
|
||||
"""
|
||||
if n == 0:
|
||||
return [0] * len(bases) if bases else [0]
|
||||
|
||||
digits = []
|
||||
remaining = n
|
||||
|
||||
# 从低位到高位处理
|
||||
for base in bases:
|
||||
digit = remaining % base
|
||||
digits.append(digit)
|
||||
remaining = remaining // base
|
||||
|
||||
return digits
|
||||
|
||||
|
||||
class WildcardProcessor:
|
||||
"""通配符处理器
|
||||
|
||||
通配符格式:
|
||||
+ option : {a|b}
|
||||
+ wildcard: __keyword__ 通配符内容将从 Easy-Use 插件提供的 easy_wildcard_dict 中获取
|
||||
"""
|
||||
|
||||
RE_OPTIONS = re.compile(r"{([^{}]*?)}")
|
||||
RE_WILDCARD = re.compile(r"__([\w\s.\-+/*\\]+?)__")
|
||||
RE_REPLACER = re.compile(r"{([^{}]*?)}|__([\w\s.\-+/*\\]+?)__")
|
||||
|
||||
# 将输入的提示词转化成符合 python str.format 要求格式的模板,并将 option 和 wildcard 按照顺序在模板中留下 {0}, {1} 等占位符
|
||||
template: str
|
||||
# option、wildcard 的替换项列表,按照在模板中出现的顺序排列,相同的替换项列表只保留第一份
|
||||
replacers: dict[int, tuple[str]]
|
||||
# 占位符的编号和替换项列表的索引的映射,占位符编号按照在模板中出现的顺序排列,方便减少替换项的存储占用
|
||||
placeholder_mapping: dict[str, int] # placeholder_id => replacer_id
|
||||
# 各替换项列表的项数,按照在模板中出现的顺序排列,提前计算,方便后续使用
|
||||
placeholder_choices: dict[str, int] # placeholder_id => len(replacer)
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.__make_template(text)
|
||||
self.__total = None
|
||||
|
||||
def random(self, seed=None) -> str:
|
||||
"从所有可能性中随机获取一个"
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
return self.getn(random.randint(0, self.total() - 1))
|
||||
|
||||
def getn(self, n: int) -> str:
|
||||
"从所有可能性中获取第 n 个,以 self.total() 为周期循环"
|
||||
n = n % self.total()
|
||||
indice = decimal_to_irregular(n, self.placeholder_choices.values())
|
||||
replacements = {
|
||||
placeholder_id: self.replacers[self.placeholder_mapping[placeholder_id]][i]
|
||||
for placeholder_id, i in zip(self.placeholder_mapping.keys(), indice)
|
||||
}
|
||||
return self.template.format(**replacements)
|
||||
|
||||
def getmany(self, limit: int, offset: int = 0) -> list[str]:
|
||||
"""返回一组可能性组成的列表,为了避免结果太长导致内存占用超限,使用 limit 限制列表的长度,使用 offset 调整偏移。
|
||||
若 limit 和 offset 的设置导致预期的结果长度超过剩下的实际长度,则会回到开头。
|
||||
"""
|
||||
return [self.getn(n) for n in range(offset, offset + limit)]
|
||||
|
||||
def total(self) -> int:
|
||||
"计算可能性的数目"
|
||||
if self.__total is None:
|
||||
self.__total = prod(self.placeholder_choices.values())
|
||||
return self.__total
|
||||
|
||||
def __make_template(self, text: str):
|
||||
"""将输入的提示词转化成符合 python str.format 要求格式的模板,
|
||||
并将 option 和 wildcard 按照顺序在模板中留下 {r0}, {r1} 等占位符,
|
||||
即使遇到相同的 option 或 wildcard,留下的占位符编号也不同,从而使每项都独立变化。
|
||||
"""
|
||||
self.placeholder_mapping = {}
|
||||
placeholder_id = 0
|
||||
replacer_id = 0
|
||||
replacers_rev = {} # replacers => id
|
||||
blocks = []
|
||||
# 记录所处理过的通配符末尾在文本中的位置,用于拼接完整的模板
|
||||
tail = 0
|
||||
for match in self.RE_REPLACER.finditer(text):
|
||||
# 提取并展开通配符内容
|
||||
m = match.group(0)
|
||||
if m.startswith("{"):
|
||||
choices = expand_options(m[1:-1])
|
||||
elif m.startswith("__"):
|
||||
keyword = m[2:-2].lower()
|
||||
keyword = wildcard_normalize(keyword)
|
||||
choices = expand_wildcard(keyword)
|
||||
else:
|
||||
raise ValueError(f"{m!r} is not a wildcard or option")
|
||||
|
||||
# 记录通配符的替换项列表和ID,相同的通配符只保留第一个
|
||||
if choices not in replacers_rev:
|
||||
replacers_rev[choices] = replacer_id
|
||||
replacer_id += 1
|
||||
|
||||
# 拼接通配符前方文本
|
||||
start, end = match.span()
|
||||
blocks.append(text[tail:start])
|
||||
tail = end
|
||||
# 将通配符替换为占位符,并记录占位符和替换项列表的索引的映射
|
||||
blocks.append(f"{{r{placeholder_id}}}")
|
||||
self.placeholder_mapping[f"r{placeholder_id}"] = replacers_rev[choices]
|
||||
placeholder_id += 1
|
||||
|
||||
if tail < len(text):
|
||||
blocks.append(text[tail:])
|
||||
self.template = "".join(blocks)
|
||||
self.replacers = {v: k for k, v in replacers_rev.items()}
|
||||
self.placeholder_choices = {
|
||||
placeholder_id: len(self.replacers[replacer_id])
|
||||
for placeholder_id, replacer_id in self.placeholder_mapping.items()
|
||||
}
|
||||
|
||||
|
||||
def test_option():
|
||||
text = "{|a|b|c}"
|
||||
answer = ["", "a", "b", "c"]
|
||||
p = WildcardProcessor(text)
|
||||
assert p.total() == len(answer)
|
||||
assert p.getn(0) == answer[0]
|
||||
assert p.getmany(4) == answer
|
||||
assert p.getmany(4, 1) == answer[1:]
|
||||
|
||||
|
||||
def test_same():
|
||||
text = "{a|b},{a|b}"
|
||||
answer = ["a,a", "b,a", "a,b", "b,b"]
|
||||
p = WildcardProcessor(text)
|
||||
assert p.total() == len(answer)
|
||||
assert p.getn(0) == answer[0]
|
||||
assert p.getmany(4) == answer
|
||||
assert p.getmany(4, 1) == answer[1:]
|
||||
|
||||
697
custom_nodes/ComfyUI-Easy-Use/py/libs/xyplot.py
Normal file
697
custom_nodes/ComfyUI-Easy-Use/py/libs/xyplot.py
Normal file
@@ -0,0 +1,697 @@
|
||||
import os, torch
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from .utils import easySave, get_sd_version
|
||||
from .adv_encode import advanced_encode
|
||||
from .controlnet import easyControlnet
|
||||
from .log import log_node_warn
|
||||
from ..modules.layer_diffuse import LayerDiffuse
|
||||
from ..config import RESOURCES_DIR
|
||||
from nodes import CLIPTextEncode
|
||||
import pprint
|
||||
try:
|
||||
from comfy_extras.nodes_flux import FluxGuidance
|
||||
except:
|
||||
FluxGuidance = None
|
||||
|
||||
class easyXYPlot():
|
||||
|
||||
def __init__(self, xyPlotData, save_prefix, image_output, prompt, extra_pnginfo, my_unique_id, sampler, easyCache):
|
||||
self.x_node_type, self.x_type = sampler.safe_split(xyPlotData.get("x_axis"), ': ')
|
||||
self.y_node_type, self.y_type = sampler.safe_split(xyPlotData.get("y_axis"), ': ')
|
||||
self.x_values = xyPlotData.get("x_vals") if self.x_type != "None" else []
|
||||
self.y_values = xyPlotData.get("y_vals") if self.y_type != "None" else []
|
||||
self.custom_font = xyPlotData.get("custom_font")
|
||||
|
||||
self.grid_spacing = xyPlotData.get("grid_spacing")
|
||||
self.latent_id = 0
|
||||
self.output_individuals = xyPlotData.get("output_individuals")
|
||||
|
||||
self.x_label, self.y_label = [], []
|
||||
self.max_width, self.max_height = 0, 0
|
||||
self.latents_plot = []
|
||||
self.image_list = []
|
||||
|
||||
self.num_cols = len(self.x_values) if len(self.x_values) > 0 else 1
|
||||
self.num_rows = len(self.y_values) if len(self.y_values) > 0 else 1
|
||||
|
||||
self.total = self.num_cols * self.num_rows
|
||||
self.num = 0
|
||||
|
||||
self.save_prefix = save_prefix
|
||||
self.image_output = image_output
|
||||
self.prompt = prompt
|
||||
self.extra_pnginfo = extra_pnginfo
|
||||
self.my_unique_id = my_unique_id
|
||||
|
||||
self.sampler = sampler
|
||||
self.easyCache = easyCache
|
||||
|
||||
# Helper Functions
|
||||
@staticmethod
|
||||
def define_variable(plot_image_vars, value_type, value, index):
|
||||
|
||||
plot_image_vars[value_type] = value
|
||||
if value_type in ["seed", "Seeds++ Batch"]:
|
||||
value_label = f"seed: {value}"
|
||||
else:
|
||||
value_label = f"{value_type}: {value}"
|
||||
|
||||
if "ControlNet" in value_type:
|
||||
value_label = f"ControlNet {index + 1}"
|
||||
|
||||
if value_type in ['Lora', 'Checkpoint']:
|
||||
arr = value.split(',')
|
||||
model_name = os.path.basename(os.path.splitext(arr[0])[0])
|
||||
trigger_words = ' ' + arr[3] if value_type == 'Lora' and len(arr[3]) > 2 else ''
|
||||
lora_weight = float(arr[1]) if value_type == 'Lora' and len(arr) > 1 else 0
|
||||
lora_weight_desc = f"({lora_weight:.2f})" if lora_weight > 0 else ''
|
||||
value_label = f"{model_name[:30]}{lora_weight_desc} {trigger_words}"
|
||||
|
||||
if value_type in ["ModelMergeBlocks"]:
|
||||
if ":" in value:
|
||||
line = value.split(':')
|
||||
value_label = f"{line[0]}"
|
||||
elif len(value) > 16:
|
||||
value_label = f"ModelMergeBlocks {index + 1}"
|
||||
else:
|
||||
value_label = f"MMB: {value}"
|
||||
|
||||
if value_type in ["Pos Condition"]:
|
||||
value_label = f"pos cond {index + 1}" if index>0 else f"pos cond"
|
||||
if value_type in ["Neg Condition"]:
|
||||
value_label = f"neg cond {index + 1}" if index>0 else f"neg cond"
|
||||
|
||||
if value_type in ["Positive Prompt S/R"]:
|
||||
value_label = f"pos prompt {index + 1}" if index>0 else f"pos prompt"
|
||||
if value_type in ["Negative Prompt S/R"]:
|
||||
value_label = f"neg prompt {index + 1}" if index>0 else f"neg prompt"
|
||||
|
||||
if value_type in ["steps", "cfg", "denoise", "clip_skip",
|
||||
"lora_model_strength", "lora_clip_strength"]:
|
||||
value_label = f"{value_type}: {value}"
|
||||
|
||||
if value_type == "positive":
|
||||
value_label = f"pos prompt {index + 1}"
|
||||
elif value_type == "negative":
|
||||
value_label = f"neg prompt {index + 1}"
|
||||
|
||||
return plot_image_vars, value_label
|
||||
|
||||
@staticmethod
|
||||
def get_font(font_size, font_path=None):
|
||||
if font_path is None:
|
||||
font_path = str(Path(os.path.join(RESOURCES_DIR, 'OpenSans-Medium.ttf')))
|
||||
return ImageFont.truetype(font_path, font_size)
|
||||
|
||||
@staticmethod
|
||||
def update_label(label, value, num_items):
|
||||
if len(label) < num_items:
|
||||
return [*label, value]
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def rearrange_tensors(latent, num_cols, num_rows):
|
||||
new_latent = []
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
index = j * num_rows + i
|
||||
new_latent.append(latent[index])
|
||||
return new_latent
|
||||
|
||||
def calculate_background_dimensions(self):
|
||||
border_size = int((self.max_width // 8) * 1.5) if self.y_type != "None" or self.x_type != "None" else 0
|
||||
|
||||
bg_width = self.num_cols * (self.max_width + self.grid_spacing) - self.grid_spacing + border_size * (
|
||||
self.y_type != "None")
|
||||
bg_height = self.num_rows * (self.max_height + self.grid_spacing) - self.grid_spacing + border_size * (
|
||||
self.x_type != "None")
|
||||
|
||||
# Add space at the bottom of the image for common informaiton about the image
|
||||
bg_height = bg_height + (border_size*2)
|
||||
# print(f"Grid Size: width = {bg_width} height = {bg_height} border_size = {border_size}")
|
||||
|
||||
x_offset_initial = border_size if self.y_type != "None" else 0
|
||||
y_offset = border_size if self.x_type != "None" else 0
|
||||
|
||||
return bg_width, bg_height, x_offset_initial, y_offset
|
||||
|
||||
|
||||
def adjust_font_size(self, text, initial_font_size, label_width):
|
||||
font = self.get_font(initial_font_size, self.custom_font)
|
||||
text_width = font.getbbox(text)
|
||||
# pprint.pp(f"Initial font size: {initial_font_size}, text: {text}, text_width: {text_width}")
|
||||
if text_width and text_width[2]:
|
||||
text_width = text_width[2]
|
||||
|
||||
scaling_factor = 0.9
|
||||
if text_width > (label_width * scaling_factor):
|
||||
# print(f"Adjusting font size from {initial_font_size} to fit text width {text_width} into label width {label_width} scaling_factor {scaling_factor}")
|
||||
return int(initial_font_size * (label_width / text_width) * scaling_factor)
|
||||
else:
|
||||
return initial_font_size
|
||||
|
||||
def textsize(self, d, text, font):
|
||||
_, _, width, height = d.textbbox((0, 0), text=text, font=font)
|
||||
return width, height
|
||||
|
||||
def create_label(self, img, text, initial_font_size, is_x_label=True, max_font_size=70, min_font_size=10, label_width=0, label_height=0):
|
||||
|
||||
# if the label_width is specified, leave it along. Otherwise do the old logic.
|
||||
if label_width == 0:
|
||||
label_width = img.width if is_x_label else img.height
|
||||
|
||||
text_lines = text.split('\n')
|
||||
longest_line = max(text_lines, key=len)
|
||||
|
||||
# Adjust font size
|
||||
font_size = self.adjust_font_size(longest_line, initial_font_size, label_width)
|
||||
font_size = min(max_font_size, font_size) # Ensure font isn't too large
|
||||
font_size = max(min_font_size, font_size) # Ensure font isn't too small
|
||||
|
||||
if label_height == 0:
|
||||
label_height = int(font_size * 1.5) if is_x_label else font_size
|
||||
|
||||
label_bg = Image.new('RGBA', (label_width, label_height), color=(255, 255, 255, 0))
|
||||
d = ImageDraw.Draw(label_bg)
|
||||
|
||||
font = self.get_font(font_size, self.custom_font)
|
||||
|
||||
# Check if text will fit, if not insert ellipsis and reduce text
|
||||
if self.textsize(d, text, font=font)[0] > label_width:
|
||||
while self.textsize(d, text + '...', font=font)[0] > label_width and len(text) > 0:
|
||||
text = text[:-1]
|
||||
text = text + '...'
|
||||
|
||||
# Compute text width and height for multi-line text
|
||||
|
||||
text_widths, text_heights = zip(*[self.textsize(d, line, font=font) for line in text_lines])
|
||||
max_text_width = max(text_widths)
|
||||
total_text_height = sum(text_heights)
|
||||
|
||||
# Compute position for each line of text
|
||||
lines_positions = []
|
||||
current_y = 0
|
||||
for line, line_width, line_height in zip(text_lines, text_widths, text_heights):
|
||||
text_x = (label_width - line_width) // 2
|
||||
text_y = current_y + (label_height - total_text_height) // 2
|
||||
current_y += line_height
|
||||
lines_positions.append((line, (text_x, text_y)))
|
||||
|
||||
# Draw each line of text
|
||||
for line, (text_x, text_y) in lines_positions:
|
||||
d.text((text_x, text_y), line, fill='black', font=font)
|
||||
|
||||
return label_bg
|
||||
|
||||
def sample_plot_image(self, plot_image_vars, samples, preview_latent, latents_plot, image_list, disable_noise,
|
||||
start_step, last_step, force_full_denoise, x_value=None, y_value=None):
|
||||
model, clip, vae, positive, negative, seed, steps, cfg = None, None, None, None, None, None, None, None
|
||||
sampler_name, scheduler, denoise = None, None, None
|
||||
|
||||
a1111_prompt_style = plot_image_vars['a1111_prompt_style'] if "a1111_prompt_style" in plot_image_vars else False
|
||||
clip = clip if clip is not None else plot_image_vars["clip"]
|
||||
steps = plot_image_vars['steps'] if "steps" in plot_image_vars else 1
|
||||
|
||||
sd_version = get_sd_version(plot_image_vars['model'])
|
||||
# 高级用法
|
||||
if plot_image_vars["x_node_type"] == "advanced" or plot_image_vars["y_node_type"] == "advanced":
|
||||
if self.x_type == "Seeds++ Batch" or self.y_type == "Seeds++ Batch":
|
||||
seed = int(x_value) if self.x_type == "Seeds++ Batch" else int(y_value)
|
||||
if self.x_type == "Steps" or self.y_type == "Steps":
|
||||
steps = int(x_value) if self.x_type == "Steps" else int(y_value)
|
||||
if self.x_type == "StartStep" or self.y_type == "StartStep":
|
||||
start_step = int(x_value) if self.x_type == "StartStep" else int(y_value)
|
||||
if self.x_type == "EndStep" or self.y_type == "EndStep":
|
||||
last_step = int(x_value) if self.x_type == "EndStep" else int(y_value)
|
||||
if self.x_type == "CFG Scale" or self.y_type == "CFG Scale":
|
||||
cfg = float(x_value) if self.x_type == "CFG Scale" else float(y_value)
|
||||
if self.x_type == "Sampler" or self.y_type == "Sampler":
|
||||
sampler_name = x_value if self.x_type == "Sampler" else y_value
|
||||
if self.x_type == "Scheduler" or self.y_type == "Scheduler":
|
||||
scheduler = x_value if self.x_type == "Scheduler" else y_value
|
||||
if self.x_type == "Sampler&Scheduler" or self.y_type == "Sampler&Scheduler":
|
||||
arr = x_value.split(',') if self.x_type == "Sampler&Scheduler" else y_value.split(',')
|
||||
if arr[0] and arr[0]!= 'None':
|
||||
sampler_name = arr[0]
|
||||
if arr[1] and arr[1]!= 'None':
|
||||
scheduler = arr[1]
|
||||
if self.x_type == "Denoise" or self.y_type == "Denoise":
|
||||
denoise = float(x_value) if self.x_type == "Denoise" else float(y_value)
|
||||
if self.x_type == "Pos Condition" or self.y_type == "Pos Condition":
|
||||
positive = plot_image_vars['positive_cond_stack'][int(x_value)] if self.x_type == "Pos Condition" else plot_image_vars['positive_cond_stack'][int(y_value)]
|
||||
if self.x_type == "Neg Condition" or self.y_type == "Neg Condition":
|
||||
negative = plot_image_vars['negative_cond_stack'][int(x_value)] if self.x_type == "Neg Condition" else plot_image_vars['negative_cond_stack'][int(y_value)]
|
||||
# 模型叠加
|
||||
if self.x_type == "ModelMergeBlocks" or self.y_type == "ModelMergeBlocks":
|
||||
ckpt_name_1, ckpt_name_2 = plot_image_vars['models']
|
||||
model1, clip1, vae1, clip_vision = self.easyCache.load_checkpoint(ckpt_name_1)
|
||||
model2, clip2, vae2, clip_vision = self.easyCache.load_checkpoint(ckpt_name_2)
|
||||
xy_values = x_value if self.x_type == "ModelMergeBlocks" else y_value
|
||||
if ":" in xy_values:
|
||||
xy_line = xy_values.split(':')
|
||||
xy_values = xy_line[1]
|
||||
|
||||
xy_arrs = xy_values.split(',')
|
||||
# ModelMergeBlocks
|
||||
if len(xy_arrs) == 3:
|
||||
input, middle, out = xy_arrs
|
||||
kwargs = {
|
||||
"input": input,
|
||||
"middle": middle,
|
||||
"out": out
|
||||
}
|
||||
elif len(xy_arrs) == 30:
|
||||
kwargs = {}
|
||||
kwargs["time_embed."] = xy_arrs[0]
|
||||
kwargs["label_emb."] = xy_arrs[1]
|
||||
|
||||
for i in range(12):
|
||||
kwargs["input_blocks.{}.".format(i)] = xy_arrs[2+i]
|
||||
|
||||
for i in range(3):
|
||||
kwargs["middle_block.{}.".format(i)] = xy_arrs[14+i]
|
||||
|
||||
for i in range(12):
|
||||
kwargs["output_blocks.{}.".format(i)] = xy_arrs[17+i]
|
||||
|
||||
kwargs["out."] = xy_arrs[29]
|
||||
else:
|
||||
raise Exception("ModelMergeBlocks weight length error")
|
||||
default_ratio = next(iter(kwargs.values()))
|
||||
|
||||
m = model1.clone()
|
||||
kp = model2.get_key_patches("diffusion_model.")
|
||||
|
||||
for k in kp:
|
||||
ratio = float(default_ratio)
|
||||
k_unet = k[len("diffusion_model."):]
|
||||
|
||||
last_arg_size = 0
|
||||
for arg in kwargs:
|
||||
if k_unet.startswith(arg) and last_arg_size < len(arg):
|
||||
ratio = float(kwargs[arg])
|
||||
last_arg_size = len(arg)
|
||||
|
||||
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
|
||||
|
||||
vae_use = plot_image_vars['vae_use']
|
||||
|
||||
clip = clip2 if vae_use == 'Use Model 2' else clip1
|
||||
if vae_use == 'Use Model 2':
|
||||
vae = vae2
|
||||
elif vae_use == 'Use Model 1':
|
||||
vae = vae1
|
||||
else:
|
||||
vae = self.easyCache.load_vae(vae_use)
|
||||
model = m
|
||||
|
||||
# 如果存在lora_stack叠加lora
|
||||
optional_lora_stack = plot_image_vars['lora_stack']
|
||||
if optional_lora_stack is not None and optional_lora_stack != []:
|
||||
for lora in optional_lora_stack:
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# 处理clip
|
||||
clip = clip.clone()
|
||||
if plot_image_vars['clip_skip'] != 0:
|
||||
clip.clip_layer(plot_image_vars['clip_skip'])
|
||||
|
||||
# CheckPoint
|
||||
if self.x_type == "Checkpoint" or self.y_type == "Checkpoint":
|
||||
xy_values = x_value if self.x_type == "Checkpoint" else y_value
|
||||
ckpt_name, clip_skip, vae_name = xy_values.split(",")
|
||||
ckpt_name = ckpt_name.replace('*', ',')
|
||||
vae_name = vae_name.replace('*', ',')
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
|
||||
if vae_name != 'None':
|
||||
vae = self.easyCache.load_vae(vae_name)
|
||||
|
||||
# 如果存在lora_stack叠加lora
|
||||
optional_lora_stack = plot_image_vars['lora_stack']
|
||||
if optional_lora_stack is not None and optional_lora_stack != []:
|
||||
for lora in optional_lora_stack:
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# 处理clip
|
||||
clip = clip.clone()
|
||||
if clip_skip != 'None':
|
||||
clip.clip_layer(int(clip_skip))
|
||||
positive = plot_image_vars['positive']
|
||||
negative = plot_image_vars['negative']
|
||||
a1111_prompt_style = plot_image_vars['a1111_prompt_style']
|
||||
steps = plot_image_vars['steps']
|
||||
clip = clip if clip is not None else plot_image_vars["clip"]
|
||||
positive = advanced_encode(clip, positive,
|
||||
plot_image_vars['positive_token_normalization'],
|
||||
plot_image_vars['positive_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable",
|
||||
a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
negative = advanced_encode(clip, negative,
|
||||
plot_image_vars['negative_token_normalization'],
|
||||
plot_image_vars['negative_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable",
|
||||
a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
if "positive_cond" in plot_image_vars:
|
||||
positive = positive + plot_image_vars["positive_cond"]
|
||||
if "negative_cond" in plot_image_vars:
|
||||
negative = negative + plot_image_vars["negative_cond"]
|
||||
|
||||
# Lora
|
||||
if self.x_type == "Lora" or self.y_type == "Lora":
|
||||
# print(f"Lora: {x_value} {y_value}")
|
||||
model = model if model is not None else plot_image_vars["model"]
|
||||
clip = clip if clip is not None else plot_image_vars["clip"]
|
||||
xy_values = x_value if self.x_type == "Lora" else y_value
|
||||
lora_name, lora_model_strength, lora_clip_strength, _ = xy_values.split(",")
|
||||
lora_stack = [{"lora_name": lora_name, "model": model, "clip" :clip, "model_strength": float(lora_model_strength), "clip_strength": float(lora_clip_strength)}]
|
||||
|
||||
# print(f"new_lora_stack: {new_lora_stack}")
|
||||
|
||||
|
||||
if 'lora_stack' in plot_image_vars:
|
||||
lora_stack = lora_stack + plot_image_vars['lora_stack']
|
||||
|
||||
if lora_stack is not None and lora_stack != []:
|
||||
for lora in lora_stack:
|
||||
# Each generation of the model, must use the reference to previously created model / clip objects.
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# 提示词
|
||||
if "Positive" in self.x_type or "Positive" in self.y_type:
|
||||
if self.x_type == 'Positive Prompt S/R' or self.y_type == 'Positive Prompt S/R':
|
||||
positive = x_value if self.x_type == "Positive Prompt S/R" else y_value
|
||||
|
||||
if sd_version == 'flux':
|
||||
positive, = CLIPTextEncode().encode(clip, positive)
|
||||
else:
|
||||
positive = advanced_encode(clip, positive,
|
||||
plot_image_vars['positive_token_normalization'],
|
||||
plot_image_vars['positive_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
# if "positive_cond" in plot_image_vars:
|
||||
# positive = positive + plot_image_vars["positive_cond"]
|
||||
|
||||
if "Negative" in self.x_type or "Negative" in self.y_type:
|
||||
if self.x_type == 'Negative Prompt S/R' or self.y_type == 'Negative Prompt S/R':
|
||||
negative = x_value if self.x_type == "Negative Prompt S/R" else y_value
|
||||
|
||||
if sd_version == 'flux':
|
||||
negative, = CLIPTextEncode().encode(clip, negative)
|
||||
else:
|
||||
negative = advanced_encode(clip, negative,
|
||||
plot_image_vars['negative_token_normalization'],
|
||||
plot_image_vars['negative_weight_interpretation'],
|
||||
w_max=1.0,
|
||||
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
# if "negative_cond" in plot_image_vars:
|
||||
# negative = negative + plot_image_vars["negative_cond"]
|
||||
|
||||
# ControlNet
|
||||
if "ControlNet" in self.x_type or "ControlNet" in self.y_type:
|
||||
cnet = plot_image_vars["cnet"] if "cnet" in plot_image_vars else None
|
||||
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
|
||||
negative = plot_image_vars["negative_cond"] if "negative" in plot_image_vars else None
|
||||
if cnet:
|
||||
index = x_value if "ControlNet" in self.x_type else y_value
|
||||
controlnet = cnet[index]
|
||||
for index, item in enumerate(controlnet):
|
||||
control_net_name = item[0]
|
||||
image = item[1]
|
||||
strength = item[2]
|
||||
start_percent = item[3]
|
||||
end_percent = item[4]
|
||||
provided_control_net = item[5] if len(item) > 5 else None
|
||||
positive, negative = easyControlnet().apply(control_net_name, image, positive, negative, strength, start_percent, end_percent, provided_control_net, 1)
|
||||
# Flux guidance
|
||||
if self.x_type == "Flux Guidance" or self.y_type == "Flux Guidance":
|
||||
positive = plot_image_vars["positive_cond"] if "positive" in plot_image_vars else None
|
||||
flux_guidance = float(x_value) if self.x_type == "Flux Guidance" else float(y_value)
|
||||
positive, = FluxGuidance().append(positive, flux_guidance)
|
||||
|
||||
# 简单用法
|
||||
if plot_image_vars["x_node_type"] == "loader" or plot_image_vars["y_node_type"] == "loader":
|
||||
if self.x_type == 'ckpt_name' or self.y_type == 'ckpt_name':
|
||||
ckpt_name = x_value if self.x_type == "ckpt_name" else y_value
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(ckpt_name)
|
||||
|
||||
if self.x_type == 'lora_name' or self.y_type == 'lora_name':
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
|
||||
lora_name = x_value if self.x_type == "lora_name" else y_value
|
||||
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": 1, "clip_strength": 1}
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
if self.x_type == 'lora_model_strength' or self.y_type == 'lora_model_strength':
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
|
||||
lora_model_strength = float(x_value) if self.x_type == "lora_model_strength" else float(y_value)
|
||||
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": lora_model_strength, "clip_strength": plot_image_vars['lora_clip_strength']}
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
if self.x_type == 'lora_clip_strength' or self.y_type == 'lora_clip_strength':
|
||||
model, clip, vae, clip_vision = self.easyCache.load_checkpoint(plot_image_vars['ckpt_name'])
|
||||
lora_clip_strength = float(x_value) if self.x_type == "lora_clip_strength" else float(y_value)
|
||||
lora = {"lora_name": plot_image_vars['lora_name'], "model": model, "clip": clip, "model_strength": plot_image_vars['lora_model_strength'], "clip_strength": lora_clip_strength}
|
||||
model, clip = self.easyCache.load_lora(lora)
|
||||
|
||||
# Check for custom VAE
|
||||
if self.x_type == 'vae_name' or self.y_type == 'vae_name':
|
||||
vae_name = x_value if self.x_type == "vae_name" else y_value
|
||||
vae = self.easyCache.load_vae(vae_name)
|
||||
|
||||
# CLIP skip
|
||||
if not clip:
|
||||
raise Exception("No CLIP found")
|
||||
clip = clip.clone()
|
||||
clip.clip_layer(plot_image_vars['clip_skip'])
|
||||
|
||||
if sd_version == 'flux':
|
||||
positive, = CLIPTextEncode().encode(clip, positive)
|
||||
else:
|
||||
positive = advanced_encode(clip, plot_image_vars['positive'],
|
||||
plot_image_vars['positive_token_normalization'],
|
||||
plot_image_vars['positive_weight_interpretation'], w_max=1.0,
|
||||
apply_to_pooled="enable",a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
if sd_version == 'flux':
|
||||
negative, = CLIPTextEncode().encode(clip, negative)
|
||||
else:
|
||||
negative = advanced_encode(clip, plot_image_vars['negative'],
|
||||
plot_image_vars['negative_token_normalization'],
|
||||
plot_image_vars['negative_weight_interpretation'], w_max=1.0,
|
||||
apply_to_pooled="enable", a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
|
||||
model = model if model is not None else plot_image_vars["model"]
|
||||
vae = vae if vae is not None else plot_image_vars["vae"]
|
||||
positive = positive if positive is not None else plot_image_vars["positive_cond"]
|
||||
negative = negative if negative is not None else plot_image_vars["negative_cond"]
|
||||
|
||||
seed = seed if seed is not None else plot_image_vars["seed"]
|
||||
steps = steps if steps is not None else plot_image_vars["steps"]
|
||||
cfg = cfg if cfg is not None else plot_image_vars["cfg"]
|
||||
sampler_name = sampler_name if sampler_name is not None else plot_image_vars["sampler_name"]
|
||||
scheduler = scheduler if scheduler is not None else plot_image_vars["scheduler"]
|
||||
denoise = denoise if denoise is not None else plot_image_vars["denoise"]
|
||||
|
||||
noise_device = plot_image_vars["noise_device"] if "noise_device" in plot_image_vars else 'cpu'
|
||||
|
||||
# LayerDiffuse
|
||||
layer_diffusion_method = plot_image_vars["layer_diffusion_method"] if "layer_diffusion_method" in plot_image_vars else None
|
||||
empty_samples = plot_image_vars["empty_samples"] if "empty_samples" in plot_image_vars else None
|
||||
|
||||
if layer_diffusion_method:
|
||||
samp_blend_samples = plot_image_vars["blend_samples"] if "blend_samples" in plot_image_vars else None
|
||||
additional_cond = plot_image_vars["layer_diffusion_cond"] if "layer_diffusion_cond" in plot_image_vars else None
|
||||
|
||||
images = plot_image_vars["images"].movedim(-1, 1) if "images" in plot_image_vars else None
|
||||
weight = plot_image_vars['layer_diffusion_weight'] if 'layer_diffusion_weight' in plot_image_vars else 1.0
|
||||
model, positive, negative = LayerDiffuse().apply_layer_diffusion(model, layer_diffusion_method, weight, samples,
|
||||
samp_blend_samples, positive,
|
||||
negative, images, additional_cond)
|
||||
|
||||
samples = empty_samples if layer_diffusion_method is not None and empty_samples is not None else samples
|
||||
# Sample
|
||||
samples = self.sampler.common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, samples,
|
||||
denoise=denoise, disable_noise=disable_noise, preview_latent=preview_latent,
|
||||
start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_device=noise_device)
|
||||
|
||||
# Decode images and store
|
||||
latent = samples["samples"]
|
||||
|
||||
# Add the latent tensor to the tensors list
|
||||
latents_plot.append(latent)
|
||||
|
||||
# Decode the image
|
||||
image = vae.decode(latent).cpu()
|
||||
|
||||
if self.output_individuals in [True, "True"]:
|
||||
easySave(image, self.save_prefix, self.image_output)
|
||||
|
||||
# Convert the image from tensor to PIL Image and add it to the list
|
||||
pil_image = self.sampler.tensor2pil(image)
|
||||
image_list.append(pil_image)
|
||||
|
||||
# Update max dimensions
|
||||
self.max_width = max(self.max_width, pil_image.width)
|
||||
self.max_height = max(self.max_height, pil_image.height)
|
||||
|
||||
# Return the touched variables
|
||||
return image_list, self.max_width, self.max_height, latents_plot
|
||||
|
||||
# Process Functions
|
||||
def validate_xy_plot(self):
|
||||
if self.x_type == 'None' and self.y_type == 'None':
|
||||
log_node_warn(f'#{self.my_unique_id}','No Valid Plot Types - Reverting to default sampling...')
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_latent(self, samples):
|
||||
# Extract the 'samples' tensor from the dictionary
|
||||
latent_image_tensor = samples["samples"]
|
||||
|
||||
# Split the tensor into individual image tensors
|
||||
image_tensors = torch.split(latent_image_tensor, 1, dim=0)
|
||||
|
||||
# Create a list of dictionaries containing the individual image tensors
|
||||
latent_list = [{'samples': image} for image in image_tensors]
|
||||
|
||||
# Set latent only to the first latent of batch
|
||||
if self.latent_id >= len(latent_list):
|
||||
log_node_warn(f'#{self.my_unique_id}',f'The selected latent_id ({self.latent_id}) is out of range.')
|
||||
log_node_warn(f'#{self.my_unique_id}', f'Automatically setting the latent_id to the last image in the list (index: {len(latent_list) - 1}).')
|
||||
|
||||
self.latent_id = len(latent_list) - 1
|
||||
|
||||
return latent_list[self.latent_id]
|
||||
|
||||
def get_labels_and_sample(self, plot_image_vars, latent_image, preview_latent, start_step, last_step,
|
||||
force_full_denoise, disable_noise):
|
||||
for x_index, x_value in enumerate(self.x_values):
|
||||
plot_image_vars, x_value_label = self.define_variable(plot_image_vars, self.x_type, x_value,
|
||||
x_index)
|
||||
self.x_label = self.update_label(self.x_label, x_value_label, len(self.x_values))
|
||||
if self.y_type != 'None':
|
||||
for y_index, y_value in enumerate(self.y_values):
|
||||
plot_image_vars, y_value_label = self.define_variable(plot_image_vars, self.y_type, y_value,
|
||||
y_index)
|
||||
self.y_label = self.update_label(self.y_label, y_value_label, len(self.y_values))
|
||||
# ttNl(f'{CC.GREY}X: {x_value_label}, Y: {y_value_label}').t(
|
||||
# f'Plot Values {self.num}/{self.total} ->').p()
|
||||
|
||||
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
|
||||
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list,
|
||||
disable_noise, start_step, last_step, force_full_denoise, x_value, y_value)
|
||||
self.num += 1
|
||||
else:
|
||||
# ttNl(f'{CC.GREY}X: {x_value_label}').t(f'Plot Values {self.num}/{self.total} ->').p()
|
||||
self.image_list, self.max_width, self.max_height, self.latents_plot = self.sample_plot_image(
|
||||
plot_image_vars, latent_image, preview_latent, self.latents_plot, self.image_list, disable_noise,
|
||||
start_step, last_step, force_full_denoise, x_value)
|
||||
self.num += 1
|
||||
|
||||
# Rearrange latent array to match preview image grid
|
||||
self.latents_plot = self.rearrange_tensors(self.latents_plot, self.num_cols, self.num_rows)
|
||||
|
||||
# Concatenate the tensors along the first dimension (dim=0)
|
||||
self.latents_plot = torch.cat(self.latents_plot, dim=0)
|
||||
|
||||
return self.latents_plot
|
||||
|
||||
def plot_images_and_labels(self, plot_image_vars):
|
||||
|
||||
bg_width, bg_height, x_offset_initial, y_offset = self.calculate_background_dimensions()
|
||||
|
||||
background = Image.new('RGBA', (int(bg_width), int(bg_height)), color=(255, 255, 255, 255))
|
||||
|
||||
output_image = []
|
||||
for row_index in range(self.num_rows):
|
||||
x_offset = x_offset_initial
|
||||
|
||||
for col_index in range(self.num_cols):
|
||||
index = col_index * self.num_rows + row_index
|
||||
img = self.image_list[index]
|
||||
output_image.append(self.sampler.pil2tensor(img))
|
||||
background.paste(img, (x_offset, y_offset))
|
||||
|
||||
# Handle X label
|
||||
if row_index == 0 and self.x_type != "None":
|
||||
label_bg = self.create_label(img, self.x_label[col_index], int(48 * img.width / 512))
|
||||
label_y = (y_offset - label_bg.height) // 2
|
||||
background.alpha_composite(label_bg, (x_offset, label_y))
|
||||
|
||||
# Handle Y label
|
||||
if col_index == 0 and self.y_type != "None":
|
||||
label_bg = self.create_label(img, self.y_label[row_index], int(48 * img.height / 512), False)
|
||||
label_bg = label_bg.rotate(90, expand=True)
|
||||
|
||||
label_x = (x_offset - label_bg.width) // 2
|
||||
label_y = y_offset + (img.height - label_bg.height) // 2
|
||||
background.alpha_composite(label_bg, (label_x, label_y))
|
||||
|
||||
x_offset += img.width + self.grid_spacing
|
||||
|
||||
y_offset += img.height + self.grid_spacing
|
||||
|
||||
# lookup used models in the image
|
||||
common_label = ""
|
||||
# Update to add a function to do the heavy lifting. Parameters are plot_image_vars name, label to use, names of the axis,
|
||||
|
||||
# pprint.pp(plot_image_vars)
|
||||
|
||||
# We don't process LORAs here because there can be multiple of them.
|
||||
labels = [
|
||||
{"id": "ckpt_name", "id_desc": "ckpt", "axis_type" : "Checkpoint"},
|
||||
{"id": "vae_name", "id_desc": '', "axis_type" : "vae_name"},
|
||||
{"id": "sampler_name", "id_desc": "sampler", "axis_type" : "Sampler"},
|
||||
{"id": "scheduler", "id_desc": '', "axis_type" : "Scheduler"},
|
||||
{"id": "steps", "id_desc": '', "axis_type" : "Steps"},
|
||||
{"id": "Flux Guidance", "id_desc": 'guidance', "axis_type" : "Flux Guidance"},
|
||||
{"id": "seed", "id_desc": '', "axis_type" : "Seeds++ Batch"}
|
||||
]
|
||||
|
||||
for item in labels:
|
||||
# Only add the label if it's not one of the axis
|
||||
# print(f"Checking item: {item['id']} axis_type {item['axis_type']} x_type: {self.x_type} y_type: {self.y_type}")
|
||||
if self.x_type != item['axis_type'] and self.y_type != item['axis_type']:
|
||||
common_label += self.add_common_label(item['id'], plot_image_vars, item['id_desc'])
|
||||
common_label += f"\n"
|
||||
|
||||
if plot_image_vars['lora_stack'] is not None and plot_image_vars['lora_stack'] != []:
|
||||
# print(f"lora_stack: {plot_image_vars['lora_stack']}")
|
||||
for lora in plot_image_vars['lora_stack']:
|
||||
|
||||
lora_name = lora['lora_name']
|
||||
lora_weight = lora['model_strength']
|
||||
if lora_name is not None and len(lora_name) > 0 and lora_weight > 0:
|
||||
common_label += f"LORA: {lora_name} weight: {lora_weight:.2f} \n"
|
||||
|
||||
common_label = common_label.strip()
|
||||
|
||||
if len(common_label) > 0:
|
||||
label_height = background.height - y_offset
|
||||
label_bg = self.create_label(background, common_label, int(48 * background.width / 512), label_width=background.width, label_height=label_height)
|
||||
label_x = (background.width - label_bg.width) // 2
|
||||
label_y = y_offset
|
||||
# print(f"Adding common label: {common_label} x = {label_x} y = {label_y}")
|
||||
background.alpha_composite(label_bg, (label_x, label_y))
|
||||
|
||||
return (self.sampler.pil2tensor(background), output_image)
|
||||
|
||||
def add_common_label(self, tag, plot_image_vars, description = ''):
|
||||
label = ''
|
||||
if description == '': description = tag
|
||||
if tag in plot_image_vars and plot_image_vars[tag] is not None and plot_image_vars[tag] != 'None':
|
||||
label += f"{description}: {plot_image_vars[tag]} "
|
||||
# print(f"add_common_label: {tag} description: {description} label: {label}" )
|
||||
return label
|
||||
1310
custom_nodes/ComfyUI-Easy-Use/py/modules/ben/model.py
Normal file
1310
custom_nodes/ComfyUI-Easy-Use/py/modules/ben/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,167 @@
|
||||
#credit to comfyanonymous for this module
|
||||
#from https://github.com/comfyanonymous/ComfyUI_bitsandbytes_NF4
|
||||
import comfy.ops
|
||||
import torch
|
||||
import folder_paths
|
||||
from ...libs.utils import install_package
|
||||
|
||||
try:
|
||||
from bitsandbytes.nn.modules import Params4bit, QuantState
|
||||
except ImportError:
|
||||
Params4bit = torch.nn.Parameter
|
||||
raise ImportError("Please install bitsandbytes>=0.43.3")
|
||||
|
||||
def functional_linear_4bits(x, weight, bias):
|
||||
try:
|
||||
install_package("bitsandbytes", "0.43.3", True, "0.43.3")
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("Please install bitsandbytes>=0.43.3")
|
||||
|
||||
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
||||
out = out.to(x)
|
||||
return out
|
||||
|
||||
|
||||
def copy_quant_state(state, device: torch.device = None):
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
device = device or state.absmax.device
|
||||
|
||||
state2 = (
|
||||
QuantState(
|
||||
absmax=state.state2.absmax.to(device),
|
||||
shape=state.state2.shape,
|
||||
code=state.state2.code.to(device),
|
||||
blocksize=state.state2.blocksize,
|
||||
quant_type=state.state2.quant_type,
|
||||
dtype=state.state2.dtype,
|
||||
)
|
||||
if state.nested
|
||||
else None
|
||||
)
|
||||
|
||||
return QuantState(
|
||||
absmax=state.absmax.to(device),
|
||||
shape=state.shape,
|
||||
code=state.code.to(device),
|
||||
blocksize=state.blocksize,
|
||||
quant_type=state.quant_type,
|
||||
dtype=state.dtype,
|
||||
offset=state.offset.to(device) if state.nested else None,
|
||||
state2=state2,
|
||||
)
|
||||
|
||||
|
||||
class ForgeParams4bit(Params4bit):
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
||||
return self._quantize(device)
|
||||
else:
|
||||
n = ForgeParams4bit(
|
||||
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad,
|
||||
quant_state=copy_quant_state(self.quant_state, device),
|
||||
blocksize=self.blocksize,
|
||||
compress_statistics=self.compress_statistics,
|
||||
quant_type=self.quant_type,
|
||||
quant_storage=self.quant_storage,
|
||||
bnb_quantized=self.bnb_quantized,
|
||||
module=self.module
|
||||
)
|
||||
self.module.quant_state = n.quant_state
|
||||
self.data = n.data
|
||||
self.quant_state = n.quant_state
|
||||
return n
|
||||
|
||||
class ForgeLoader4Bit(torch.nn.Module):
|
||||
def __init__(self, *, device, dtype, quant_type, **kwargs):
|
||||
super().__init__()
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
|
||||
self.weight = None
|
||||
self.quant_state = None
|
||||
self.bias = None
|
||||
self.quant_type = quant_type
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
quant_state = getattr(self.weight, "quant_state", None)
|
||||
if quant_state is not None:
|
||||
for k, v in quant_state.as_dict(packed=True).items():
|
||||
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
|
||||
return
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
|
||||
|
||||
if any('bitsandbytes' in k for k in quant_state_keys):
|
||||
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
|
||||
|
||||
self.weight = ForgeParams4bit().from_prequantized(
|
||||
data=state_dict[prefix + 'weight'],
|
||||
quantized_stats=quant_state_dict,
|
||||
requires_grad=False,
|
||||
device=self.dummy.device,
|
||||
module=self
|
||||
)
|
||||
self.quant_state = self.weight.quant_state
|
||||
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
|
||||
del self.dummy
|
||||
elif hasattr(self, 'dummy'):
|
||||
if prefix + 'weight' in state_dict:
|
||||
self.weight = ForgeParams4bit(
|
||||
state_dict[prefix + 'weight'].to(self.dummy),
|
||||
requires_grad=False,
|
||||
compress_statistics=True,
|
||||
quant_type=self.quant_type,
|
||||
quant_storage=torch.uint8,
|
||||
module=self,
|
||||
)
|
||||
self.quant_state = self.weight.quant_state
|
||||
|
||||
if prefix + 'bias' in state_dict:
|
||||
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
||||
|
||||
del self.dummy
|
||||
else:
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
current_device = None
|
||||
current_dtype = None
|
||||
current_manual_cast_enabled = False
|
||||
current_bnb_dtype = None
|
||||
|
||||
class OPS(comfy.ops.manual_cast):
|
||||
class Linear(ForgeLoader4Bit):
|
||||
def __init__(self, *args, device=None, dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
|
||||
self.parameters_manual_cast = current_manual_cast_enabled
|
||||
|
||||
def forward(self, x):
|
||||
self.weight.quant_state = self.quant_state
|
||||
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
# Maybe this can also be set to all non-bnb ops since the cost is very low.
|
||||
# And it only invokes one time, and most linear does not have bias
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if not self.parameters_manual_cast:
|
||||
return functional_linear_4bits(x, self.weight, self.bias)
|
||||
elif not self.weight.bnb_quantized:
|
||||
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
|
||||
layer_original_device = self.weight.device
|
||||
self.weight = self.weight._quantize(x.device)
|
||||
bias = self.bias.to(x.device) if self.bias is not None else None
|
||||
out = functional_linear_4bits(x, self.weight, bias)
|
||||
self.weight = self.weight.to(layer_original_device)
|
||||
return out
|
||||
else:
|
||||
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
||||
with main_stream_worker(weight, bias, signal):
|
||||
return functional_linear_4bits(x, weight, bias)
|
||||
475
custom_nodes/ComfyUI-Easy-Use/py/modules/briaai/rembg.py
Normal file
475
custom_nodes/ComfyUI-Easy-Use/py/modules/briaai/rembg.py
Normal file
@@ -0,0 +1,475 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms.functional import normalize
|
||||
import numpy as np
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
||||
super(REBNCONV,self).__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src,tar):
|
||||
|
||||
src = F.interpolate(src,size=tar.shape[2:],mode='bilinear')
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
||||
super(RSU7,self).__init__()
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.mid_ch = mid_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) ## 1 -> 1/2
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
||||
hx6dup = _upsample_like(hx6d,hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F,self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class myrebnconv(nn.Module):
|
||||
def __init__(self, in_ch=3,
|
||||
out_ch=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1):
|
||||
super(myrebnconv,self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups)
|
||||
self.bn = nn.BatchNorm2d(out_ch)
|
||||
self.rl = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self,x):
|
||||
return self.rl(self.bn(self.conv(x)))
|
||||
|
||||
def preprocess_image(im, model_input_size: list) -> torch.Tensor:
|
||||
# im = im.resize(model_input_size, Image.BILINEAR)
|
||||
im_np = np.array(im)
|
||||
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1)
|
||||
im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
|
||||
image = torch.divide(im_tensor,255.0)
|
||||
image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
|
||||
return image
|
||||
|
||||
def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
|
||||
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
|
||||
ma = torch.max(result)
|
||||
mi = torch.min(result)
|
||||
result = (result-mi)/(ma-mi)
|
||||
im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
|
||||
im_array = np.squeeze(im_array)
|
||||
return im_array
|
||||
|
||||
class BriaRMBG(nn.Module):
|
||||
|
||||
def __init__(self, config:dict={"in_ch":3,"out_ch":1}):
|
||||
super(BriaRMBG,self).__init__()
|
||||
in_ch = config["in_ch"]
|
||||
out_ch = config["out_ch"]
|
||||
self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
||||
self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage1 = RSU7(64,32,64)
|
||||
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64,32,128)
|
||||
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128,64,256)
|
||||
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256,128,512)
|
||||
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512,256,512)
|
||||
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512,256,512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024,256,512)
|
||||
self.stage4d = RSU4(1024,128,256)
|
||||
self.stage3d = RSU5(512,64,128)
|
||||
self.stage2d = RSU6(256,32,64)
|
||||
self.stage1d = RSU7(128,16,64)
|
||||
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
|
||||
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.conv_in(hx)
|
||||
#hx = self.pool_in(hxin)
|
||||
|
||||
#stage 1
|
||||
hx1 = self.stage1(hxin)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
#stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
#stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
#stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
#stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6,hx5)
|
||||
|
||||
#-------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
|
||||
#side output
|
||||
d1 = self.side1(hx1d)
|
||||
d1 = _upsample_like(d1,x)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2,x)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3,x)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4,x)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5,x)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6,x)
|
||||
|
||||
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6]
|
||||
822
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/__init__.py
Normal file
822
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/__init__.py
Normal file
@@ -0,0 +1,822 @@
|
||||
#credit to nullquant for this module
|
||||
#from https://github.com/nullquant/ComfyUI-BrushNet
|
||||
|
||||
import os
|
||||
import types
|
||||
|
||||
import torch
|
||||
try:
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
except:
|
||||
init_empty_weights, load_checkpoint_and_dispatch = None, None
|
||||
|
||||
import comfy
|
||||
|
||||
try:
|
||||
from .model import BrushNetModel, PowerPaintModel
|
||||
from .model_patch import add_model_patch_option, patch_model_function_wrapper
|
||||
from .powerpaint_utils import TokenizerWrapper, add_tokens
|
||||
except:
|
||||
BrushNetModel, PowerPaintModel = None, None
|
||||
add_model_patch_option, patch_model_function_wrapper = None, None
|
||||
TokenizerWrapper, add_tokens = None, None
|
||||
|
||||
cwd_path = os.path.dirname(os.path.realpath(__file__))
|
||||
brushnet_config_file = os.path.join(cwd_path, 'config', 'brushnet.json')
|
||||
brushnet_xl_config_file = os.path.join(cwd_path, 'config', 'brushnet_xl.json')
|
||||
powerpaint_config_file = os.path.join(cwd_path, 'config', 'powerpaint.json')
|
||||
|
||||
sd15_scaling_factor = 0.18215
|
||||
sdxl_scaling_factor = 0.13025
|
||||
|
||||
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel, comfy.ldm.models.autoencoder.AutoencoderKL]
|
||||
|
||||
class BrushNet:
|
||||
|
||||
# Check models compatibility
|
||||
def check_compatibilty(self, model, brushnet):
|
||||
is_SDXL = False
|
||||
is_PP = False
|
||||
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
||||
print('Base model type: SD1.5')
|
||||
is_SDXL = False
|
||||
if brushnet["SDXL"]:
|
||||
raise Exception("Base model is SD15, but BrushNet is SDXL type")
|
||||
if brushnet["PP"]:
|
||||
is_PP = True
|
||||
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
||||
print('Base model type: SDXL')
|
||||
is_SDXL = True
|
||||
if not brushnet["SDXL"]:
|
||||
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
|
||||
else:
|
||||
print('Base model type: ', type(model.model.model_config))
|
||||
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
|
||||
|
||||
return (is_SDXL, is_PP)
|
||||
|
||||
def check_image_mask(self, image, mask, name):
|
||||
if len(image.shape) < 4:
|
||||
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
||||
image = image[None, :, :, :]
|
||||
|
||||
if len(mask.shape) > 3:
|
||||
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
|
||||
# take first mask, red channel
|
||||
mask = (mask[:, :, :, 0])[:, :, :]
|
||||
elif len(mask.shape) < 3:
|
||||
# mask tensor shape should be [B, H, W] but batch somehow is missing
|
||||
mask = mask[None, :, :]
|
||||
|
||||
if image.shape[0] > mask.shape[0]:
|
||||
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
|
||||
if mask.shape[0] == 1:
|
||||
print(name, "will copy the mask to fill batch")
|
||||
mask = torch.cat([mask] * image.shape[0], dim=0)
|
||||
else:
|
||||
print(name, "will add empty masks to fill batch")
|
||||
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
|
||||
mask = torch.cat([mask, empty_mask], dim=0)
|
||||
elif image.shape[0] < mask.shape[0]:
|
||||
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
|
||||
mask = mask[:image.shape[0], :, :]
|
||||
|
||||
return (image, mask)
|
||||
|
||||
# Prepare image and mask
|
||||
def prepare_image(self, image, mask):
|
||||
|
||||
image, mask = self.check_image_mask(image, mask, 'BrushNet')
|
||||
|
||||
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
|
||||
|
||||
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
|
||||
raise Exception("Image and mask should be the same size")
|
||||
|
||||
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
|
||||
mask = mask.round()
|
||||
|
||||
masked_image = image * (1.0 - mask[:, :, :, None])
|
||||
|
||||
return (masked_image, mask)
|
||||
|
||||
# Get origin of the mask
|
||||
def cut_with_mask(self, mask, width, height):
|
||||
iy, ix = (mask == 1).nonzero(as_tuple=True)
|
||||
|
||||
h0, w0 = mask.shape
|
||||
|
||||
if iy.numel() == 0:
|
||||
x_c = w0 / 2.0
|
||||
y_c = h0 / 2.0
|
||||
else:
|
||||
x_min = ix.min().item()
|
||||
x_max = ix.max().item()
|
||||
y_min = iy.min().item()
|
||||
y_max = iy.max().item()
|
||||
|
||||
if x_max - x_min > width or y_max - y_min > height:
|
||||
raise Exception("Mask is bigger than provided dimensions")
|
||||
|
||||
x_c = (x_min + x_max) / 2.0
|
||||
y_c = (y_min + y_max) / 2.0
|
||||
|
||||
width2 = width / 2.0
|
||||
height2 = height / 2.0
|
||||
|
||||
if w0 <= width:
|
||||
x0 = 0
|
||||
w = w0
|
||||
else:
|
||||
x0 = max(0, x_c - width2)
|
||||
w = width
|
||||
if x0 + width > w0:
|
||||
x0 = w0 - width
|
||||
|
||||
if h0 <= height:
|
||||
y0 = 0
|
||||
h = h0
|
||||
else:
|
||||
y0 = max(0, y_c - height2)
|
||||
h = height
|
||||
if y0 + height > h0:
|
||||
y0 = h0 - height
|
||||
|
||||
return (int(x0), int(y0), int(w), int(h))
|
||||
|
||||
# Prepare conditioning_latents
|
||||
@torch.inference_mode()
|
||||
def get_image_latents(self, masked_image, mask, vae, scaling_factor):
|
||||
processed_image = masked_image.to(vae.device)
|
||||
image_latents = vae.encode(processed_image[:, :, :, :3]) * scaling_factor
|
||||
processed_mask = 1. - mask[:, None, :, :]
|
||||
interpolated_mask = torch.nn.functional.interpolate(
|
||||
processed_mask,
|
||||
size=(
|
||||
image_latents.shape[-2],
|
||||
image_latents.shape[-1]
|
||||
)
|
||||
)
|
||||
interpolated_mask = interpolated_mask.to(image_latents.device)
|
||||
|
||||
conditioning_latents = [image_latents, interpolated_mask]
|
||||
|
||||
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =',
|
||||
interpolated_mask.shape)
|
||||
|
||||
return conditioning_latents
|
||||
|
||||
def brushnet_blocks(self, sd):
|
||||
brushnet_down_block = 0
|
||||
brushnet_mid_block = 0
|
||||
brushnet_up_block = 0
|
||||
for key in sd:
|
||||
if 'brushnet_down_block' in key:
|
||||
brushnet_down_block += 1
|
||||
if 'brushnet_mid_block' in key:
|
||||
brushnet_mid_block += 1
|
||||
if 'brushnet_up_block' in key:
|
||||
brushnet_up_block += 1
|
||||
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
|
||||
|
||||
def get_model_type(self, brushnet_file):
|
||||
sd = comfy.utils.load_torch_file(brushnet_file)
|
||||
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = self.brushnet_blocks(sd)
|
||||
del sd
|
||||
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
|
||||
is_SDXL = False
|
||||
if keys == 322:
|
||||
is_PP = False
|
||||
print('BrushNet model type: SD1.5')
|
||||
else:
|
||||
is_PP = True
|
||||
print('PowerPaint model type: SD1.5')
|
||||
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
|
||||
print('BrushNet model type: Loading SDXL')
|
||||
is_SDXL = True
|
||||
is_PP = False
|
||||
else:
|
||||
raise Exception("Unknown BrushNet model")
|
||||
return is_SDXL, is_PP
|
||||
|
||||
def load_brushnet_model(self, brushnet_file, dtype='float16'):
|
||||
is_SDXL, is_PP = self.get_model_type(brushnet_file)
|
||||
with init_empty_weights():
|
||||
if is_SDXL:
|
||||
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
|
||||
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
||||
elif is_PP:
|
||||
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
|
||||
brushnet_model = PowerPaintModel.from_config(brushnet_config)
|
||||
else:
|
||||
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
|
||||
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
||||
if is_PP:
|
||||
print("PowerPaint model file:", brushnet_file)
|
||||
else:
|
||||
print("BrushNet model file:", brushnet_file)
|
||||
|
||||
if dtype == 'float16':
|
||||
torch_dtype = torch.float16
|
||||
elif dtype == 'bfloat16':
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == 'float32':
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
torch_dtype = torch.float64
|
||||
|
||||
brushnet_model = load_checkpoint_and_dispatch(
|
||||
brushnet_model,
|
||||
brushnet_file,
|
||||
device_map="sequential",
|
||||
max_memory=None,
|
||||
offload_folder=None,
|
||||
offload_state_dict=False,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=False,
|
||||
)
|
||||
|
||||
if is_PP:
|
||||
print("PowerPaint model is loaded")
|
||||
elif is_SDXL:
|
||||
print("BrushNet SDXL model is loaded")
|
||||
else:
|
||||
print("BrushNet SD1.5 model is loaded")
|
||||
|
||||
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype},)
|
||||
|
||||
def brushnet_model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
|
||||
|
||||
is_SDXL, is_PP = self.check_compatibilty(model, brushnet)
|
||||
|
||||
if is_PP:
|
||||
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
|
||||
|
||||
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
||||
model = model.clone()
|
||||
|
||||
# prepare image and mask
|
||||
# no batches for original image and mask
|
||||
masked_image, mask = self.prepare_image(image, mask)
|
||||
|
||||
batch = masked_image.shape[0]
|
||||
width = masked_image.shape[2]
|
||||
height = masked_image.shape[1]
|
||||
|
||||
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
|
||||
'scale_factor'):
|
||||
scaling_factor = model.model.model_config.latent_format.scale_factor
|
||||
elif is_SDXL:
|
||||
scaling_factor = sdxl_scaling_factor
|
||||
else:
|
||||
scaling_factor = sd15_scaling_factor
|
||||
|
||||
torch_dtype = brushnet['dtype']
|
||||
|
||||
# prepare conditioning latents
|
||||
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
|
||||
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
|
||||
# unload vae
|
||||
del vae
|
||||
# for loaded_model in comfy.model_management.current_loaded_models:
|
||||
# if type(loaded_model.model.model) in ModelsToUnload:
|
||||
# comfy.model_management.current_loaded_models.remove(loaded_model)
|
||||
# loaded_model.model_unload()
|
||||
# del loaded_model
|
||||
|
||||
# prepare embeddings
|
||||
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
|
||||
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
||||
if prompt_embeds.shape[1] < max_tokens:
|
||||
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
|
||||
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:, -77:, :]] * multiplier, dim=1)
|
||||
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape,
|
||||
'multiplying prompt_embeds')
|
||||
if negative_prompt_embeds.shape[1] < max_tokens:
|
||||
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
|
||||
negative_prompt_embeds = torch.concat(
|
||||
[negative_prompt_embeds] + [negative_prompt_embeds[:, -77:, :]] * multiplier, dim=1)
|
||||
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape,
|
||||
'multiplying negative_prompt_embeds')
|
||||
|
||||
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
|
||||
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
||||
else:
|
||||
print('BrushNet: positive conditioning has not pooled_output')
|
||||
if is_SDXL:
|
||||
print('BrushNet will not produce correct results')
|
||||
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
||||
|
||||
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
|
||||
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(
|
||||
brushnet['brushnet'].device)
|
||||
else:
|
||||
print('BrushNet: negative conditioning has not pooled_output')
|
||||
if is_SDXL:
|
||||
print('BrushNet will not produce correct results')
|
||||
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]],
|
||||
device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
||||
|
||||
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(
|
||||
brushnet['brushnet'].device)
|
||||
|
||||
if not is_SDXL:
|
||||
pooled_prompt_embeds = None
|
||||
negative_pooled_prompt_embeds = None
|
||||
time_ids = None
|
||||
|
||||
# apply patch to model
|
||||
brushnet_conditioning_scale = scale
|
||||
control_guidance_start = start_at
|
||||
control_guidance_end = end_at
|
||||
|
||||
add_brushnet_patch(model,
|
||||
brushnet['brushnet'],
|
||||
torch_dtype,
|
||||
conditioning_latents,
|
||||
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
||||
prompt_embeds, negative_prompt_embeds,
|
||||
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
|
||||
False)
|
||||
|
||||
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
|
||||
device=brushnet['brushnet'].device)
|
||||
|
||||
return (model, positive, negative, {"samples": latent},)
|
||||
|
||||
#powperpaint
|
||||
def load_powerpaint_clip(self, base_clip_file, pp_clip_file):
|
||||
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_clip_file])
|
||||
|
||||
print('PowerPaint base CLIP file: ', base_clip_file)
|
||||
|
||||
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
|
||||
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
|
||||
|
||||
add_tokens(
|
||||
tokenizer=pp_tokenizer,
|
||||
text_encoder=pp_text_encoder,
|
||||
placeholder_tokens=["P_ctxt", "P_shape", "P_obj"],
|
||||
initialize_tokens=["a", "a", "a"],
|
||||
num_vectors_per_token=10,
|
||||
)
|
||||
|
||||
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_clip_file), strict=False)
|
||||
|
||||
print('PowerPaint CLIP file: ', pp_clip_file)
|
||||
|
||||
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
|
||||
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
|
||||
|
||||
return (pp_clip,)
|
||||
|
||||
def powerpaint_model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
|
||||
is_SDXL, is_PP = self.check_compatibilty(model, powerpaint)
|
||||
if not is_PP:
|
||||
raise Exception("BrushNet model was loaded, please use BrushNet node")
|
||||
|
||||
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
||||
model = model.clone()
|
||||
|
||||
# prepare image and mask
|
||||
# no batches for original image and mask
|
||||
masked_image, mask = self.prepare_image(image, mask)
|
||||
|
||||
batch = masked_image.shape[0]
|
||||
# width = masked_image.shape[2]
|
||||
# height = masked_image.shape[1]
|
||||
|
||||
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format,
|
||||
'scale_factor'):
|
||||
scaling_factor = model.model.model_config.latent_format.scale_factor
|
||||
else:
|
||||
scaling_factor = sd15_scaling_factor
|
||||
|
||||
torch_dtype = powerpaint['dtype']
|
||||
|
||||
# prepare conditioning latents
|
||||
conditioning_latents = self.get_image_latents(masked_image, mask, vae, scaling_factor)
|
||||
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
||||
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
||||
|
||||
# prepare embeddings
|
||||
|
||||
if function == "object removal":
|
||||
promptA = "P_ctxt"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = "P_obj"
|
||||
negative_promptB = "P_obj"
|
||||
print('You should add to positive prompt: "empty scene blur"')
|
||||
# positive = positive + " empty scene blur"
|
||||
elif function == "context aware":
|
||||
promptA = "P_ctxt"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = ""
|
||||
negative_promptB = ""
|
||||
# positive = positive + " empty scene"
|
||||
print('You should add to positive prompt: "empty scene"')
|
||||
elif function == "shape guided":
|
||||
promptA = "P_shape"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = "P_shape"
|
||||
negative_promptB = "P_ctxt"
|
||||
elif function == "image outpainting":
|
||||
promptA = "P_ctxt"
|
||||
promptB = "P_ctxt"
|
||||
negative_promptA = "P_obj"
|
||||
negative_promptB = "P_obj"
|
||||
# positive = positive + " empty scene"
|
||||
print('You should add to positive prompt: "empty scene"')
|
||||
else:
|
||||
promptA = "P_obj"
|
||||
promptB = "P_obj"
|
||||
negative_promptA = "P_obj"
|
||||
negative_promptB = "P_obj"
|
||||
|
||||
tokens = clip.tokenize(promptA)
|
||||
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
tokens = clip.tokenize(negative_promptA)
|
||||
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
tokens = clip.tokenize(promptB)
|
||||
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
tokens = clip.tokenize(negative_promptB)
|
||||
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
||||
|
||||
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(
|
||||
powerpaint['brushnet'].device)
|
||||
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(
|
||||
dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
||||
|
||||
# unload vae and CLIPs
|
||||
del vae
|
||||
del clip
|
||||
# for loaded_model in comfy.model_management.current_loaded_models:
|
||||
# if type(loaded_model.model.model) in ModelsToUnload:
|
||||
# comfy.model_management.current_loaded_models.remove(loaded_model)
|
||||
# loaded_model.model_unload()
|
||||
# del loaded_model
|
||||
|
||||
# apply patch to model
|
||||
|
||||
brushnet_conditioning_scale = scale
|
||||
control_guidance_start = start_at
|
||||
control_guidance_end = end_at
|
||||
|
||||
if save_memory != 'none':
|
||||
powerpaint['brushnet'].set_attention_slice(save_memory)
|
||||
|
||||
add_brushnet_patch(model,
|
||||
powerpaint['brushnet'],
|
||||
torch_dtype,
|
||||
conditioning_latents,
|
||||
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
||||
negative_prompt_embeds_pp, prompt_embeds_pp,
|
||||
None, None, None,
|
||||
False)
|
||||
|
||||
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]],
|
||||
device=powerpaint['brushnet'].device)
|
||||
|
||||
return (model, positive, negative, {"samples": latent},)
|
||||
@torch.inference_mode()
|
||||
def brushnet_inference(x, timesteps, transformer_options, debug):
|
||||
if 'model_patch' not in transformer_options:
|
||||
print('BrushNet inference: there is no model_patch key in transformer_options')
|
||||
return ([], 0, [])
|
||||
mp = transformer_options['model_patch']
|
||||
if 'brushnet' not in mp:
|
||||
print('BrushNet inference: there is no brushnet key in mdel_patch')
|
||||
return ([], 0, [])
|
||||
bo = mp['brushnet']
|
||||
if 'model' not in bo:
|
||||
print('BrushNet inference: there is no model key in brushnet')
|
||||
return ([], 0, [])
|
||||
brushnet = bo['model']
|
||||
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
|
||||
print('BrushNet model is not a BrushNetModel class')
|
||||
return ([], 0, [])
|
||||
|
||||
torch_dtype = bo['dtype']
|
||||
cl_list = bo['latents']
|
||||
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
|
||||
pe = bo['prompt_embeds']
|
||||
npe = bo['negative_prompt_embeds']
|
||||
ppe, nppe, time_ids = bo['add_embeds']
|
||||
|
||||
#do_classifier_free_guidance = mp['free_guidance']
|
||||
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
|
||||
|
||||
x = x.detach().clone()
|
||||
x = x.to(torch_dtype).to(brushnet.device)
|
||||
|
||||
timesteps = timesteps.detach().clone()
|
||||
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
|
||||
|
||||
total_steps = mp['total_steps']
|
||||
step = mp['step']
|
||||
|
||||
added_cond_kwargs = {}
|
||||
|
||||
if do_classifier_free_guidance and step == 0:
|
||||
print('BrushNet inference: do_classifier_free_guidance is True')
|
||||
|
||||
sub_idx = None
|
||||
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
|
||||
sub_idx = transformer_options['ad_params']['sub_idxs']
|
||||
|
||||
# we have batch input images
|
||||
batch = cl_list[0].shape[0]
|
||||
# we have incoming latents
|
||||
latents_incoming = x.shape[0]
|
||||
# and we already got some
|
||||
latents_got = bo['latent_id']
|
||||
if step == 0 or batch > 1:
|
||||
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
|
||||
% (step, batch, latents_incoming, latents_got))
|
||||
|
||||
image_latents = []
|
||||
masks = []
|
||||
prompt_embeds = []
|
||||
negative_prompt_embeds = []
|
||||
pooled_prompt_embeds = []
|
||||
negative_pooled_prompt_embeds = []
|
||||
if sub_idx:
|
||||
# AnimateDiff indexes detected
|
||||
if step == 0:
|
||||
print('BrushNet inference: AnimateDiff indexes detected and applied')
|
||||
|
||||
batch = len(sub_idx)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
for i in sub_idx:
|
||||
image_latents.append(cl_list[0][i][None,:,:,:])
|
||||
masks.append(cl_list[1][i][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
negative_prompt_embeds.append(npe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
negative_pooled_prompt_embeds.append(nppe)
|
||||
for i in sub_idx:
|
||||
image_latents.append(cl_list[0][i][None,:,:,:])
|
||||
masks.append(cl_list[1][i][None,:,:,:])
|
||||
else:
|
||||
for i in sub_idx:
|
||||
image_latents.append(cl_list[0][i][None,:,:,:])
|
||||
masks.append(cl_list[1][i][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
else:
|
||||
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
|
||||
continue_batch = True
|
||||
for i in range(latents_incoming):
|
||||
number = latents_got + i
|
||||
if number < batch:
|
||||
# 1st pass, cond
|
||||
image_latents.append(cl_list[0][number][None,:,:,:])
|
||||
masks.append(cl_list[1][number][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
elif do_classifier_free_guidance and number < batch * 2:
|
||||
# 2nd pass, uncond
|
||||
image_latents.append(cl_list[0][number-batch][None,:,:,:])
|
||||
masks.append(cl_list[1][number-batch][None,:,:,:])
|
||||
negative_prompt_embeds.append(npe)
|
||||
negative_pooled_prompt_embeds.append(nppe)
|
||||
else:
|
||||
# latent batch
|
||||
image_latents.append(cl_list[0][0][None,:,:,:])
|
||||
masks.append(cl_list[1][0][None,:,:,:])
|
||||
prompt_embeds.append(pe)
|
||||
pooled_prompt_embeds.append(ppe)
|
||||
latents_got = -i
|
||||
continue_batch = False
|
||||
|
||||
if continue_batch:
|
||||
# we don't have full batch yet
|
||||
if do_classifier_free_guidance:
|
||||
if number < batch * 2 - 1:
|
||||
bo['latent_id'] = number + 1
|
||||
else:
|
||||
bo['latent_id'] = 0
|
||||
else:
|
||||
if number < batch - 1:
|
||||
bo['latent_id'] = number + 1
|
||||
else:
|
||||
bo['latent_id'] = 0
|
||||
else:
|
||||
bo['latent_id'] = 0
|
||||
|
||||
cl = []
|
||||
for il, m in zip(image_latents, masks):
|
||||
cl.append(torch.concat([il, m], dim=1))
|
||||
cl2apply = torch.concat(cl, dim=0)
|
||||
|
||||
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
|
||||
|
||||
prompt_embeds.extend(negative_prompt_embeds)
|
||||
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
||||
|
||||
if ppe is not None:
|
||||
added_cond_kwargs = {}
|
||||
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
|
||||
|
||||
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
|
||||
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
||||
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
|
||||
else:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
|
||||
if step == 0:
|
||||
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
|
||||
conditioning_latents = torch.nn.functional.interpolate(
|
||||
conditioning_latents, size=(
|
||||
x.shape[2],
|
||||
x.shape[3],
|
||||
), mode='bicubic',
|
||||
).to(torch_dtype).to(brushnet.device)
|
||||
|
||||
if step == 0:
|
||||
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
|
||||
|
||||
if debug: print('BrushNet: step =', step)
|
||||
|
||||
if step < control_guidance_start or step > control_guidance_end:
|
||||
cond_scale = 0.0
|
||||
else:
|
||||
cond_scale = brushnet_conditioning_scale
|
||||
|
||||
return brushnet(x,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
brushnet_cond=conditioning_latents,
|
||||
timestep = timesteps,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=False,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
|
||||
controls,
|
||||
prompt_embeds, negative_prompt_embeds,
|
||||
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
|
||||
debug):
|
||||
|
||||
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
|
||||
|
||||
if model.model.model_config.custom_operations is None:
|
||||
fp8 = model.model.model_config.optimizations.get("fp8", model.model.model_config.scaled_fp8 is not None)
|
||||
operations = comfy.ops.pick_operations(model.model.model_config.unet_config.get("dtype", None), model.model.manual_cast_dtype,
|
||||
fp8_optimizations=fp8, scaled_fp8=model.model.model_config.scaled_fp8)
|
||||
else:
|
||||
# such as gguf
|
||||
operations = model.model.model_config.custom_operations
|
||||
|
||||
if is_SDXL:
|
||||
input_blocks = [[0, operations.Conv2d],
|
||||
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.attention.SpatialTransformer]]
|
||||
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
||||
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
||||
else:
|
||||
input_blocks = [[0, operations.Conv2d],
|
||||
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
||||
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
||||
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
||||
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
||||
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[6, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
||||
[9, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[10, comfy.ldm.modules.attention.SpatialTransformer],
|
||||
[11, comfy.ldm.modules.attention.SpatialTransformer]]
|
||||
|
||||
def last_layer_index(block, tp):
|
||||
layer_list = []
|
||||
for layer in block:
|
||||
layer_list.append(type(layer))
|
||||
layer_list.reverse()
|
||||
if tp not in layer_list:
|
||||
return -1, layer_list.reverse()
|
||||
return len(layer_list) - 1 - layer_list.index(tp), layer_list
|
||||
|
||||
def brushnet_forward(model, x, timesteps, transformer_options, control):
|
||||
if 'brushnet' not in transformer_options['model_patch']:
|
||||
input_samples = []
|
||||
mid_sample = 0
|
||||
output_samples = []
|
||||
else:
|
||||
# brushnet inference
|
||||
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
|
||||
|
||||
# give additional samples to blocks
|
||||
for i, tp in input_blocks:
|
||||
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
|
||||
if idx < 0:
|
||||
print("BrushNet can't find", tp, "layer in", i, "input block:", layer_list)
|
||||
continue
|
||||
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
|
||||
|
||||
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
|
||||
if idx < 0:
|
||||
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
|
||||
model.middle_block[idx].add_sample_after = mid_sample
|
||||
|
||||
for i, tp in output_blocks:
|
||||
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
|
||||
if idx < 0:
|
||||
print("BrushNet can't find", tp, "layer in", i, "outnput block:", layer_list)
|
||||
continue
|
||||
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
|
||||
|
||||
patch_model_function_wrapper(model, brushnet_forward)
|
||||
|
||||
to = add_model_patch_option(model)
|
||||
mp = to['model_patch']
|
||||
if 'brushnet' not in mp:
|
||||
mp['brushnet'] = {}
|
||||
bo = mp['brushnet']
|
||||
|
||||
bo['model'] = brushnet
|
||||
bo['dtype'] = torch_dtype
|
||||
bo['latents'] = conditioning_latents
|
||||
bo['controls'] = controls
|
||||
bo['prompt_embeds'] = prompt_embeds
|
||||
bo['negative_prompt_embeds'] = negative_prompt_embeds
|
||||
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
|
||||
bo['latent_id'] = 0
|
||||
|
||||
# patch layers `forward` so we can apply brushnet
|
||||
def forward_patched_by_brushnet(self, x, *args, **kwargs):
|
||||
h = self.original_forward(x, *args, **kwargs)
|
||||
if hasattr(self, 'add_sample_after') and type(self):
|
||||
to_add = self.add_sample_after
|
||||
if torch.is_tensor(to_add):
|
||||
# interpolate due to RAUNet
|
||||
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
|
||||
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
|
||||
h += to_add.to(h.dtype).to(h.device)
|
||||
else:
|
||||
h += self.add_sample_after
|
||||
self.add_sample_after = 0
|
||||
return h
|
||||
|
||||
for i, block in enumerate(model.model.diffusion_model.input_blocks):
|
||||
for j, layer in enumerate(block):
|
||||
if not hasattr(layer, 'original_forward'):
|
||||
layer.original_forward = layer.forward
|
||||
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
||||
layer.add_sample_after = 0
|
||||
|
||||
for j, layer in enumerate(model.model.diffusion_model.middle_block):
|
||||
if not hasattr(layer, 'original_forward'):
|
||||
layer.original_forward = layer.forward
|
||||
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
||||
layer.add_sample_after = 0
|
||||
|
||||
for i, block in enumerate(model.model.diffusion_model.output_blocks):
|
||||
for j, layer in enumerate(block):
|
||||
if not hasattr(layer, 'original_forward'):
|
||||
layer.original_forward = layer.forward
|
||||
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
||||
layer.add_sample_after = 0
|
||||
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"_class_name": "BrushNetModel",
|
||||
"_diffusers_version": "0.27.0.dev0",
|
||||
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": null,
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": null,
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"brushnet_conditioning_channel_order": "rgb",
|
||||
"class_embed_type": null,
|
||||
"conditioning_channels": 5,
|
||||
"conditioning_embedding_out_channels": [
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256
|
||||
],
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"global_pool_conditions": false,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "MidBlock2D",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"projection_class_embeddings_input_dim": null,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"transformer_layers_per_block": 1,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": false,
|
||||
"use_linear_projection": false
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"_class_name": "BrushNetModel",
|
||||
"_diffusers_version": "0.27.0.dev0",
|
||||
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [
|
||||
5,
|
||||
10,
|
||||
20
|
||||
],
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280
|
||||
],
|
||||
"brushnet_conditioning_channel_order": "rgb",
|
||||
"class_embed_type": null,
|
||||
"conditioning_channels": 5,
|
||||
"conditioning_embedding_out_channels": [
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256
|
||||
],
|
||||
"cross_attention_dim": 2048,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"global_pool_conditions": false,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "MidBlock2D",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"transformer_layers_per_block": [
|
||||
1,
|
||||
2,
|
||||
10
|
||||
],
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": true
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"_class_name": "BrushNetModel",
|
||||
"_diffusers_version": "0.27.2",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": null,
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": null,
|
||||
"attention_head_dim": 8,
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"brushnet_conditioning_channel_order": "rgb",
|
||||
"class_embed_type": null,
|
||||
"conditioning_channels": 5,
|
||||
"conditioning_embedding_out_channels": [
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256
|
||||
],
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"global_pool_conditions": false,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"projection_class_embeddings_input_dim": null,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"transformer_layers_per_block": 1,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
],
|
||||
"upcast_attention": false,
|
||||
"use_linear_projection": false
|
||||
}
|
||||
1688
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model.py
Normal file
1688
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model.py
Normal file
File diff suppressed because it is too large
Load Diff
137
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model_patch.py
Normal file
137
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/model_patch.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import comfy
|
||||
|
||||
# Check and add 'model_patch' to model.model_options['transformer_options']
|
||||
def add_model_patch_option(model):
|
||||
if 'transformer_options' not in model.model_options:
|
||||
model.model_options['transformer_options'] = {}
|
||||
to = model.model_options['transformer_options']
|
||||
if "model_patch" not in to:
|
||||
to["model_patch"] = {}
|
||||
return to
|
||||
|
||||
|
||||
# Patch model with model_function_wrapper
|
||||
def patch_model_function_wrapper(model, forward_patch, remove=False):
|
||||
def brushnet_model_function_wrapper(apply_model_method, options_dict):
|
||||
to = options_dict['c']['transformer_options']
|
||||
|
||||
control = None
|
||||
if 'control' in options_dict['c']:
|
||||
control = options_dict['c']['control']
|
||||
|
||||
x = options_dict['input']
|
||||
timestep = options_dict['timestep']
|
||||
|
||||
# check if there are patches to execute
|
||||
if 'model_patch' not in to or 'forward' not in to['model_patch']:
|
||||
return apply_model_method(x, timestep, **options_dict['c'])
|
||||
|
||||
mp = to['model_patch']
|
||||
unet = mp['unet']
|
||||
|
||||
all_sigmas = mp['all_sigmas']
|
||||
sigma = to['sigmas'][0].item()
|
||||
total_steps = all_sigmas.shape[0] - 1
|
||||
step = torch.argmin((all_sigmas - sigma).abs()).item()
|
||||
|
||||
mp['step'] = step
|
||||
mp['total_steps'] = total_steps
|
||||
|
||||
# comfy.model_base.apply_model
|
||||
xc = model.model.model_sampling.calculate_input(timestep, x)
|
||||
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
|
||||
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
|
||||
t = model.model.model_sampling.timestep(timestep).float()
|
||||
# execute all patches
|
||||
for method in mp['forward']:
|
||||
method(unet, xc, t, to, control)
|
||||
|
||||
return apply_model_method(x, timestep, **options_dict['c'])
|
||||
|
||||
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
|
||||
print('BrushNet is going to replace existing model_function_wrapper:',
|
||||
model.model_options["model_function_wrapper"])
|
||||
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
|
||||
|
||||
to = add_model_patch_option(model)
|
||||
mp = to['model_patch']
|
||||
|
||||
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
||||
mp['SDXL'] = False
|
||||
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
||||
mp['SDXL'] = True
|
||||
else:
|
||||
print('Base model type: ', type(model.model.model_config))
|
||||
raise Exception("Unsupported model type: ", type(model.model.model_config))
|
||||
|
||||
if 'forward' not in mp:
|
||||
mp['forward'] = []
|
||||
|
||||
if remove:
|
||||
if forward_patch in mp['forward']:
|
||||
mp['forward'].remove(forward_patch)
|
||||
else:
|
||||
mp['forward'].append(forward_patch)
|
||||
|
||||
mp['unet'] = model.model.diffusion_model
|
||||
mp['step'] = 0
|
||||
mp['total_steps'] = 1
|
||||
|
||||
# apply patches to code
|
||||
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
|
||||
comfy.samplers.original_sample = comfy.samplers.sample
|
||||
comfy.samplers.sample = modified_sample
|
||||
|
||||
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
|
||||
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
|
||||
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
|
||||
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
|
||||
|
||||
|
||||
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
|
||||
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
|
||||
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
|
||||
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
''' Modified by BrushNet nodes'''
|
||||
cfg_guider = comfy.samplers.CFGGuider(model)
|
||||
cfg_guider.set_conds(positive, negative)
|
||||
cfg_guider.set_cfg(cfg)
|
||||
|
||||
### Modified part ######################################################################
|
||||
to = add_model_patch_option(model)
|
||||
to['model_patch']['all_sigmas'] = sigmas
|
||||
#######################################################################################
|
||||
|
||||
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
|
||||
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
|
||||
def modified_apply_control(h, control, name):
|
||||
'''Modified by BrushNet nodes'''
|
||||
if control is not None and name in control and len(control[name]) > 0:
|
||||
ctrl = control[name].pop()
|
||||
if ctrl is not None:
|
||||
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
|
||||
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(
|
||||
h.dtype).to(h.device)
|
||||
try:
|
||||
h += ctrl
|
||||
except:
|
||||
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
||||
return h
|
||||
|
||||
def add_model_patch(model):
|
||||
to = add_model_patch_option(model)
|
||||
mp = to['model_patch']
|
||||
if "brushnet" in mp:
|
||||
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
||||
mp['SDXL'] = False
|
||||
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
||||
mp['SDXL'] = True
|
||||
else:
|
||||
print('Base model type: ', type(model.model.model_config))
|
||||
raise Exception("Unsupported model type: ", type(model.model.model_config))
|
||||
|
||||
mp['unet'] = model.model.diffusion_model
|
||||
mp['step'] = 0
|
||||
mp['total_steps'] = 1
|
||||
@@ -0,0 +1,467 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
class TokenizerWrapper:
|
||||
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
|
||||
currently. This wrapper is modified from https://github.com/huggingface/dif
|
||||
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
|
||||
py#L358 # noqa.
|
||||
|
||||
Args:
|
||||
from_pretrained (Union[str, os.PathLike], optional): The *model id*
|
||||
of a pretrained model or a path to a *directory* containing
|
||||
model weights and config. Defaults to None.
|
||||
from_config (Union[str, os.PathLike], optional): The *model id*
|
||||
of a pretrained model or a path to a *directory* containing
|
||||
model weights and config. Defaults to None.
|
||||
|
||||
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
|
||||
will be passed to `from_pretrained` function. Otherwise, *args
|
||||
and **kwargs will be used to initialize the model by
|
||||
`self._module_cls(*args, **kwargs)`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.wrapped = tokenizer
|
||||
self.token_map = {}
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self.__dict__:
|
||||
return getattr(self, name)
|
||||
# if name == "wrapped":
|
||||
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
|
||||
|
||||
try:
|
||||
return getattr(self.wrapped, name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"'name' cannot be found in both "
|
||||
f"'{self.__class__.__name__}' and "
|
||||
f"'{self.__class__.__name__}.tokenizer'."
|
||||
)
|
||||
|
||||
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
||||
"""Attempt to add tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
tokens (Union[str, List[str]]): The tokens to be added.
|
||||
"""
|
||||
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
||||
assert num_added_tokens != 0, (
|
||||
f"The tokenizer already contains the token {tokens}. Please pass "
|
||||
"a different `placeholder_token` that is not already in the "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
def get_token_info(self, token: str) -> dict:
|
||||
"""Get the information of a token, including its start and end index in
|
||||
the current tokenizer.
|
||||
|
||||
Args:
|
||||
token (str): The token to be queried.
|
||||
|
||||
Returns:
|
||||
dict: The information of the token, including its start and end
|
||||
index in current tokenizer.
|
||||
"""
|
||||
token_ids = self.__call__(token).input_ids
|
||||
start, end = token_ids[1], token_ids[-2] + 1
|
||||
return {"name": token, "start": start, "end": end}
|
||||
|
||||
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
|
||||
"""Add placeholder tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
placeholder_token (str): The placeholder token to be added.
|
||||
num_vec_per_token (int, optional): The number of vectors of
|
||||
the added placeholder token.
|
||||
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
||||
"""
|
||||
output = []
|
||||
if num_vec_per_token == 1:
|
||||
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
||||
output.append(placeholder_token)
|
||||
else:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
self.try_adding_tokens(ith_token, *args, **kwargs)
|
||||
output.append(ith_token)
|
||||
|
||||
for token in self.token_map:
|
||||
if token in placeholder_token:
|
||||
raise ValueError(
|
||||
f"The tokenizer already has placeholder token {token} "
|
||||
f"that can get confused with {placeholder_token} "
|
||||
"keep placeholder tokens independent"
|
||||
)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def replace_placeholder_tokens_in_text(
|
||||
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
|
||||
) -> Union[str, List[str]]:
|
||||
"""Replace the keywords in text with placeholder tokens. This function
|
||||
will be called in `self.__call__` and `self.encode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
|
||||
return output
|
||||
|
||||
for placeholder_token in self.token_map:
|
||||
if placeholder_token in text:
|
||||
tokens = self.token_map[placeholder_token]
|
||||
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
||||
if vector_shuffle:
|
||||
tokens = copy.copy(tokens)
|
||||
random.shuffle(tokens)
|
||||
text = text.replace(placeholder_token, " ".join(tokens))
|
||||
return text
|
||||
|
||||
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
|
||||
"""Replace the placeholder tokens in text with the original keywords.
|
||||
This function will be called in `self.decode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
||||
return output
|
||||
|
||||
for placeholder_token, tokens in self.token_map.items():
|
||||
merged_tokens = " ".join(tokens)
|
||||
if merged_tokens in text:
|
||||
text = text.replace(merged_tokens, placeholder_token)
|
||||
return text
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
*args,
|
||||
vector_shuffle: bool = False,
|
||||
prop_tokens_to_load: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""The call function of the wrapper.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be tokenized.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(
|
||||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||
)
|
||||
|
||||
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
||||
|
||||
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
||||
"""Encode the passed text to token index.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be encode.
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
||||
return self.wrapped(replaced_text, *args, **kwargs)
|
||||
|
||||
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
|
||||
"""Decode the token index to text.
|
||||
|
||||
Args:
|
||||
token_ids: The token index to be decoded.
|
||||
return_raw: Whether keep the placeholder token in the text.
|
||||
Defaults to False.
|
||||
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The decoded text.
|
||||
"""
|
||||
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
||||
if return_raw:
|
||||
return text
|
||||
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
||||
return replaced_text
|
||||
|
||||
def __repr__(self):
|
||||
"""The representation of the wrapper."""
|
||||
s = super().__repr__()
|
||||
prefix = f"Wrapped Module Class: {self._module_cls}\n"
|
||||
prefix += f"Wrapped Module Name: {self._module_name}\n"
|
||||
if self._from_pretrained:
|
||||
prefix += f"From Pretrained: {self._from_pretrained}\n"
|
||||
s = prefix + s
|
||||
return s
|
||||
|
||||
|
||||
class EmbeddingLayerWithFixes(nn.Module):
|
||||
"""The revised embedding layer to support external embeddings. This design
|
||||
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
||||
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
||||
jack.py#L224 # noqa.
|
||||
|
||||
Args:
|
||||
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
||||
external_embeddings (Union[dict, List[dict]], optional): The external
|
||||
embeddings added to this layer. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.num_embeddings = wrapped.weight.shape[0]
|
||||
|
||||
self.external_embeddings = []
|
||||
if external_embeddings:
|
||||
self.add_embeddings(external_embeddings)
|
||||
|
||||
self.trainable_embeddings = nn.ParameterDict()
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
"""Get the weight of wrapped embedding layer."""
|
||||
return self.wrapped.weight
|
||||
|
||||
def check_duplicate_names(self, embeddings: List[dict]):
|
||||
"""Check whether duplicate names exist in list of 'external
|
||||
embeddings'.
|
||||
|
||||
Args:
|
||||
embeddings (List[dict]): A list of embedding to be check.
|
||||
"""
|
||||
names = [emb["name"] for emb in embeddings]
|
||||
assert len(names) == len(set(names)), (
|
||||
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
||||
)
|
||||
|
||||
def check_ids_overlap(self, embeddings):
|
||||
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
||||
|
||||
Args:
|
||||
embeddings (List[dict]): A list of embedding to be check.
|
||||
"""
|
||||
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
||||
ids_range.sort() # sort by 'start'
|
||||
# check if 'end' has overlapping
|
||||
for idx in range(len(ids_range) - 1):
|
||||
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
||||
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
||||
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
||||
)
|
||||
|
||||
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
||||
"""Add external embeddings to this layer.
|
||||
Use case:
|
||||
Args:
|
||||
embeddings (Union[dict, list[dict]]): The external embeddings to
|
||||
be added. Each dict must contain the following 4 fields: 'name'
|
||||
(the name of this embedding), 'embedding' (the embedding
|
||||
tensor), 'start' (the start token id of this embedding), 'end'
|
||||
(the end token id of this embedding). For example:
|
||||
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
||||
"""
|
||||
if isinstance(embeddings, dict):
|
||||
embeddings = [embeddings]
|
||||
|
||||
self.external_embeddings += embeddings
|
||||
self.check_duplicate_names(self.external_embeddings)
|
||||
self.check_ids_overlap(self.external_embeddings)
|
||||
|
||||
# set for trainable
|
||||
added_trainable_emb_info = []
|
||||
for embedding in embeddings:
|
||||
trainable = embedding.get("trainable", False)
|
||||
if trainable:
|
||||
name = embedding["name"]
|
||||
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
||||
self.trainable_embeddings[name] = embedding["embedding"]
|
||||
added_trainable_emb_info.append(name)
|
||||
|
||||
added_emb_info = [emb["name"] for emb in embeddings]
|
||||
added_emb_info = ", ".join(added_emb_info)
|
||||
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
||||
|
||||
if added_trainable_emb_info:
|
||||
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
||||
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
|
||||
|
||||
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Replace external input ids to 0.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The input ids to be replaced.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The replaced input ids.
|
||||
"""
|
||||
input_ids_fwd = input_ids.clone()
|
||||
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
||||
return input_ids_fwd
|
||||
|
||||
def replace_embeddings(
|
||||
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
||||
) -> torch.Tensor:
|
||||
"""Replace external embedding to the embedding layer. Noted that, in
|
||||
this function we use `torch.cat` to avoid inplace modification.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The original token ids. Shape like
|
||||
[LENGTH, ].
|
||||
embedding (torch.Tensor): The embedding of token ids after
|
||||
`replace_input_ids` function.
|
||||
external_embedding (dict): The external embedding to be replaced.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The replaced embedding.
|
||||
"""
|
||||
new_embedding = []
|
||||
|
||||
name = external_embedding["name"]
|
||||
start = external_embedding["start"]
|
||||
end = external_embedding["end"]
|
||||
target_ids_to_replace = [i for i in range(start, end)]
|
||||
ext_emb = external_embedding["embedding"].to(embedding.device)
|
||||
|
||||
# do not need to replace
|
||||
if not (input_ids == start).any():
|
||||
return embedding
|
||||
|
||||
# start replace
|
||||
s_idx, e_idx = 0, 0
|
||||
while e_idx < len(input_ids):
|
||||
if input_ids[e_idx] == start:
|
||||
if e_idx != 0:
|
||||
# add embedding do not need to replace
|
||||
new_embedding.append(embedding[s_idx:e_idx])
|
||||
|
||||
# check if the next embedding need to replace is valid
|
||||
actually_ids_to_replace = [int(i) for i in input_ids[e_idx: e_idx + end - start]]
|
||||
assert actually_ids_to_replace == target_ids_to_replace, (
|
||||
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
||||
f"Expect '{target_ids_to_replace}' for embedding "
|
||||
f"'{name}' but found '{actually_ids_to_replace}'."
|
||||
)
|
||||
|
||||
new_embedding.append(ext_emb)
|
||||
|
||||
s_idx = e_idx + end - start
|
||||
e_idx = s_idx + 1
|
||||
else:
|
||||
e_idx += 1
|
||||
|
||||
if e_idx == len(input_ids):
|
||||
new_embedding.append(embedding[s_idx:e_idx])
|
||||
|
||||
return torch.cat(new_embedding, dim=0)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None, out_dtype = None):
|
||||
"""The forward function.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
||||
[LENGTH, ].
|
||||
external_embeddings (Optional[List[dict]]): The external
|
||||
embeddings. If not passed, only `self.external_embeddings`
|
||||
will be used. Defaults to None.
|
||||
|
||||
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
||||
"""
|
||||
assert input_ids.ndim in [1, 2]
|
||||
if input_ids.ndim == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
if external_embeddings is None and not self.external_embeddings:
|
||||
return self.wrapped(input_ids, out_dtype=out_dtype)
|
||||
|
||||
input_ids_fwd = self.replace_input_ids(input_ids)
|
||||
inputs_embeds = self.wrapped(input_ids_fwd)
|
||||
|
||||
vecs = []
|
||||
|
||||
if external_embeddings is None:
|
||||
external_embeddings = []
|
||||
elif isinstance(external_embeddings, dict):
|
||||
external_embeddings = [external_embeddings]
|
||||
embeddings = self.external_embeddings + external_embeddings
|
||||
|
||||
for input_id, embedding in zip(input_ids, inputs_embeds):
|
||||
new_embedding = embedding
|
||||
for external_embedding in embeddings:
|
||||
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
|
||||
vecs.append(new_embedding)
|
||||
|
||||
return torch.stack(vecs).to(out_dtype)
|
||||
|
||||
|
||||
def add_tokens(
|
||||
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None,
|
||||
num_vectors_per_token: int = 1
|
||||
):
|
||||
"""Add token for training.
|
||||
|
||||
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
||||
"""
|
||||
if initialize_tokens is not None:
|
||||
assert len(initialize_tokens) == len(
|
||||
placeholder_tokens
|
||||
), "placeholder_token should be the same length as initialize_token"
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
|
||||
|
||||
# text_encoder.set_embedding_layer()
|
||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
|
||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||
|
||||
assert embedding_layer is not None, (
|
||||
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
|
||||
)
|
||||
initialize_embedding = []
|
||||
if initialize_tokens is not None:
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
||||
temp_embedding = embedding_layer.weight[init_id]
|
||||
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
|
||||
else:
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
init_id = tokenizer("a").input_ids[1]
|
||||
temp_embedding = embedding_layer.weight[init_id]
|
||||
len_emb = temp_embedding.shape[0]
|
||||
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
||||
initialize_embedding.append(init_weight)
|
||||
|
||||
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
||||
|
||||
token_info_all = []
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
||||
token_info["embedding"] = initialize_embedding[ii]
|
||||
token_info["trainable"] = True
|
||||
token_info_all.append(token_info)
|
||||
embedding_layer.add_embeddings(token_info_all)
|
||||
3908
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/unet_2d_blocks.py
Normal file
3908
custom_nodes/ComfyUI-Easy-Use/py/modules/brushnet/unet_2d_blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/__init__.py
Normal file
2
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
#credit to city96 for this module
|
||||
#from https://github.com/city96/ComfyUI_ExtraModels/
|
||||
120
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/config.py
Normal file
120
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/config.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
List of all DiT model types / settings
|
||||
"""
|
||||
sampling_settings = {
|
||||
"beta_schedule" : "sqrt_linear",
|
||||
"linear_start" : 0.0001,
|
||||
"linear_end" : 0.02,
|
||||
"timesteps" : 1000,
|
||||
}
|
||||
|
||||
dit_conf = {
|
||||
"XL/2": { # DiT_XL_2
|
||||
"unet_config": {
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 1152,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"XL/4": { # DiT_XL_4
|
||||
"unet_config": {
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 4,
|
||||
"hidden_size" : 1152,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"XL/8": { # DiT_XL_8
|
||||
"unet_config": {
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 8,
|
||||
"hidden_size" : 1152,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"L/2": { # DiT_L_2
|
||||
"unet_config": {
|
||||
"depth" : 24,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 1024,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"L/4": { # DiT_L_4
|
||||
"unet_config": {
|
||||
"depth" : 24,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 4,
|
||||
"hidden_size" : 1024,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"L/8": { # DiT_L_8
|
||||
"unet_config": {
|
||||
"depth" : 24,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 8,
|
||||
"hidden_size" : 1024,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"B/2": { # DiT_B_2
|
||||
"unet_config": {
|
||||
"depth" : 12,
|
||||
"num_heads" : 12,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 768,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"B/4": { # DiT_B_4
|
||||
"unet_config": {
|
||||
"depth" : 12,
|
||||
"num_heads" : 12,
|
||||
"patch_size" : 4,
|
||||
"hidden_size" : 768,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"B/8": { # DiT_B_8
|
||||
"unet_config": {
|
||||
"depth" : 12,
|
||||
"num_heads" : 12,
|
||||
"patch_size" : 8,
|
||||
"hidden_size" : 768,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"S/2": { # DiT_S_2
|
||||
"unet_config": {
|
||||
"depth" : 12,
|
||||
"num_heads" : 6,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 384,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"S/4": { # DiT_S_4
|
||||
"unet_config": {
|
||||
"depth" : 12,
|
||||
"num_heads" : 6,
|
||||
"patch_size" : 4,
|
||||
"hidden_size" : 384,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"S/8": { # DiT_S_8
|
||||
"unet_config": {
|
||||
"depth" : 12,
|
||||
"num_heads" : 6,
|
||||
"patch_size" : 8,
|
||||
"hidden_size" : 384,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,661 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published
|
||||
by the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
139
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/pixArt/config.py
Normal file
139
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/pixArt/config.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
List of all PixArt model types / settings
|
||||
"""
|
||||
sampling_settings = {
|
||||
"beta_schedule" : "sqrt_linear",
|
||||
"linear_start" : 0.0001,
|
||||
"linear_end" : 0.02,
|
||||
"timesteps" : 1000,
|
||||
}
|
||||
|
||||
pixart_conf = {
|
||||
"PixArtMS_XL_2": { # models/PixArtMS
|
||||
"target": "PixArtMS",
|
||||
"unet_config": {
|
||||
"input_size" : 1024//8,
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 1152,
|
||||
"pe_interpolation": 2,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"PixArtMS_Sigma_XL_2": {
|
||||
"target": "PixArtMSSigma",
|
||||
"unet_config": {
|
||||
"input_size" : 1024//8,
|
||||
"token_num" : 300,
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 1152,
|
||||
"micro_condition": False,
|
||||
"pe_interpolation": 2,
|
||||
"model_max_length": 300,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"PixArtMS_Sigma_XL_2_900M": {
|
||||
"target": "PixArtMSSigma",
|
||||
"unet_config": {
|
||||
"input_size": 1024 // 8,
|
||||
"token_num": 300,
|
||||
"depth": 42,
|
||||
"num_heads": 16,
|
||||
"patch_size": 2,
|
||||
"hidden_size": 1152,
|
||||
"micro_condition": False,
|
||||
"pe_interpolation": 2,
|
||||
"model_max_length": 300,
|
||||
},
|
||||
"sampling_settings": sampling_settings,
|
||||
},
|
||||
"PixArtMS_Sigma_XL_2_2K": {
|
||||
"target": "PixArtMSSigma",
|
||||
"unet_config": {
|
||||
"input_size" : 2048//8,
|
||||
"token_num" : 300,
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 1152,
|
||||
"micro_condition": False,
|
||||
"pe_interpolation": 4,
|
||||
"model_max_length": 300,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
"PixArt_XL_2": { # models/PixArt
|
||||
"target": "PixArt",
|
||||
"unet_config": {
|
||||
"input_size" : 512//8,
|
||||
"token_num" : 120,
|
||||
"depth" : 28,
|
||||
"num_heads" : 16,
|
||||
"patch_size" : 2,
|
||||
"hidden_size" : 1152,
|
||||
"pe_interpolation": 1,
|
||||
},
|
||||
"sampling_settings" : sampling_settings,
|
||||
},
|
||||
}
|
||||
|
||||
pixart_conf.update({ # controlnet models
|
||||
"ControlPixArtHalf": {
|
||||
"target": "ControlPixArtHalf",
|
||||
"unet_config": pixart_conf["PixArt_XL_2"]["unet_config"],
|
||||
"sampling_settings": pixart_conf["PixArt_XL_2"]["sampling_settings"],
|
||||
},
|
||||
"ControlPixArtMSHalf": {
|
||||
"target": "ControlPixArtMSHalf",
|
||||
"unet_config": pixart_conf["PixArtMS_XL_2"]["unet_config"],
|
||||
"sampling_settings": pixart_conf["PixArtMS_XL_2"]["sampling_settings"],
|
||||
}
|
||||
})
|
||||
|
||||
pixart_res = {
|
||||
"PixArtMS_XL_2": { # models/PixArtMS 1024x1024
|
||||
'0.25': [512, 2048], '0.26': [512, 1984], '0.27': [512, 1920], '0.28': [512, 1856],
|
||||
'0.32': [576, 1792], '0.33': [576, 1728], '0.35': [576, 1664], '0.40': [640, 1600],
|
||||
'0.42': [640, 1536], '0.48': [704, 1472], '0.50': [704, 1408], '0.52': [704, 1344],
|
||||
'0.57': [768, 1344], '0.60': [768, 1280], '0.68': [832, 1216], '0.72': [832, 1152],
|
||||
'0.78': [896, 1152], '0.82': [896, 1088], '0.88': [960, 1088], '0.94': [960, 1024],
|
||||
'1.00': [1024,1024], '1.07': [1024, 960], '1.13': [1088, 960], '1.21': [1088, 896],
|
||||
'1.29': [1152, 896], '1.38': [1152, 832], '1.46': [1216, 832], '1.67': [1280, 768],
|
||||
'1.75': [1344, 768], '2.00': [1408, 704], '2.09': [1472, 704], '2.40': [1536, 640],
|
||||
'2.50': [1600, 640], '2.89': [1664, 576], '3.00': [1728, 576], '3.11': [1792, 576],
|
||||
'3.62': [1856, 512], '3.75': [1920, 512], '3.88': [1984, 512], '4.00': [2048, 512],
|
||||
},
|
||||
"PixArt_XL_2": { # models/PixArt 512x512
|
||||
'0.25': [256,1024], '0.26': [256, 992], '0.27': [256, 960], '0.28': [256, 928],
|
||||
'0.32': [288, 896], '0.33': [288, 864], '0.35': [288, 832], '0.40': [320, 800],
|
||||
'0.42': [320, 768], '0.48': [352, 736], '0.50': [352, 704], '0.52': [352, 672],
|
||||
'0.57': [384, 672], '0.60': [384, 640], '0.68': [416, 608], '0.72': [416, 576],
|
||||
'0.78': [448, 576], '0.82': [448, 544], '0.88': [480, 544], '0.94': [480, 512],
|
||||
'1.00': [512, 512], '1.07': [512, 480], '1.13': [544, 480], '1.21': [544, 448],
|
||||
'1.29': [576, 448], '1.38': [576, 416], '1.46': [608, 416], '1.67': [640, 384],
|
||||
'1.75': [672, 384], '2.00': [704, 352], '2.09': [736, 352], '2.40': [768, 320],
|
||||
'2.50': [800, 320], '2.89': [832, 288], '3.00': [864, 288], '3.11': [896, 288],
|
||||
'3.62': [928, 256], '3.75': [960, 256], '3.88': [992, 256], '4.00': [1024,256]
|
||||
},
|
||||
"PixArtMS_Sigma_XL_2_2K": {
|
||||
'0.25': [1024, 4096], '0.26': [1024, 3968], '0.27': [1024, 3840], '0.28': [1024, 3712],
|
||||
'0.32': [1152, 3584], '0.33': [1152, 3456], '0.35': [1152, 3328], '0.40': [1280, 3200],
|
||||
'0.42': [1280, 3072], '0.48': [1408, 2944], '0.50': [1408, 2816], '0.52': [1408, 2688],
|
||||
'0.57': [1536, 2688], '0.60': [1536, 2560], '0.68': [1664, 2432], '0.72': [1664, 2304],
|
||||
'0.78': [1792, 2304], '0.82': [1792, 2176], '0.88': [1920, 2176], '0.94': [1920, 2048],
|
||||
'1.00': [2048, 2048], '1.07': [2048, 1920], '1.13': [2176, 1920], '1.21': [2176, 1792],
|
||||
'1.29': [2304, 1792], '1.38': [2304, 1664], '1.46': [2432, 1664], '1.67': [2560, 1536],
|
||||
'1.75': [2688, 1536], '2.00': [2816, 1408], '2.09': [2944, 1408], '2.40': [3072, 1280],
|
||||
'2.50': [3200, 1280], '2.89': [3328, 1152], '3.00': [3456, 1152], '3.11': [3584, 1152],
|
||||
'3.62': [3712, 1024], '3.75': [3840, 1024], '3.88': [3968, 1024], '4.00': [4096, 1024]
|
||||
}
|
||||
}
|
||||
# These should be the same
|
||||
pixart_res.update({
|
||||
"PixArtMS_Sigma_XL_2": pixart_res["PixArtMS_XL_2"],
|
||||
"PixArtMS_Sigma_XL_2_512": pixart_res["PixArt_XL_2"],
|
||||
})
|
||||
@@ -0,0 +1,216 @@
|
||||
# For using the diffusers format weights
|
||||
# Based on the original ComfyUI function +
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/master/tools/convert_pixart_alpha_to_diffusers.py
|
||||
import torch
|
||||
|
||||
conversion_map_ms = [ # for multi_scale_train (MS)
|
||||
# Resolution
|
||||
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
|
||||
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
|
||||
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
|
||||
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
|
||||
# Aspect ratio
|
||||
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
|
||||
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
|
||||
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
|
||||
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
|
||||
]
|
||||
|
||||
|
||||
def get_depth(state_dict):
|
||||
return sum(key.endswith('.attn1.to_k.bias') for key in state_dict.keys())
|
||||
|
||||
|
||||
def get_lora_depth(state_dict):
|
||||
return sum(key.endswith('.attn1.to_k.lora_A.weight') for key in state_dict.keys())
|
||||
|
||||
|
||||
def get_conversion_map(state_dict):
|
||||
conversion_map = [ # main SD conversion map (PixArt reference, HF Diffusers)
|
||||
# Patch embeddings
|
||||
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||
# Caption projection
|
||||
("y_embedder.y_embedding", "caption_projection.y_embedding"),
|
||||
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
|
||||
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
|
||||
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
|
||||
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
|
||||
# AdaLN-single LN
|
||||
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
|
||||
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
|
||||
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
|
||||
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
|
||||
# Shared norm
|
||||
("t_block.1.weight", "adaln_single.linear.weight"),
|
||||
("t_block.1.bias", "adaln_single.linear.bias"),
|
||||
# Final block
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.scale_shift_table", "scale_shift_table"),
|
||||
]
|
||||
|
||||
# Add actual transformer blocks
|
||||
for depth in range(get_depth(state_dict)):
|
||||
# Transformer blocks
|
||||
conversion_map += [
|
||||
(f"blocks.{depth}.scale_shift_table", f"transformer_blocks.{depth}.scale_shift_table"),
|
||||
# Projection
|
||||
(f"blocks.{depth}.attn.proj.weight", f"transformer_blocks.{depth}.attn1.to_out.0.weight"),
|
||||
(f"blocks.{depth}.attn.proj.bias", f"transformer_blocks.{depth}.attn1.to_out.0.bias"),
|
||||
# Feed-forward
|
||||
(f"blocks.{depth}.mlp.fc1.weight", f"transformer_blocks.{depth}.ff.net.0.proj.weight"),
|
||||
(f"blocks.{depth}.mlp.fc1.bias", f"transformer_blocks.{depth}.ff.net.0.proj.bias"),
|
||||
(f"blocks.{depth}.mlp.fc2.weight", f"transformer_blocks.{depth}.ff.net.2.weight"),
|
||||
(f"blocks.{depth}.mlp.fc2.bias", f"transformer_blocks.{depth}.ff.net.2.bias"),
|
||||
# Cross-attention (proj)
|
||||
(f"blocks.{depth}.cross_attn.proj.weight", f"transformer_blocks.{depth}.attn2.to_out.0.weight"),
|
||||
(f"blocks.{depth}.cross_attn.proj.bias", f"transformer_blocks.{depth}.attn2.to_out.0.bias"),
|
||||
]
|
||||
return conversion_map
|
||||
|
||||
|
||||
def find_prefix(state_dict, target_key):
|
||||
prefix = ""
|
||||
for k in state_dict.keys():
|
||||
if k.endswith(target_key):
|
||||
prefix = k.split(target_key)[0]
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def convert_state_dict(state_dict):
|
||||
if "adaln_single.emb.resolution_embedder.linear_1.weight" in state_dict.keys():
|
||||
cmap = get_conversion_map(state_dict) + conversion_map_ms
|
||||
else:
|
||||
cmap = get_conversion_map(state_dict)
|
||||
|
||||
missing = [k for k, v in cmap if v not in state_dict]
|
||||
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
|
||||
matched = list(v for k, v in cmap if v in state_dict.keys())
|
||||
|
||||
for depth in range(get_depth(state_dict)):
|
||||
for wb in ["weight", "bias"]:
|
||||
# Self Attention
|
||||
key = lambda a: f"transformer_blocks.{depth}.attn1.to_{a}.{wb}"
|
||||
new_state_dict[f"blocks.{depth}.attn.qkv.{wb}"] = torch.cat((
|
||||
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
|
||||
), dim=0)
|
||||
matched += [key('q'), key('k'), key('v')]
|
||||
|
||||
# Cross-attention (linear)
|
||||
key = lambda a: f"transformer_blocks.{depth}.attn2.to_{a}.{wb}"
|
||||
new_state_dict[f"blocks.{depth}.cross_attn.q_linear.{wb}"] = state_dict[key('q')]
|
||||
new_state_dict[f"blocks.{depth}.cross_attn.kv_linear.{wb}"] = torch.cat((
|
||||
state_dict[key('k')], state_dict[key('v')]
|
||||
), dim=0)
|
||||
matched += [key('q'), key('k'), key('v')]
|
||||
|
||||
if len(matched) < len(state_dict):
|
||||
print(f"PixArt: UNET conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
|
||||
print(list(set(state_dict.keys()) - set(matched)))
|
||||
|
||||
if len(missing) > 0:
|
||||
print(f"PixArt: UNET conversion has missing keys!")
|
||||
print(missing)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# Same as above but for LoRA weights:
|
||||
def convert_lora_state_dict(state_dict, peft=True):
|
||||
# koyha
|
||||
rep_ak = lambda x: x.replace(".weight", ".lora_down.weight")
|
||||
rep_bk = lambda x: x.replace(".weight", ".lora_up.weight")
|
||||
rep_pk = lambda x: x.replace(".weight", ".alpha")
|
||||
if peft: # peft
|
||||
rep_ap = lambda x: x.replace(".weight", ".lora_A.weight")
|
||||
rep_bp = lambda x: x.replace(".weight", ".lora_B.weight")
|
||||
rep_pp = lambda x: x.replace(".weight", ".alpha")
|
||||
|
||||
prefix = find_prefix(state_dict, "adaln_single.linear.lora_A.weight")
|
||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
|
||||
else: # OneTrainer
|
||||
rep_ap = lambda x: x.replace(".", "_")[:-7] + ".lora_down.weight"
|
||||
rep_bp = lambda x: x.replace(".", "_")[:-7] + ".lora_up.weight"
|
||||
rep_pp = lambda x: x.replace(".", "_")[:-7] + ".alpha"
|
||||
|
||||
prefix = "lora_transformer_"
|
||||
t5_marker = "lora_te_encoder"
|
||||
t5_keys = []
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith(prefix):
|
||||
state_dict[key[len(prefix):]] = state_dict.pop(key)
|
||||
elif t5_marker in key:
|
||||
t5_keys.append(state_dict.pop(key))
|
||||
if len(t5_keys) > 0:
|
||||
print(f"Text Encoder not supported for PixArt LoRA, ignoring {len(t5_keys)} keys")
|
||||
|
||||
cmap = []
|
||||
cmap_unet = get_conversion_map(state_dict) + conversion_map_ms # todo: 512 model
|
||||
for k, v in cmap_unet:
|
||||
if v.endswith(".weight"):
|
||||
cmap.append((rep_ak(k), rep_ap(v)))
|
||||
cmap.append((rep_bk(k), rep_bp(v)))
|
||||
if not peft:
|
||||
cmap.append((rep_pk(k), rep_pp(v)))
|
||||
|
||||
missing = [k for k, v in cmap if v not in state_dict]
|
||||
new_state_dict = {k: state_dict[v] for k, v in cmap if k not in missing}
|
||||
matched = list(v for k, v in cmap if v in state_dict.keys())
|
||||
|
||||
lora_depth = get_lora_depth(state_dict)
|
||||
for fp, fk in ((rep_ap, rep_ak), (rep_bp, rep_bk)):
|
||||
for depth in range(lora_depth):
|
||||
# Self Attention
|
||||
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
|
||||
new_state_dict[fk(f"blocks.{depth}.attn.qkv.weight")] = torch.cat((
|
||||
state_dict[key('q')], state_dict[key('k')], state_dict[key('v')]
|
||||
), dim=0)
|
||||
|
||||
matched += [key('q'), key('k'), key('v')]
|
||||
if not peft:
|
||||
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
|
||||
new_state_dict[rep_pk((f"blocks.{depth}.attn.qkv.weight"))] = state_dict[akey("q")]
|
||||
matched += [akey('q'), akey('k'), akey('v')]
|
||||
|
||||
# Self Attention projection?
|
||||
key = lambda a: fp(f"transformer_blocks.{depth}.attn1.to_{a}.weight")
|
||||
new_state_dict[fk(f"blocks.{depth}.attn.proj.weight")] = state_dict[key('out.0')]
|
||||
matched += [key('out.0')]
|
||||
|
||||
# Cross-attention (linear)
|
||||
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
|
||||
new_state_dict[fk(f"blocks.{depth}.cross_attn.q_linear.weight")] = state_dict[key('q')]
|
||||
new_state_dict[fk(f"blocks.{depth}.cross_attn.kv_linear.weight")] = torch.cat((
|
||||
state_dict[key('k')], state_dict[key('v')]
|
||||
), dim=0)
|
||||
matched += [key('q'), key('k'), key('v')]
|
||||
if not peft:
|
||||
akey = lambda a: rep_pp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
|
||||
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.q_linear.weight"))] = state_dict[akey("q")]
|
||||
new_state_dict[rep_pk((f"blocks.{depth}.cross_attn.kv_linear.weight"))] = state_dict[akey("k")]
|
||||
matched += [akey('q'), akey('k'), akey('v')]
|
||||
|
||||
# Cross Attention projection?
|
||||
key = lambda a: fp(f"transformer_blocks.{depth}.attn2.to_{a}.weight")
|
||||
new_state_dict[fk(f"blocks.{depth}.cross_attn.proj.weight")] = state_dict[key('out.0')]
|
||||
matched += [key('out.0')]
|
||||
|
||||
key = fp(f"transformer_blocks.{depth}.ff.net.0.proj.weight")
|
||||
new_state_dict[fk(f"blocks.{depth}.mlp.fc1.weight")] = state_dict[key]
|
||||
matched += [key]
|
||||
|
||||
key = fp(f"transformer_blocks.{depth}.ff.net.2.weight")
|
||||
new_state_dict[fk(f"blocks.{depth}.mlp.fc2.weight")] = state_dict[key]
|
||||
matched += [key]
|
||||
|
||||
if len(matched) < len(state_dict):
|
||||
print(f"PixArt: LoRA conversion has leftover keys! ({len(matched)} vs {len(state_dict)})")
|
||||
print(list(set(state_dict.keys()) - set(matched)))
|
||||
|
||||
if len(missing) > 0:
|
||||
print(f"PixArt: LoRA conversion has missing keys! (probably)")
|
||||
print(missing)
|
||||
|
||||
return new_state_dict
|
||||
331
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/pixArt/loader.py
Normal file
331
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/pixArt/loader.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import torch
|
||||
import math
|
||||
import comfy.supported_models_base
|
||||
import comfy.latent_formats
|
||||
import comfy.model_patcher
|
||||
import comfy.model_base
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
from comfy import model_management
|
||||
from .diffusers_convert import convert_state_dict, convert_lora_state_dict
|
||||
|
||||
# checkpointbf
|
||||
class EXM_PixArt(comfy.supported_models_base.BASE):
|
||||
unet_config = {}
|
||||
unet_extra_config = {}
|
||||
latent_format = comfy.latent_formats.SD15
|
||||
|
||||
def __init__(self, model_conf):
|
||||
self.model_target = model_conf.get("target")
|
||||
self.unet_config = model_conf.get("unet_config", {})
|
||||
self.sampling_settings = model_conf.get("sampling_settings", {})
|
||||
self.latent_format = self.latent_format()
|
||||
# UNET is handled by extension
|
||||
self.unet_config["disable_unet_model_creation"] = True
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
return comfy.model_base.ModelType.EPS
|
||||
|
||||
|
||||
class EXM_PixArt_Model(comfy.model_base.BaseModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
img_hw = kwargs.get("img_hw", None)
|
||||
if img_hw is not None:
|
||||
out["img_hw"] = comfy.conds.CONDRegular(torch.tensor(img_hw))
|
||||
|
||||
aspect_ratio = kwargs.get("aspect_ratio", None)
|
||||
if aspect_ratio is not None:
|
||||
out["aspect_ratio"] = comfy.conds.CONDRegular(torch.tensor(aspect_ratio))
|
||||
|
||||
cn_hint = kwargs.get("cn_hint", None)
|
||||
if cn_hint is not None:
|
||||
out["cn_hint"] = comfy.conds.CONDRegular(cn_hint)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def load_pixart(model_path, model_conf=None):
|
||||
state_dict = comfy.utils.load_torch_file(model_path)
|
||||
state_dict = state_dict.get("model", state_dict)
|
||||
|
||||
# prefix
|
||||
for prefix in ["model.diffusion_model.", ]:
|
||||
if any(True for x in state_dict if x.startswith(prefix)):
|
||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
|
||||
|
||||
# diffusers
|
||||
if "adaln_single.linear.weight" in state_dict:
|
||||
state_dict = convert_state_dict(state_dict) # Diffusers
|
||||
|
||||
# guess auto config
|
||||
if model_conf is None:
|
||||
model_conf = guess_pixart_config(state_dict)
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(state_dict)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.unet_offload_device()
|
||||
|
||||
# ignore fp8/etc and use directly for now
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype:
|
||||
print(f"PixArt: falling back to {manual_cast_dtype}")
|
||||
unet_dtype = manual_cast_dtype
|
||||
|
||||
model_conf = EXM_PixArt(model_conf) # convert to object
|
||||
model = EXM_PixArt_Model( # same as comfy.model_base.BaseModel
|
||||
model_conf,
|
||||
model_type=comfy.model_base.ModelType.EPS,
|
||||
device=model_management.get_torch_device()
|
||||
)
|
||||
|
||||
if model_conf.model_target == "PixArtMS":
|
||||
from .models.PixArtMS import PixArtMS
|
||||
model.diffusion_model = PixArtMS(**model_conf.unet_config)
|
||||
elif model_conf.model_target == "PixArt":
|
||||
from .models.PixArt import PixArt
|
||||
model.diffusion_model = PixArt(**model_conf.unet_config)
|
||||
elif model_conf.model_target == "PixArtMSSigma":
|
||||
from .models.PixArtMS import PixArtMS
|
||||
model.diffusion_model = PixArtMS(**model_conf.unet_config)
|
||||
model.latent_format = comfy.latent_formats.SDXL()
|
||||
elif model_conf.model_target == "ControlPixArtMSHalf":
|
||||
from .models.PixArtMS import PixArtMS
|
||||
from .models.pixart_controlnet import ControlPixArtMSHalf
|
||||
model.diffusion_model = PixArtMS(**model_conf.unet_config)
|
||||
model.diffusion_model = ControlPixArtMSHalf(model.diffusion_model)
|
||||
elif model_conf.model_target == "ControlPixArtHalf":
|
||||
from .models.PixArt import PixArt
|
||||
from .models.pixart_controlnet import ControlPixArtHalf
|
||||
model.diffusion_model = PixArt(**model_conf.unet_config)
|
||||
model.diffusion_model = ControlPixArtHalf(model.diffusion_model)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown model target '{model_conf.model_target}'")
|
||||
|
||||
m, u = model.diffusion_model.load_state_dict(state_dict, strict=False)
|
||||
if len(m) > 0: print("Missing UNET keys", m)
|
||||
if len(u) > 0: print("Leftover UNET keys", u)
|
||||
model.diffusion_model.dtype = unet_dtype
|
||||
model.diffusion_model.eval()
|
||||
model.diffusion_model.to(unet_dtype)
|
||||
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(
|
||||
model,
|
||||
load_device=load_device,
|
||||
offload_device=offload_device,
|
||||
)
|
||||
return model_patcher
|
||||
|
||||
|
||||
def guess_pixart_config(sd):
|
||||
"""
|
||||
Guess config based on converted state dict.
|
||||
"""
|
||||
# Shared settings based on DiT_XL_2 - could be enumerated
|
||||
config = {
|
||||
"num_heads": 16, # get from attention
|
||||
"patch_size": 2, # final layer I guess?
|
||||
"hidden_size": 1152, # pos_embed.shape[2]
|
||||
}
|
||||
config["depth"] = sum([key.endswith(".attn.proj.weight") for key in sd.keys()]) or 28
|
||||
|
||||
try:
|
||||
# this is not present in the diffusers version for sigma?
|
||||
config["model_max_length"] = sd["y_embedder.y_embedding"].shape[0]
|
||||
except KeyError:
|
||||
# need better logic to guess this
|
||||
config["model_max_length"] = 300
|
||||
|
||||
if "pos_embed" in sd:
|
||||
config["input_size"] = int(math.sqrt(sd["pos_embed"].shape[1])) * config["patch_size"]
|
||||
config["pe_interpolation"] = config["input_size"] // (512 // 8) # dumb guess
|
||||
|
||||
target_arch = "PixArtMS"
|
||||
if config["model_max_length"] == 300:
|
||||
# Sigma
|
||||
target_arch = "PixArtMSSigma"
|
||||
config["micro_condition"] = False
|
||||
if "input_size" not in config:
|
||||
# The diffusers weights for 1K/2K are exactly the same...?
|
||||
# replace patch embed logic with HyDiT?
|
||||
print(f"PixArt: diffusers weights - 2K model will be broken, use manual loading!")
|
||||
config["input_size"] = 1024 // 8
|
||||
else:
|
||||
# Alpha
|
||||
if "csize_embedder.mlp.0.weight" in sd:
|
||||
# MS (microconds)
|
||||
target_arch = "PixArtMS"
|
||||
config["micro_condition"] = True
|
||||
if "input_size" not in config:
|
||||
config["input_size"] = 1024 // 8
|
||||
config["pe_interpolation"] = 2
|
||||
else:
|
||||
# PixArt
|
||||
target_arch = "PixArt"
|
||||
if "input_size" not in config:
|
||||
config["input_size"] = 512 // 8
|
||||
config["pe_interpolation"] = 1
|
||||
|
||||
print("PixArt guessed config:", target_arch, config)
|
||||
return {
|
||||
"target": target_arch,
|
||||
"unet_config": config,
|
||||
"sampling_settings": {
|
||||
"beta_schedule": "sqrt_linear",
|
||||
"linear_start": 0.0001,
|
||||
"linear_end": 0.02,
|
||||
"timesteps": 1000,
|
||||
}
|
||||
}
|
||||
|
||||
# lora
|
||||
class EXM_PixArt_ModelPatcher(comfy.model_patcher.ModelPatcher):
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
"""
|
||||
This is almost the same as the comfy function, but stripped down to just the LoRA patch code.
|
||||
The problem with the original code is the q/k/v keys being combined into one for the attention.
|
||||
In the diffusers code, they're treated as separate keys, but in the reference code they're recombined (q+kv|qkv).
|
||||
This means, for example, that the [1152,1152] weights become [3456,1152] in the state dict.
|
||||
The issue with this is that the LoRA weights are [128,1152],[1152,128] and become [384,1162],[3456,128] instead.
|
||||
|
||||
This is the best thing I could think of that would fix that, but it's very fragile.
|
||||
- Check key shape to determine if it needs the fallback logic
|
||||
- Cut the input into parts based on the shape (undoing the torch.cat)
|
||||
- Do the matrix multiplication logic
|
||||
- Recombine them to match the expected shape
|
||||
"""
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
if len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "lora":
|
||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
if v[2] is not None:
|
||||
alpha *= v[2] / mat2.shape[0]
|
||||
try:
|
||||
mat1 = mat1.flatten(start_dim=1)
|
||||
mat2 = mat2.flatten(start_dim=1)
|
||||
|
||||
ch1 = mat1.shape[0] // mat2.shape[1]
|
||||
ch2 = mat2.shape[0] // mat1.shape[1]
|
||||
### Fallback logic for shape mismatch ###
|
||||
if mat1.shape[0] != mat2.shape[1] and ch1 == ch2 and (mat1.shape[0] / mat2.shape[1]) % 1 == 0:
|
||||
mat1 = mat1.chunk(ch1, dim=0)
|
||||
mat2 = mat2.chunk(ch1, dim=0)
|
||||
weight += torch.cat(
|
||||
[alpha * torch.mm(mat1[x], mat2[x]) for x in range(ch1)],
|
||||
dim=0,
|
||||
).reshape(weight.shape).type(weight.dtype)
|
||||
else:
|
||||
weight += (alpha * torch.mm(mat1, mat2)).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
return weight
|
||||
|
||||
def clone(self):
|
||||
n = EXM_PixArt_ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||
weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
|
||||
|
||||
def replace_model_patcher(model):
|
||||
n = EXM_PixArt_ModelPatcher(
|
||||
model=model.model,
|
||||
size=model.size,
|
||||
load_device=model.load_device,
|
||||
offload_device=model.offload_device,
|
||||
current_device=model.current_device,
|
||||
weight_inplace_update=model.weight_inplace_update,
|
||||
)
|
||||
n.patches = {}
|
||||
for k in model.patches:
|
||||
n.patches[k] = model.patches[k][:]
|
||||
|
||||
n.object_patches = model.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(model.model_options)
|
||||
return n
|
||||
|
||||
|
||||
def find_peft_alpha(path):
|
||||
def load_json(json_path):
|
||||
with open(json_path) as f:
|
||||
data = json.load(f)
|
||||
alpha = data.get("lora_alpha")
|
||||
alpha = alpha or data.get("alpha")
|
||||
if not alpha:
|
||||
print(" Found config but `lora_alpha` is missing!")
|
||||
else:
|
||||
print(f" Found config at {json_path} [alpha:{alpha}]")
|
||||
return alpha
|
||||
|
||||
# For some weird reason peft doesn't include the alpha in the actual model
|
||||
print("PixArt: Warning! This is a PEFT LoRA. Trying to find config...")
|
||||
files = [
|
||||
f"{os.path.splitext(path)[0]}.json",
|
||||
f"{os.path.splitext(path)[0]}.config.json",
|
||||
os.path.join(os.path.dirname(path), "adapter_config.json"),
|
||||
]
|
||||
for file in files:
|
||||
if os.path.isfile(file):
|
||||
return load_json(file)
|
||||
|
||||
print(" Missing config/alpha! assuming alpha of 8. Consider converting it/adding a config json to it.")
|
||||
return 8.0
|
||||
|
||||
|
||||
def load_pixart_lora(model, lora, lora_path, strength):
|
||||
k_back = lambda x: x.replace(".lora_up.weight", "")
|
||||
# need to convert the actual weights for this to work.
|
||||
if any(True for x in lora.keys() if x.endswith("adaln_single.linear.lora_A.weight")):
|
||||
lora = convert_lora_state_dict(lora, peft=True)
|
||||
alpha = find_peft_alpha(lora_path)
|
||||
lora.update({f"{k_back(x)}.alpha": torch.tensor(alpha) for x in lora.keys() if "lora_up" in x})
|
||||
else: # OneTrainer
|
||||
lora = convert_lora_state_dict(lora, peft=False)
|
||||
|
||||
key_map = {k_back(x): f"diffusion_model.{k_back(x)}.weight" for x in lora.keys() if "lora_up" in x} # fake
|
||||
|
||||
loaded = comfy.lora.load_lora(lora, key_map)
|
||||
if model is not None:
|
||||
# switch to custom model patcher when using LoRAs
|
||||
if isinstance(model, EXM_PixArt_ModelPatcher):
|
||||
new_modelpatcher = model.clone()
|
||||
else:
|
||||
new_modelpatcher = replace_model_patcher(model)
|
||||
k = new_modelpatcher.add_patches(loaded, strength)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
k = set(k)
|
||||
for x in loaded:
|
||||
if (x not in k):
|
||||
print("NOT LOADED", x)
|
||||
|
||||
return new_modelpatcher
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import numpy as np
|
||||
from timm.models.layers import DropPath
|
||||
from timm.models.vision_transformer import PatchEmbed, Mlp
|
||||
|
||||
|
||||
from .utils import auto_grad_checkpoint, to_2tuple
|
||||
from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer
|
||||
|
||||
|
||||
class PixArtBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
self.sampling = sampling
|
||||
self.sr_ratio = sr_ratio
|
||||
|
||||
def forward(self, x, y, t, mask=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArt(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=1.0,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.depth = depth
|
||||
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
num_patches = self.x_embedder.num_patches
|
||||
self.base_size = input_size // self.patch_size
|
||||
# Will use fixed sin-cos embedding:
|
||||
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length
|
||||
)
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
self.kv_compress_config = kv_compress_config
|
||||
if kv_compress_config is None:
|
||||
self.kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
input_size=(input_size // patch_size, input_size // patch_size),
|
||||
sampling=self.kv_compress_config['sampling'],
|
||||
sr_ratio=int(
|
||||
self.kv_compress_config['scale_factor']
|
||||
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
def forward_raw(self, x, t, y, mask=None, data_info=None):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
x = x.to(self.dtype)
|
||||
timestep = t.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
pos_embed = self.pos_embed.to(self.dtype)
|
||||
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, y=None, **kwargs):
|
||||
"""
|
||||
Forward pass that adapts comfy input to original forward function
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
timesteps: (N,) tensor of diffusion timesteps
|
||||
context: (N, 1, 120, C) conditioning
|
||||
y: extra conditioning.
|
||||
"""
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_raw(
|
||||
x = x.to(self.dtype),
|
||||
t = timesteps.to(self.dtype),
|
||||
y = context.to(self.dtype),
|
||||
)
|
||||
|
||||
## only return EPS
|
||||
out = out.to(torch.float)
|
||||
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||
return eps
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = to_2tuple(grid_size)
|
||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
|
||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed.astype(np.float32)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
# --------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from timm.models.layers import DropPath
|
||||
from timm.models.vision_transformer import Mlp
|
||||
|
||||
from .utils import auto_grad_checkpoint, to_2tuple
|
||||
from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, SizeEmbedder
|
||||
from .PixArt import PixArt, get_2d_sincos_pos_embed
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.patch_size = patch_size
|
||||
self.flatten = flatten
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(PixArt):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
hidden_size=hidden_size,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
class_dropout_prob=class_dropout_prob,
|
||||
learn_sigma=learn_sigma,
|
||||
pred_sigma=pred_sigma,
|
||||
drop_path=drop_path,
|
||||
pe_interpolation=pe_interpolation,
|
||||
config=config,
|
||||
model_max_length=model_max_length,
|
||||
qk_norm=qk_norm,
|
||||
kv_compress_config=kv_compress_config,
|
||||
**kwargs,
|
||||
)
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.h = self.w = 0
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(patch_size, in_channels, hidden_size, bias=True)
|
||||
self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu, token_num=model_max_length)
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3) # c_size embed
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3) # aspect ratio embed
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
input_size=(input_size // patch_size, input_size // patch_size),
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
def forward_raw(self, x, t, y, mask=None, data_info=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
bs = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = t.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round((x.shape[-1]+x.shape[-2])/2.0 / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
|
||||
pos_embed = torch.from_numpy(
|
||||
get_2d_sincos_pos_embed(
|
||||
self.pos_embed.shape[-1], (self.h, self.w), pe_interpolation=pe_interpolation,
|
||||
base_size=self.base_size
|
||||
)
|
||||
).unsqueeze(0).to(device=x.device, dtype=self.dtype)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep) # (N, D)
|
||||
|
||||
if self.micro_conditioning:
|
||||
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
|
||||
csize = self.csize_embedder(c_size, bs) # (N, D)
|
||||
ar = self.ar_embedder(ar, bs) # (N, D)
|
||||
t = t + torch.cat([csize, ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = auto_grad_checkpoint(block, x, y, t0, y_lens, (self.h, self.w), **kwargs) # (N, T, D) #support grad checkpoint
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, **kwargs):
|
||||
"""
|
||||
Forward pass that adapts comfy input to original forward function
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
timesteps: (N,) tensor of diffusion timesteps
|
||||
context: (N, 1, 120, C) conditioning
|
||||
img_hw: height|width conditioning
|
||||
aspect_ratio: aspect ratio conditioning
|
||||
"""
|
||||
## size/ar from cond with fallback based on the latent image shape.
|
||||
bs = x.shape[0]
|
||||
data_info = {}
|
||||
if img_hw is None:
|
||||
data_info["img_hw"] = torch.tensor(
|
||||
[[x.shape[2]*8, x.shape[3]*8]],
|
||||
dtype=self.dtype,
|
||||
device=x.device
|
||||
).repeat(bs, 1)
|
||||
else:
|
||||
data_info["img_hw"] = img_hw.to(dtype=x.dtype, device=x.device)
|
||||
if aspect_ratio is None or True:
|
||||
data_info["aspect_ratio"] = torch.tensor(
|
||||
[[x.shape[2]/x.shape[3]]],
|
||||
dtype=self.dtype,
|
||||
device=x.device
|
||||
).repeat(bs, 1)
|
||||
else:
|
||||
data_info["aspect_ratio"] = aspect_ratio.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_raw(
|
||||
x = x.to(self.dtype),
|
||||
t = timesteps.to(self.dtype),
|
||||
y = context.to(self.dtype),
|
||||
data_info=data_info,
|
||||
)
|
||||
|
||||
## only return EPS
|
||||
out = out.to(torch.float)
|
||||
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||
return eps
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
assert self.h * self.w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
|
||||
return imgs
|
||||
@@ -0,0 +1,477 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.vision_transformer import Mlp, Attention as Attention_
|
||||
from einops import rearrange
|
||||
|
||||
from comfy import model_management
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
print("""
|
||||
########################################
|
||||
PixArt: Not using xformers!
|
||||
Expect images to be non-deterministic!
|
||||
Batch sizes > 1 are most likely broken
|
||||
########################################
|
||||
""")
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., **block_kwargs):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = nn.Linear(d_model, d_model)
|
||||
self.kv_linear = nn.Linear(d_model, d_model*2)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(d_model, d_model)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
attn_bias = None
|
||||
if mask is not None:
|
||||
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
|
||||
x = xformers.ops.memory_efficient_attention(
|
||||
q, k, v,
|
||||
p=self.attn_drop.p,
|
||||
attn_bias=attn_bias
|
||||
)
|
||||
else:
|
||||
q, k, v = map(lambda t: t.permute(0, 2, 1, 3),(q, k, v),)
|
||||
attn_mask = None
|
||||
if mask is not None and len(mask) > 1:
|
||||
|
||||
# Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||
# But depth doesn't matter as tensors can expand in that dimension
|
||||
attn_mask_template = torch.ones(
|
||||
[q.shape[2] // B, mask[0]],
|
||||
dtype=torch.bool,
|
||||
device=q.device
|
||||
)
|
||||
attn_mask = torch.block_diag(attn_mask_template)
|
||||
|
||||
# create a mask on the diagonal for each mask in the batch
|
||||
for n in range(B - 1):
|
||||
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p
|
||||
).permute(0, 2, 1, 3).contiguous()
|
||||
x = x.view(B, -1, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionKVCompress(Attention_):
|
||||
"""Multi-head Attention block with KV token compression and qk norm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=True,
|
||||
sampling='conv',
|
||||
sr_ratio=1,
|
||||
qk_norm=False,
|
||||
**block_kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
||||
"""
|
||||
super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs)
|
||||
|
||||
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1 and sampling == 'conv':
|
||||
# Avg Conv Init.
|
||||
self.sr = nn.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.sr.weight.data.fill_(1/sr_ratio**2)
|
||||
self.sr.bias.data.zero_()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
if qk_norm:
|
||||
self.q_norm = nn.LayerNorm(dim)
|
||||
self.k_norm = nn.LayerNorm(dim)
|
||||
else:
|
||||
self.q_norm = nn.Identity()
|
||||
self.k_norm = nn.Identity()
|
||||
|
||||
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||
if sampling is None or scale_factor == 1:
|
||||
return tensor
|
||||
B, N, C = tensor.shape
|
||||
|
||||
if sampling == 'uniform_every':
|
||||
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||
|
||||
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||
new_N = new_H * new_W
|
||||
|
||||
if sampling == 'ave':
|
||||
tensor = F.interpolate(
|
||||
tensor, scale_factor=1 / scale_factor, mode='nearest'
|
||||
).permute(0, 2, 3, 1)
|
||||
elif sampling == 'uniform':
|
||||
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||
elif sampling == 'conv':
|
||||
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||
tensor = self.norm(tensor)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||
|
||||
def forward(self, x, mask=None, HW=None, block_id=None):
|
||||
B, N, C = x.shape # 2 4096 1152
|
||||
new_N = N
|
||||
if HW is None:
|
||||
H = W = int(N ** 0.5)
|
||||
else:
|
||||
H, W = HW
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
|
||||
q, k, v = qkv.unbind(2)
|
||||
dtype = q.dtype
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# KV compression
|
||||
if self.sr_ratio > 1:
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
|
||||
|
||||
attn_bias = None
|
||||
if mask is not None:
|
||||
attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||
attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
||||
# Switch between torch / xformers attention
|
||||
if model_management.xformers_enabled():
|
||||
x = xformers.ops.memory_efficient_attention(
|
||||
q, k, v,
|
||||
p=self.attn_drop.p,
|
||||
attn_bias=attn_bias
|
||||
)
|
||||
else:
|
||||
q, k, v = map(lambda t: t.transpose(1, 2),(q, k, v),)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
attn_mask=attn_bias
|
||||
).transpose(1, 2).contiguous()
|
||||
x = x.view(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################
|
||||
# AMP attention with fp32 softmax to fix loss NaN problem during training #
|
||||
#################################################################################
|
||||
class Attention(Attention_):
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
use_fp32_attention = getattr(self, 'fp32_attention', False)
|
||||
if use_fp32_attention:
|
||||
q, k = q.float(), k.float()
|
||||
with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x, t):
|
||||
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MaskFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, decoder_hidden_size):
|
||||
super().__init__()
|
||||
self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_decoder(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(t.dtype))
|
||||
return t_emb
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs//s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = self.timestep_embedding(s, self.frequency_embedding_size)
|
||||
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
|
||||
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class CaptionEmbedderDoubleBr(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
|
||||
super().__init__()
|
||||
self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
|
||||
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
||||
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return global_caption, caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
assert caption.shape[2: ] == self.y_embedding.shape
|
||||
global_caption = caption.mean(dim=2).squeeze()
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
||||
y_embed = self.proj(global_caption)
|
||||
return y_embed, caption
|
||||
@@ -0,0 +1,312 @@
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from copy import deepcopy
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, Linear, init
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .PixArt import PixArt, get_2d_sincos_pos_embed
|
||||
from .PixArtMS import PixArtMSBlock, PixArtMS
|
||||
from .utils import auto_grad_checkpoint
|
||||
|
||||
# The implementation of ControlNet-Half architrecture
|
||||
# https://github.com/lllyasviel/ControlNet/discussions/188
|
||||
class ControlT2IDitBlockHalf(Module):
|
||||
def __init__(self, base_block: PixArtMSBlock, block_index: 0) -> None:
|
||||
super().__init__()
|
||||
self.copied_block = deepcopy(base_block)
|
||||
self.block_index = block_index
|
||||
|
||||
for p in self.copied_block.parameters():
|
||||
p.requires_grad_(True)
|
||||
|
||||
self.copied_block.load_state_dict(base_block.state_dict())
|
||||
self.copied_block.train()
|
||||
|
||||
self.hidden_size = hidden_size = base_block.hidden_size
|
||||
if self.block_index == 0:
|
||||
self.before_proj = Linear(hidden_size, hidden_size)
|
||||
init.zeros_(self.before_proj.weight)
|
||||
init.zeros_(self.before_proj.bias)
|
||||
self.after_proj = Linear(hidden_size, hidden_size)
|
||||
init.zeros_(self.after_proj.weight)
|
||||
init.zeros_(self.after_proj.bias)
|
||||
|
||||
def forward(self, x, y, t, mask=None, c=None):
|
||||
|
||||
if self.block_index == 0:
|
||||
# the first block
|
||||
c = self.before_proj(c)
|
||||
c = self.copied_block(x + c, y, t, mask)
|
||||
c_skip = self.after_proj(c)
|
||||
else:
|
||||
# load from previous c and produce the c for skip connection
|
||||
c = self.copied_block(c, y, t, mask)
|
||||
c_skip = self.after_proj(c)
|
||||
|
||||
return c, c_skip
|
||||
|
||||
|
||||
# The implementation of ControlPixArtHalf net
|
||||
class ControlPixArtHalf(Module):
|
||||
# only support single res model
|
||||
def __init__(self, base_model: PixArt, copy_blocks_num: int = 13) -> None:
|
||||
super().__init__()
|
||||
self.dtype = torch.get_default_dtype()
|
||||
self.base_model = base_model.eval()
|
||||
self.controlnet = []
|
||||
self.copy_blocks_num = copy_blocks_num
|
||||
self.total_blocks_num = len(base_model.blocks)
|
||||
for p in self.base_model.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
# Copy first copy_blocks_num block
|
||||
for i in range(copy_blocks_num):
|
||||
self.controlnet.append(ControlT2IDitBlockHalf(base_model.blocks[i], i))
|
||||
self.controlnet = nn.ModuleList(self.controlnet)
|
||||
|
||||
def __getattr__(self, name: str) -> Tensor or Module:
|
||||
if name in ['forward', 'forward_with_dpmsolver', 'forward_with_cfg', 'forward_c', 'load_state_dict']:
|
||||
return self.__dict__[name]
|
||||
elif name in ['base_model', 'controlnet']:
|
||||
return super().__getattr__(name)
|
||||
else:
|
||||
return getattr(self.base_model, name)
|
||||
|
||||
def forward_c(self, c):
|
||||
self.h, self.w = c.shape[-2]//self.patch_size, c.shape[-1]//self.patch_size
|
||||
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(c.device).to(self.dtype)
|
||||
return self.x_embedder(c) + pos_embed if c is not None else c
|
||||
|
||||
# def forward(self, x, t, c, **kwargs):
|
||||
# return self.base_model(x, t, c=self.forward_c(c), **kwargs)
|
||||
def forward_raw(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
|
||||
# modify the original PixArtMS forward function
|
||||
if c is not None:
|
||||
c = c.to(self.dtype)
|
||||
c = self.forward_c(c)
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
pos_embed = self.pos_embed.to(self.dtype)
|
||||
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, 1, L, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
|
||||
# define the first layer
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
|
||||
|
||||
if c is not None:
|
||||
# update c
|
||||
for index in range(1, self.copy_blocks_num + 1):
|
||||
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
|
||||
|
||||
# update x
|
||||
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
|
||||
else:
|
||||
for index in range(1, self.total_blocks_num):
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, cn_hint=None, **kwargs):
|
||||
"""
|
||||
Forward pass that adapts comfy input to original forward function
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
timesteps: (N,) tensor of diffusion timesteps
|
||||
context: (N, 1, 120, C) conditioning
|
||||
cn_hint: controlnet hint
|
||||
"""
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_raw(
|
||||
x = x.to(self.dtype),
|
||||
timestep = timesteps.to(self.dtype),
|
||||
y = context.to(self.dtype),
|
||||
c = cn_hint,
|
||||
)
|
||||
|
||||
## only return EPS
|
||||
out = out.to(torch.float)
|
||||
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||
return eps
|
||||
|
||||
def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
|
||||
model_out = self.forward_raw(x, t, y, data_info=data_info, c=c, **kwargs)
|
||||
return model_out.chunk(2, dim=1)[0]
|
||||
|
||||
# def forward_with_dpmsolver(self, x, t, y, data_info, c, **kwargs):
|
||||
# return self.base_model.forward_with_dpmsolver(x, t, y, data_info=data_info, c=self.forward_c(c), **kwargs)
|
||||
|
||||
def forward_with_cfg(self, x, t, y, cfg_scale, data_info, c, **kwargs):
|
||||
return self.base_model.forward_with_cfg(x, t, y, cfg_scale, data_info, c=self.forward_c(c), **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
if all((k.startswith('base_model') or k.startswith('controlnet')) for k in state_dict.keys()):
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
else:
|
||||
new_key = {}
|
||||
for k in state_dict.keys():
|
||||
new_key[k] = re.sub(r"(blocks\.\d+)(.*)", r"\1.base_block\2", k)
|
||||
for k, v in new_key.items():
|
||||
if k != v:
|
||||
print(f"replace {k} to {v}")
|
||||
state_dict[v] = state_dict.pop(k)
|
||||
|
||||
return self.base_model.load_state_dict(state_dict, strict)
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
assert self.h * self.w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], self.h, self.w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, self.h * p, self.w * p))
|
||||
return imgs
|
||||
|
||||
# @property
|
||||
# def dtype(self):
|
||||
## 返回模型参数的数据类型
|
||||
# return next(self.parameters()).dtype
|
||||
|
||||
|
||||
# The implementation for PixArtMS_Half + 1024 resolution
|
||||
class ControlPixArtMSHalf(ControlPixArtHalf):
|
||||
# support multi-scale res model (multi-scale model can also be applied to single reso training & inference)
|
||||
def __init__(self, base_model: PixArtMS, copy_blocks_num: int = 13) -> None:
|
||||
super().__init__(base_model=base_model, copy_blocks_num=copy_blocks_num)
|
||||
|
||||
def forward_raw(self, x, timestep, y, mask=None, data_info=None, c=None, **kwargs):
|
||||
# modify the original PixArtMS forward function
|
||||
"""
|
||||
Forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) tensor of class labels
|
||||
"""
|
||||
if c is not None:
|
||||
c = c.to(self.dtype)
|
||||
c = self.forward_c(c)
|
||||
bs = x.shape[0]
|
||||
x = x.to(self.dtype)
|
||||
timestep = timestep.to(self.dtype)
|
||||
y = y.to(self.dtype)
|
||||
c_size, ar = data_info['img_hw'].to(self.dtype), data_info['aspect_ratio'].to(self.dtype)
|
||||
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
|
||||
|
||||
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (self.h, self.w), lewei_scale=self.lewei_scale, base_size=self.base_size)).unsqueeze(0).to(x.device).to(self.dtype)
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep) # (N, D)
|
||||
csize = self.csize_embedder(c_size, bs) # (N, D)
|
||||
ar = self.ar_embedder(ar, bs) # (N, D)
|
||||
t = t + torch.cat([csize, ar], dim=1)
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = [y.shape[2]] * y.shape[0]
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
|
||||
# define the first layer
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[0], x, y, t0, y_lens, **kwargs) # (N, T, D) #support grad checkpoint
|
||||
|
||||
if c is not None:
|
||||
# update c
|
||||
for index in range(1, self.copy_blocks_num + 1):
|
||||
c, c_skip = auto_grad_checkpoint(self.controlnet[index - 1], x, y, t0, y_lens, c, **kwargs)
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[index], x + c_skip, y, t0, y_lens, **kwargs)
|
||||
|
||||
# update x
|
||||
for index in range(self.copy_blocks_num + 1, self.total_blocks_num):
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
|
||||
else:
|
||||
for index in range(1, self.total_blocks_num):
|
||||
x = auto_grad_checkpoint(self.base_model.blocks[index], x, y, t0, y_lens, **kwargs)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, img_hw=None, aspect_ratio=None, cn_hint=None, **kwargs):
|
||||
"""
|
||||
Forward pass that adapts comfy input to original forward function
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
timesteps: (N,) tensor of diffusion timesteps
|
||||
context: (N, 1, 120, C) conditioning
|
||||
img_hw: height|width conditioning
|
||||
aspect_ratio: aspect ratio conditioning
|
||||
cn_hint: controlnet hint
|
||||
"""
|
||||
## size/ar from cond with fallback based on the latent image shape.
|
||||
bs = x.shape[0]
|
||||
data_info = {}
|
||||
if img_hw is None:
|
||||
data_info["img_hw"] = torch.tensor(
|
||||
[[x.shape[2]*8, x.shape[3]*8]],
|
||||
dtype=self.dtype,
|
||||
device=x.device
|
||||
).repeat(bs, 1)
|
||||
else:
|
||||
data_info["img_hw"] = img_hw.to(x.dtype)
|
||||
if aspect_ratio is None or True:
|
||||
data_info["aspect_ratio"] = torch.tensor(
|
||||
[[x.shape[2]/x.shape[3]]],
|
||||
dtype=self.dtype,
|
||||
device=x.device
|
||||
).repeat(bs, 1)
|
||||
else:
|
||||
data_info["aspect_ratio"] = aspect_ratio.to(x.dtype)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_raw(
|
||||
x = x.to(self.dtype),
|
||||
timestep = timesteps.to(self.dtype),
|
||||
y = context.to(self.dtype),
|
||||
c = cn_hint,
|
||||
data_info=data_info,
|
||||
)
|
||||
|
||||
## only return EPS
|
||||
out = out.to(torch.float)
|
||||
eps, rest = out[:, :self.in_channels], out[:, self.in_channels:]
|
||||
return eps
|
||||
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
from collections.abc import Iterable
|
||||
from itertools import repeat
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, Iterable) and not isinstance(x, str):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
def set_attr(module):
|
||||
module.grad_checkpointing = True
|
||||
module.fp32_attention = use_fp32_attention
|
||||
module.grad_checkpointing_step = gc_step
|
||||
model.apply(set_attr)
|
||||
|
||||
def auto_grad_checkpoint(module, *args, **kwargs):
|
||||
if getattr(module, 'grad_checkpointing', False):
|
||||
if isinstance(module, Iterable):
|
||||
gc_step = module[0].grad_checkpointing_step
|
||||
return checkpoint_sequential(module, gc_step, *args, **kwargs)
|
||||
else:
|
||||
return checkpoint(module, *args, **kwargs)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
def checkpoint_sequential(functions, step, input, *args, **kwargs):
|
||||
|
||||
# Hack for keyword-only parameter in a python 2.7-compliant way
|
||||
preserve = kwargs.pop('preserve_rng_state', True)
|
||||
if kwargs:
|
||||
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
|
||||
|
||||
def run_function(start, end, functions):
|
||||
def forward(input):
|
||||
for j in range(start, end + 1):
|
||||
input = functions[j](input, *args)
|
||||
return input
|
||||
return forward
|
||||
|
||||
if isinstance(functions, torch.nn.Sequential):
|
||||
functions = list(functions.children())
|
||||
|
||||
# the last chunk has to be non-volatile
|
||||
end = -1
|
||||
segment = len(functions) // step
|
||||
for start in range(0, step * (segment - 1), step):
|
||||
end = start + step - 1
|
||||
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
|
||||
return run_function(end + 1, len(functions) - 1, functions)(input)
|
||||
|
||||
def get_rel_pos(q_size, k_size, rel_pos):
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (Tensor): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (
|
||||
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
).view(B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
38
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/utils.py
Normal file
38
custom_nodes/ComfyUI-Easy-Use/py/modules/dit/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from comfy import model_management
|
||||
|
||||
def string_to_dtype(s="none", mode=None):
|
||||
s = s.lower().strip()
|
||||
if s in ["default", "as-is"]:
|
||||
return None
|
||||
elif s in ["auto", "auto (comfy)"]:
|
||||
if mode == "vae":
|
||||
return model_management.vae_device()
|
||||
elif mode == "text_encoder":
|
||||
return model_management.text_encoder_dtype()
|
||||
elif mode == "unet":
|
||||
return model_management.unet_dtype()
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dtype mode '{mode}'")
|
||||
elif s in ["none", "auto (hf)", "auto (hf/bnb)"]:
|
||||
return None
|
||||
elif s in ["fp32", "float32", "float"]:
|
||||
return torch.float32
|
||||
elif s in ["bf16", "bfloat16"]:
|
||||
return torch.bfloat16
|
||||
elif s in ["fp16", "float16", "half"]:
|
||||
return torch.float16
|
||||
elif "fp8" in s or "float8" in s:
|
||||
if "e5m2" in s:
|
||||
return torch.float8_e5m2
|
||||
elif "e4m3" in s:
|
||||
return torch.float8_e4m3fn
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown 8bit dtype '{s}'")
|
||||
elif "bnb" in s:
|
||||
assert s in ["bnb8bit", "bnb4bit"], f"Unknown bnb mode '{s}'"
|
||||
return s
|
||||
elif s is None:
|
||||
return None
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dtype '{s}'")
|
||||
139
custom_nodes/ComfyUI-Easy-Use/py/modules/fooocus/__init__.py
Normal file
139
custom_nodes/ComfyUI-Easy-Use/py/modules/fooocus/__init__.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#credit to Acly for this module
|
||||
#from https://github.com/Acly/comfyui-inpaint-nodes
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_management import cast_to_device
|
||||
|
||||
from ...libs.log import log_node_warn, log_node_error, log_node_info
|
||||
|
||||
class InpaintHead(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.head = torch.nn.Parameter(torch.empty(size=(320, 5, 3, 3), device="cpu"))
|
||||
|
||||
def __call__(self, x):
|
||||
x = F.pad(x, (1, 1, 1, 1), "replicate")
|
||||
return F.conv2d(x, weight=self.head)
|
||||
|
||||
# injected_model_patcher_calculate_weight = False
|
||||
# original_calculate_weight = None
|
||||
|
||||
class applyFooocusInpaint:
|
||||
def calculate_weight_patched(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||
remaining = []
|
||||
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
|
||||
is_fooocus_patch = isinstance(v, tuple) and len(v) == 2 and v[0] == "fooocus"
|
||||
if not is_fooocus_patch:
|
||||
remaining.append(p)
|
||||
continue
|
||||
|
||||
if alpha != 0.0:
|
||||
v = v[1]
|
||||
w1 = cast_to_device(v[0], weight.device, torch.float32)
|
||||
if w1.shape == weight.shape:
|
||||
w_min = cast_to_device(v[1], weight.device, torch.float32)
|
||||
w_max = cast_to_device(v[2], weight.device, torch.float32)
|
||||
w1 = (w1 / 255.0) * (w_max - w_min) + w_min
|
||||
weight += alpha * cast_to_device(w1, weight.device, weight.dtype)
|
||||
else:
|
||||
print(
|
||||
f"[ApplyFooocusInpaint] Shape mismatch {key}, weight not merged ({w1.shape} != {weight.shape})"
|
||||
)
|
||||
|
||||
if len(remaining) > 0:
|
||||
return self.original_calculate_weight(remaining, weight, key, intermediate_dtype)
|
||||
return weight
|
||||
|
||||
def __enter__(self):
|
||||
try:
|
||||
print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
|
||||
self.original_calculate_weight = comfy.lora.calculate_weight
|
||||
comfy.lora.calculate_weight = self.calculate_weight_patched
|
||||
except AttributeError:
|
||||
print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
|
||||
self.original_calculate_weight = ModelPatcher.calculate_weight
|
||||
ModelPatcher.calculate_weight = self.calculate_weight_patched
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
try:
|
||||
comfy.lora.calculate_weight = self.original_calculate_weight
|
||||
except:
|
||||
ModelPatcher.calculate_weight = self.original_calculate_weight
|
||||
|
||||
# def inject_patched_calculate_weight():
|
||||
# global injected_model_patcher_calculate_weight
|
||||
# if not injected_model_patcher_calculate_weight:
|
||||
# try:
|
||||
# print("[comfyui-easy-use] Injecting patched comfy.lora.calculate_weight.calculate_weight")
|
||||
# original_calculate_weight = comfy.lora.calculate_weight
|
||||
# comfy.lora.original_calculate_weight = original_calculate_weight
|
||||
# comfy.lora.calculate_weight = calculate_weight_patched
|
||||
# except AttributeError:
|
||||
# print("[comfyui-easy-use] Injecting patched comfy.model_patcher.ModelPatcher.calculate_weight")
|
||||
# original_calculate_weight = ModelPatcher.calculate_weight
|
||||
# ModelPatcher.original_calculate_weight = original_calculate_weight
|
||||
# ModelPatcher.calculate_weight = calculate_weight_patched
|
||||
# injected_model_patcher_calculate_weight = True
|
||||
|
||||
|
||||
class InpaintWorker:
|
||||
def __init__(self, node_name):
|
||||
self.node_name = node_name if node_name is not None else ""
|
||||
|
||||
def load_fooocus_patch(self, lora: dict, to_load: dict):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for key in to_load.values():
|
||||
if value := lora.get(key, None):
|
||||
patch_dict[key] = ("fooocus", value)
|
||||
loaded_keys.add(key)
|
||||
|
||||
not_loaded = sum(1 for x in lora if x not in loaded_keys)
|
||||
if not_loaded > 0:
|
||||
log_node_info(self.node_name,
|
||||
f"{len(loaded_keys)} Lora keys loaded, {not_loaded} remaining keys not found in model."
|
||||
)
|
||||
return patch_dict
|
||||
|
||||
def _input_block_patch(self, h: torch.Tensor, transformer_options: dict):
|
||||
if transformer_options["block"][1] == 0:
|
||||
if self._inpaint_block is None or self._inpaint_block.shape != h.shape:
|
||||
assert self._inpaint_head_feature is not None
|
||||
batch = h.shape[0] // self._inpaint_head_feature.shape[0]
|
||||
self._inpaint_block = self._inpaint_head_feature.to(h).repeat(batch, 1, 1, 1)
|
||||
h = h + self._inpaint_block
|
||||
return h
|
||||
|
||||
def patch(self, model, latent, patch):
|
||||
base_model: BaseModel = model.model
|
||||
latent_pixels = base_model.process_latent_in(latent["samples"])
|
||||
noise_mask = latent["noise_mask"].round()
|
||||
latent_mask = F.max_pool2d(noise_mask, (8, 8)).round().to(latent_pixels)
|
||||
|
||||
inpaint_head_model, inpaint_lora = patch
|
||||
feed = torch.cat([latent_mask, latent_pixels], dim=1)
|
||||
inpaint_head_model.to(device=feed.device, dtype=feed.dtype)
|
||||
self._inpaint_head_feature = inpaint_head_model(feed)
|
||||
self._inpaint_block = None
|
||||
|
||||
lora_keys = comfy.lora.model_lora_keys_unet(model.model, {})
|
||||
lora_keys.update({x: x for x in base_model.state_dict().keys()})
|
||||
loaded_lora = self.load_fooocus_patch(inpaint_lora, lora_keys)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_input_block_patch(self._input_block_patch)
|
||||
patched = m.add_patches(loaded_lora, 1.0)
|
||||
m.model_options['transformer_options']['fooocus'] = True
|
||||
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
|
||||
if not_patched_count > 0:
|
||||
log_node_error(self.node_name, f"Failed to patch {not_patched_count} keys")
|
||||
|
||||
# inject_patched_calculate_weight()
|
||||
return (m,)
|
||||
@@ -0,0 +1,156 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from .simple_extractor_dataset import SimpleFolderDataset
|
||||
from .transforms import transform_logits
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
def get_palette(num_cls):
|
||||
""" Returns the color map for visualizing the segmentation mask.
|
||||
Args:
|
||||
num_cls: Number of classes
|
||||
Returns:
|
||||
The color map
|
||||
"""
|
||||
n = num_cls
|
||||
palette = [0] * (n * 3)
|
||||
for j in range(0, n):
|
||||
lab = j
|
||||
palette[j * 3 + 0] = 0
|
||||
palette[j * 3 + 1] = 0
|
||||
palette[j * 3 + 2] = 0
|
||||
i = 0
|
||||
while lab:
|
||||
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
||||
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
||||
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
||||
i += 1
|
||||
lab >>= 3
|
||||
return palette
|
||||
|
||||
|
||||
def delete_irregular(logits_result):
|
||||
parsing_result = np.argmax(logits_result, axis=2)
|
||||
upper_cloth = np.where(parsing_result == 4, 255, 0)
|
||||
contours, hierarchy = cv2.findContours(upper_cloth.astype(np.uint8),
|
||||
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
|
||||
area = []
|
||||
for i in range(len(contours)):
|
||||
a = cv2.contourArea(contours[i], True)
|
||||
area.append(abs(a))
|
||||
if len(area) != 0:
|
||||
top = area.index(max(area))
|
||||
M = cv2.moments(contours[top])
|
||||
cY = int(M["m01"] / M["m00"])
|
||||
|
||||
dresses = np.where(parsing_result == 7, 255, 0)
|
||||
contours_dress, hierarchy_dress = cv2.findContours(dresses.astype(np.uint8),
|
||||
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
|
||||
area_dress = []
|
||||
for j in range(len(contours_dress)):
|
||||
a_d = cv2.contourArea(contours_dress[j], True)
|
||||
area_dress.append(abs(a_d))
|
||||
if len(area_dress) != 0:
|
||||
top_dress = area_dress.index(max(area_dress))
|
||||
M_dress = cv2.moments(contours_dress[top_dress])
|
||||
cY_dress = int(M_dress["m01"] / M_dress["m00"])
|
||||
wear_type = "dresses"
|
||||
if len(area) != 0:
|
||||
if len(area_dress) != 0 and cY_dress > cY:
|
||||
irregular_list = np.array([4, 5, 6])
|
||||
logits_result[:, :, irregular_list] = -1
|
||||
else:
|
||||
irregular_list = np.array([5, 6, 7, 8, 9, 10, 12, 13])
|
||||
logits_result[:cY, :, irregular_list] = -1
|
||||
wear_type = "cloth_pant"
|
||||
parsing_result = np.argmax(logits_result, axis=2)
|
||||
# pad border
|
||||
parsing_result = np.pad(parsing_result, pad_width=1, mode='constant', constant_values=0)
|
||||
return parsing_result, wear_type
|
||||
|
||||
|
||||
|
||||
def hole_fill(img):
|
||||
img_copy = img.copy()
|
||||
mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
|
||||
cv2.floodFill(img, mask, (0, 0), 255)
|
||||
img_inverse = cv2.bitwise_not(img)
|
||||
dst = cv2.bitwise_or(img_copy, img_inverse)
|
||||
return dst
|
||||
|
||||
def refine_mask(mask):
|
||||
contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
|
||||
cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
|
||||
area = []
|
||||
for j in range(len(contours)):
|
||||
a_d = cv2.contourArea(contours[j], True)
|
||||
area.append(abs(a_d))
|
||||
refine_mask = np.zeros_like(mask).astype(np.uint8)
|
||||
if len(area) != 0:
|
||||
i = area.index(max(area))
|
||||
cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
|
||||
# keep large area in skin case
|
||||
for j in range(len(area)):
|
||||
if j != i and area[i] > 2000:
|
||||
cv2.drawContours(refine_mask, contours, j, color=255, thickness=-1)
|
||||
return refine_mask
|
||||
|
||||
def refine_hole(parsing_result_filled, parsing_result, arm_mask):
|
||||
filled_hole = cv2.bitwise_and(np.where(parsing_result_filled == 4, 255, 0),
|
||||
np.where(parsing_result != 4, 255, 0)) - arm_mask * 255
|
||||
contours, hierarchy = cv2.findContours(filled_hole, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
|
||||
refine_hole_mask = np.zeros_like(parsing_result).astype(np.uint8)
|
||||
for i in range(len(contours)):
|
||||
a = cv2.contourArea(contours[i], True)
|
||||
# keep hole > 2000 pixels
|
||||
if abs(a) > 2000:
|
||||
cv2.drawContours(refine_hole_mask, contours, i, color=255, thickness=-1)
|
||||
return refine_hole_mask + arm_mask
|
||||
|
||||
def onnx_inference(lip_session, input_dir, mask_components=[0]):
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
|
||||
])
|
||||
input_size = [473, 473]
|
||||
|
||||
dataset_lip = SimpleFolderDataset(root=input_dir, input_size=input_size, transform=transform)
|
||||
dataloader_lip = DataLoader(dataset_lip)
|
||||
palette = get_palette(20)
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(tqdm(dataloader_lip)):
|
||||
image, meta = batch
|
||||
c = meta['center'].numpy()[0]
|
||||
s = meta['scale'].numpy()[0]
|
||||
w = meta['width'].numpy()[0]
|
||||
h = meta['height'].numpy()[0]
|
||||
|
||||
output = lip_session.run(None, {"input.1": image.numpy().astype(np.float32)})
|
||||
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
|
||||
upsample_output = upsample(torch.from_numpy(output[1][0]).unsqueeze(0))
|
||||
upsample_output = upsample_output.squeeze()
|
||||
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
|
||||
logits_result_lip = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h,
|
||||
input_size=input_size)
|
||||
parsing_result = np.argmax(logits_result_lip, axis=2)
|
||||
|
||||
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
||||
output_img.putpalette(palette)
|
||||
|
||||
mask = np.isin(output_img, mask_components).astype(np.uint8)
|
||||
mask_image = Image.fromarray(mask * 255)
|
||||
mask_image = mask_image.convert("RGB")
|
||||
mask_image = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
|
||||
output_img = output_img.convert('RGB')
|
||||
output_img = torch.from_numpy(np.array(output_img).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
|
||||
return output_img, mask_image
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from .parsing_api import onnx_inference
|
||||
from ...libs.utils import install_package
|
||||
|
||||
class HumanParsing:
|
||||
def __init__(self, model_path):
|
||||
self.model_path = model_path
|
||||
self.session = None
|
||||
|
||||
def __call__(self, input_image, mask_components):
|
||||
if self.session is None:
|
||||
install_package('onnxruntime')
|
||||
import onnxruntime as ort
|
||||
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
# session_options.add_session_config_entry('gpu_id', str(gpu_id))
|
||||
self.session = ort.InferenceSession(self.model_path, sess_options=session_options,
|
||||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
|
||||
parsed_image, mask = onnx_inference(self.session, input_image, mask_components)
|
||||
return parsed_image, mask
|
||||
|
||||
|
||||
class HumanParts:
|
||||
|
||||
def __init__(self, model_path):
|
||||
self.model_path = model_path
|
||||
self.session = None
|
||||
# self.classes_dict = {
|
||||
# "background": 0,
|
||||
# "hair": 2,
|
||||
# "glasses": 4,
|
||||
# "top-clothes": 5,
|
||||
# "bottom-clothes": 9,
|
||||
# "torso-skin": 10,
|
||||
# "face": 13,
|
||||
# "left-arm": 14,
|
||||
# "right-arm": 15,
|
||||
# "left-leg": 16,
|
||||
# "right-leg": 17,
|
||||
# "left-foot": 18,
|
||||
# "right-foot": 19,
|
||||
# },
|
||||
self.classes = [0, 13, 2, 4, 5, 9, 10, 14, 15, 16, 17, 18, 19]
|
||||
|
||||
|
||||
def __call__(self, input_image, mask_components):
|
||||
if self.session is None:
|
||||
install_package('onnxruntime')
|
||||
import onnxruntime as ort
|
||||
|
||||
self.session = ort.InferenceSession(self.model_path, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
|
||||
mask, = self.get_mask(self.session, input_image, 0, mask_components)
|
||||
return mask
|
||||
|
||||
def get_mask(self, model, image, rotation, mask_components):
|
||||
image = image.squeeze(0)
|
||||
image_np = image.numpy() * 255
|
||||
|
||||
pil_image = Image.fromarray(image_np.astype(np.uint8))
|
||||
original_size = pil_image.size # to resize the mask later
|
||||
# resize to 512x512 as the model expects
|
||||
pil_image = pil_image.resize((512, 512))
|
||||
center = (256, 256)
|
||||
|
||||
if rotation != 0:
|
||||
pil_image = pil_image.rotate(rotation, center=center)
|
||||
|
||||
# normalize the image
|
||||
image_np = np.array(pil_image).astype(np.float32) / 127.5 - 1
|
||||
image_np = np.expand_dims(image_np, axis=0)
|
||||
|
||||
# use the onnx model to get the mask
|
||||
input_name = model.get_inputs()[0].name
|
||||
output_name = model.get_outputs()[0].name
|
||||
result = model.run([output_name], {input_name: image_np})
|
||||
result = np.array(result[0]).argmax(axis=3).squeeze(0)
|
||||
|
||||
score: int = 0
|
||||
|
||||
mask = np.zeros_like(result)
|
||||
for class_index in mask_components:
|
||||
detected = result == self.classes[class_index]
|
||||
mask[detected] = 255
|
||||
score += mask.sum()
|
||||
|
||||
# back to the original size
|
||||
mask_image = Image.fromarray(mask.astype(np.uint8), mode="L")
|
||||
if rotation != 0:
|
||||
mask_image = mask_image.rotate(-rotation, center=center)
|
||||
|
||||
mask_image = mask_image.resize(original_size)
|
||||
|
||||
# and back to numpy...
|
||||
mask = np.array(mask_image).astype(np.float32) / 255
|
||||
|
||||
# add 2 dimensions to match the expected output
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
mask = np.expand_dims(mask, axis=0)
|
||||
# ensure to return a "binary mask_image"
|
||||
|
||||
del image_np, result # free up memory, maybe not necessary
|
||||
|
||||
return (torch.from_numpy(mask.astype(np.uint8)),)
|
||||
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
"""
|
||||
@Author : Peike Li
|
||||
@Contact : peike.li@yahoo.com
|
||||
@File : dataset.py
|
||||
@Time : 8/30/19 9:12 PM
|
||||
@Desc : Dataset Definition
|
||||
@License : This source code is licensed under the license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from .transforms import get_affine_transform
|
||||
|
||||
|
||||
class SimpleFolderDataset(data.Dataset):
|
||||
def __init__(self, root, input_size=[512, 512], transform=None):
|
||||
self.root = root
|
||||
self.input_size = input_size
|
||||
self.transform = transform
|
||||
self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
|
||||
self.input_size = np.asarray(input_size)
|
||||
self.is_pil_image = False
|
||||
if isinstance(root, Image.Image):
|
||||
self.file_list = [root]
|
||||
self.is_pil_image = True
|
||||
elif os.path.isfile(root):
|
||||
self.file_list = [os.path.basename(root)]
|
||||
self.root = os.path.dirname(root)
|
||||
else:
|
||||
self.file_list = os.listdir(self.root)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
def _box2cs(self, box):
|
||||
x, y, w, h = box[:4]
|
||||
return self._xywh2cs(x, y, w, h)
|
||||
|
||||
def _xywh2cs(self, x, y, w, h):
|
||||
center = np.zeros((2), dtype=np.float32)
|
||||
center[0] = x + w * 0.5
|
||||
center[1] = y + h * 0.5
|
||||
if w > self.aspect_ratio * h:
|
||||
h = w * 1.0 / self.aspect_ratio
|
||||
elif w < self.aspect_ratio * h:
|
||||
w = h * self.aspect_ratio
|
||||
scale = np.array([w, h], dtype=np.float32)
|
||||
return center, scale
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_pil_image:
|
||||
img = np.asarray(self.file_list[index])[:, :, [2, 1, 0]]
|
||||
else:
|
||||
img_name = self.file_list[index]
|
||||
img_path = os.path.join(self.root, img_name)
|
||||
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
||||
h, w, _ = img.shape
|
||||
|
||||
# Get person center and scale
|
||||
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
||||
r = 0
|
||||
trans = get_affine_transform(person_center, s, r, self.input_size)
|
||||
input = cv2.warpAffine(
|
||||
img,
|
||||
trans,
|
||||
(int(self.input_size[1]), int(self.input_size[0])),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0, 0, 0))
|
||||
|
||||
input = self.transform(input)
|
||||
meta = {
|
||||
'center': person_center,
|
||||
'height': h,
|
||||
'width': w,
|
||||
'scale': s,
|
||||
'rotation': r
|
||||
}
|
||||
|
||||
return input, meta
|
||||
@@ -0,0 +1,167 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft
|
||||
# Licensed under the MIT License.
|
||||
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
class BRG2Tensor_transform(object):
|
||||
def __call__(self, pic):
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
if isinstance(img, torch.ByteTensor):
|
||||
return img.float()
|
||||
else:
|
||||
return img
|
||||
|
||||
class BGR2RGB_transform(object):
|
||||
def __call__(self, tensor):
|
||||
return tensor[[2,1,0],:,:]
|
||||
|
||||
def flip_back(output_flipped, matched_parts):
|
||||
'''
|
||||
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
||||
'''
|
||||
assert output_flipped.ndim == 4,\
|
||||
'output_flipped should be [batch_size, num_joints, height, width]'
|
||||
|
||||
output_flipped = output_flipped[:, :, :, ::-1]
|
||||
|
||||
for pair in matched_parts:
|
||||
tmp = output_flipped[:, pair[0], :, :].copy()
|
||||
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
||||
output_flipped[:, pair[1], :, :] = tmp
|
||||
|
||||
return output_flipped
|
||||
|
||||
|
||||
def fliplr_joints(joints, joints_vis, width, matched_parts):
|
||||
"""
|
||||
flip coords
|
||||
"""
|
||||
# Flip horizontal
|
||||
joints[:, 0] = width - joints[:, 0] - 1
|
||||
|
||||
# Change left-right parts
|
||||
for pair in matched_parts:
|
||||
joints[pair[0], :], joints[pair[1], :] = \
|
||||
joints[pair[1], :], joints[pair[0], :].copy()
|
||||
joints_vis[pair[0], :], joints_vis[pair[1], :] = \
|
||||
joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
|
||||
|
||||
return joints*joints_vis, joints_vis
|
||||
|
||||
|
||||
def transform_preds(coords, center, scale, input_size):
|
||||
target_coords = np.zeros(coords.shape)
|
||||
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
||||
for p in range(coords.shape[0]):
|
||||
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
|
||||
return target_coords
|
||||
|
||||
def transform_parsing(pred, center, scale, width, height, input_size):
|
||||
|
||||
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
||||
target_pred = cv2.warpAffine(
|
||||
pred,
|
||||
trans,
|
||||
(int(width), int(height)), #(int(width), int(height)),
|
||||
flags=cv2.INTER_NEAREST,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0))
|
||||
|
||||
return target_pred
|
||||
|
||||
def transform_logits(logits, center, scale, width, height, input_size):
|
||||
|
||||
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
||||
channel = logits.shape[2]
|
||||
target_logits = []
|
||||
for i in range(channel):
|
||||
target_logit = cv2.warpAffine(
|
||||
logits[:,:,i],
|
||||
trans,
|
||||
(int(width), int(height)), #(int(width), int(height)),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(0))
|
||||
target_logits.append(target_logit)
|
||||
target_logits = np.stack(target_logits,axis=2)
|
||||
|
||||
return target_logits
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
scale,
|
||||
rot,
|
||||
output_size,
|
||||
shift=np.array([0, 0], dtype=np.float32),
|
||||
inv=0):
|
||||
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
||||
print(scale)
|
||||
scale = np.array([scale, scale])
|
||||
|
||||
scale_tmp = scale
|
||||
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[1]
|
||||
dst_h = output_size[0]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0, (dst_w-1) * -0.5], np.float32)
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
dst[0, :] = [(dst_w-1) * 0.5, (dst_h-1) * 0.5]
|
||||
dst[1, :] = np.array([(dst_w-1) * 0.5, (dst_h-1) * 0.5]) + dst_dir
|
||||
|
||||
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
||||
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
def affine_transform(pt, t):
|
||||
new_pt = np.array([pt[0], pt[1], 1.]).T
|
||||
new_pt = np.dot(t, new_pt)
|
||||
return new_pt[:2]
|
||||
|
||||
|
||||
def get_3rd_point(a, b):
|
||||
direct = a - b
|
||||
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
||||
|
||||
|
||||
def get_dir(src_point, rot_rad):
|
||||
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
||||
|
||||
src_result = [0, 0]
|
||||
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
||||
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
||||
|
||||
return src_result
|
||||
|
||||
|
||||
def crop(img, center, scale, output_size, rot=0):
|
||||
trans = get_affine_transform(center, scale, rot, output_size)
|
||||
|
||||
dst_img = cv2.warpAffine(img,
|
||||
trans,
|
||||
(int(output_size[1]), int(output_size[0])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
|
||||
return dst_img
|
||||
184
custom_nodes/ComfyUI-Easy-Use/py/modules/ic_light/__init__.py
Normal file
184
custom_nodes/ComfyUI-Easy-Use/py/modules/ic_light/__init__.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#credit to huchenlei for this module
|
||||
#from https://github.com/huchenlei/ComfyUI-IC-Light-Native
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Tuple, TypedDict, Callable
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.sd import load_unet
|
||||
from comfy.ldm.models.autoencoder import AutoencoderKL
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from PIL import Image
|
||||
from nodes import VAEEncode
|
||||
from ...libs.image import np2tensor, pil2tensor
|
||||
|
||||
class UnetParams(TypedDict):
|
||||
input: torch.Tensor
|
||||
timestep: torch.Tensor
|
||||
c: dict
|
||||
cond_or_uncond: torch.Tensor
|
||||
|
||||
class VAEEncodeArgMax(VAEEncode):
|
||||
def encode(self, vae, pixels):
|
||||
assert isinstance(
|
||||
vae.first_stage_model, AutoencoderKL
|
||||
), "ArgMax only supported for AutoencoderKL"
|
||||
original_sample_mode = vae.first_stage_model.regularization.sample
|
||||
vae.first_stage_model.regularization.sample = False
|
||||
ret = super().encode(vae, pixels)
|
||||
vae.first_stage_model.regularization.sample = original_sample_mode
|
||||
return ret
|
||||
|
||||
class ICLight:
|
||||
|
||||
@staticmethod
|
||||
def apply_c_concat(params: UnetParams, concat_conds) -> UnetParams:
|
||||
"""Apply c_concat on unet call."""
|
||||
sample = params["input"]
|
||||
params["c"]["c_concat"] = torch.cat(
|
||||
(
|
||||
[concat_conds.to(sample.device)]
|
||||
* (sample.shape[0] // concat_conds.shape[0])
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def create_custom_conv(
|
||||
original_conv: torch.nn.Module,
|
||||
dtype: torch.dtype,
|
||||
device=torch.device,
|
||||
) -> torch.nn.Module:
|
||||
with torch.no_grad():
|
||||
new_conv_in = torch.nn.Conv2d(
|
||||
8,
|
||||
original_conv.out_channels,
|
||||
original_conv.kernel_size,
|
||||
original_conv.stride,
|
||||
original_conv.padding,
|
||||
)
|
||||
new_conv_in.weight.zero_()
|
||||
new_conv_in.weight[:, :4, :, :].copy_(original_conv.weight)
|
||||
new_conv_in.bias = original_conv.bias
|
||||
return new_conv_in.to(dtype=dtype, device=device)
|
||||
|
||||
def generate_lighting_image(self, original_image, direction):
|
||||
_, image_height, image_width, _ = original_image.shape
|
||||
if direction == 'Left Light':
|
||||
gradient = np.linspace(255, 0, image_width)
|
||||
image = np.tile(gradient, (image_height, 1))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif direction == 'Right Light':
|
||||
gradient = np.linspace(0, 255, image_width)
|
||||
image = np.tile(gradient, (image_height, 1))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif direction == 'Top Light':
|
||||
gradient = np.linspace(255, 0, image_height)[:, None]
|
||||
image = np.tile(gradient, (1, image_width))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif direction == 'Bottom Light':
|
||||
gradient = np.linspace(0, 255, image_height)[:, None]
|
||||
image = np.tile(gradient, (1, image_width))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif direction == 'Circle Light':
|
||||
x = np.linspace(-1, 1, image_width)
|
||||
y = np.linspace(-1, 1, image_height)
|
||||
x, y = np.meshgrid(x, y)
|
||||
r = np.sqrt(x ** 2 + y ** 2)
|
||||
r = r / r.max()
|
||||
color1 = np.array([0, 0, 0])[np.newaxis, np.newaxis, :]
|
||||
color2 = np.array([255, 255, 255])[np.newaxis, np.newaxis, :]
|
||||
gradient = (color1 * r[..., np.newaxis] + color2 * (1 - r)[..., np.newaxis]).astype(np.uint8)
|
||||
image = pil2tensor(Image.fromarray(gradient))
|
||||
return image
|
||||
else:
|
||||
image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0)))
|
||||
return image
|
||||
|
||||
def generate_source_image(self, original_image, source):
|
||||
batch_size, image_height, image_width, _ = original_image.shape
|
||||
if source == 'Use Flipped Background Image':
|
||||
if batch_size < 2:
|
||||
raise ValueError('Must be at least 2 image to use flipped background image.')
|
||||
original_image = [img.unsqueeze(0) for img in original_image]
|
||||
image = torch.flip(original_image[1], [2])
|
||||
return image
|
||||
elif source == 'Ambient':
|
||||
input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
|
||||
return np2tensor(input_bg)
|
||||
elif source == 'Left Light':
|
||||
gradient = np.linspace(224, 32, image_width)
|
||||
image = np.tile(gradient, (image_height, 1))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif source == 'Right Light':
|
||||
gradient = np.linspace(32, 224, image_width)
|
||||
image = np.tile(gradient, (image_height, 1))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif source == 'Top Light':
|
||||
gradient = np.linspace(224, 32, image_height)[:, None]
|
||||
image = np.tile(gradient, (1, image_width))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
elif source == 'Bottom Light':
|
||||
gradient = np.linspace(32, 224, image_height)[:, None]
|
||||
image = np.tile(gradient, (1, image_width))
|
||||
input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
|
||||
return np2tensor(input_bg)
|
||||
else:
|
||||
image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0)))
|
||||
return image
|
||||
|
||||
|
||||
def apply(self, ic_model_path, model, c_concat: dict, ic_model=None) -> Tuple[ModelPatcher]:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
dtype = comfy.model_management.unet_dtype()
|
||||
work_model = model.clone()
|
||||
|
||||
# Apply scale factor.
|
||||
base_model: BaseModel = work_model.model
|
||||
scale_factor = base_model.model_config.latent_format.scale_factor
|
||||
|
||||
# [B, 4, H, W]
|
||||
concat_conds: torch.Tensor = c_concat["samples"] * scale_factor
|
||||
# [1, 4 * B, H, W]
|
||||
concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
|
||||
|
||||
def unet_dummy_apply(unet_apply: Callable, params: UnetParams):
|
||||
"""A dummy unet apply wrapper serving as the endpoint of wrapper
|
||||
chain."""
|
||||
return unet_apply(x=params["input"], t=params["timestep"], **params["c"])
|
||||
|
||||
existing_wrapper = work_model.model_options.get(
|
||||
"model_function_wrapper", unet_dummy_apply
|
||||
)
|
||||
|
||||
def wrapper_func(unet_apply: Callable, params: UnetParams):
|
||||
return existing_wrapper(unet_apply, params=self.apply_c_concat(params, concat_conds))
|
||||
|
||||
work_model.set_model_unet_function_wrapper(wrapper_func)
|
||||
if not ic_model:
|
||||
ic_model = load_unet(ic_model_path)
|
||||
ic_model_state_dict = ic_model.model.diffusion_model.state_dict()
|
||||
|
||||
work_model.add_patches(
|
||||
patches={
|
||||
("diffusion_model." + key): (
|
||||
'diff',
|
||||
[
|
||||
value.to(dtype=dtype, device=device),
|
||||
{"pad_weight": key == 'input_blocks.0.0.weight'}
|
||||
]
|
||||
)
|
||||
for key, value in ic_model_state_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
return (work_model, ic_model)
|
||||
268
custom_nodes/ComfyUI-Easy-Use/py/modules/ipadapter/__init__.py
Normal file
268
custom_nodes/ComfyUI-Easy-Use/py/modules/ipadapter/__init__.py
Normal file
@@ -0,0 +1,268 @@
|
||||
#credit to shakker-labs and instantX for this module
|
||||
#from https://github.com/Shakker-Labs/ComfyUI-IPAdapter-Flux
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from .attention_processor import IPAFluxAttnProcessor2_0
|
||||
from .utils import is_model_pathched, FluxUpdateModules
|
||||
from .sd3.resampler import TimeResampler
|
||||
from .sd3.joinblock import JointBlockIPWrapper, IPAttnProcessor
|
||||
|
||||
image_proj_model = None
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, id_embeds):
|
||||
x = self.proj(id_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class InstantXFluxIpadapterApply:
|
||||
def __init__(self, num_tokens=128):
|
||||
self.device = None
|
||||
self.dtype = torch.float16
|
||||
self.num_tokens = num_tokens
|
||||
self.ip_ckpt = None
|
||||
self.clip_vision = None
|
||||
self.image_encoder = None
|
||||
self.clip_image_processor = None
|
||||
# state_dict
|
||||
self.state_dict = None
|
||||
self.joint_attention_dim = 4096
|
||||
self.hidden_size = 3072
|
||||
|
||||
def set_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
|
||||
s = flux_model.model_sampling
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
timestep_range = (percent_to_timestep_function(timestep_percent_range[0]),
|
||||
percent_to_timestep_function(timestep_percent_range[1]))
|
||||
ip_attn_procs = {} # 19+38=57
|
||||
dsb_count = len(flux_model.diffusion_model.double_blocks)
|
||||
for i in range(dsb_count):
|
||||
name = f"double_blocks.{i}"
|
||||
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
|
||||
hidden_size=self.hidden_size,
|
||||
cross_attention_dim=self.joint_attention_dim,
|
||||
num_tokens=self.num_tokens,
|
||||
scale=weight,
|
||||
timestep_range=timestep_range
|
||||
).to(self.device, dtype=self.dtype)
|
||||
ssb_count = len(flux_model.diffusion_model.single_blocks)
|
||||
for i in range(ssb_count):
|
||||
name = f"single_blocks.{i}"
|
||||
ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
|
||||
hidden_size=self.hidden_size,
|
||||
cross_attention_dim=self.joint_attention_dim,
|
||||
num_tokens=self.num_tokens,
|
||||
scale=weight,
|
||||
timestep_range=timestep_range
|
||||
).to(self.device, dtype=self.dtype)
|
||||
return ip_attn_procs
|
||||
|
||||
def load_ip_adapter(self, flux_model, weight, timestep_percent_range=(0.0, 1.0)):
|
||||
global image_proj_model
|
||||
image_proj_model.load_state_dict(self.state_dict["image_proj"], strict=True)
|
||||
ip_attn_procs = self.set_ip_adapter(flux_model, weight, timestep_percent_range)
|
||||
ip_layers = torch.nn.ModuleList(ip_attn_procs.values())
|
||||
ip_layers.load_state_dict(self.state_dict["ip_adapter"], strict=True)
|
||||
return ip_attn_procs
|
||||
|
||||
def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
|
||||
# outputs = self.clip_vision.encode_image(pil_image)
|
||||
# clip_image_embeds = outputs['image_embeds']
|
||||
# clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
|
||||
# image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||
if pil_image is not None:
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = self.image_encoder(
|
||||
clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
|
||||
clip_image_embeds = clip_image_embeds.to(dtype=self.dtype)
|
||||
else:
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
|
||||
global image_proj_model
|
||||
image_prompt_embeds = image_proj_model(clip_image_embeds)
|
||||
return image_prompt_embeds
|
||||
|
||||
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
|
||||
self.device = provider.lower()
|
||||
if "clipvision" in ipadapter:
|
||||
# self.clip_vision = ipadapter["clipvision"]['model']
|
||||
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
|
||||
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
|
||||
if "ipadapter" in ipadapter:
|
||||
self.ip_ckpt = ipadapter["ipadapter"]['file']
|
||||
self.state_dict = ipadapter["ipadapter"]['model']
|
||||
|
||||
# process image
|
||||
pil_image = image.numpy()[0] * 255.0
|
||||
pil_image = Image.fromarray(pil_image.astype(np.uint8))
|
||||
# initialize ipadapter
|
||||
global image_proj_model
|
||||
if image_proj_model is None:
|
||||
image_proj_model = MLPProjModel(
|
||||
cross_attention_dim=self.joint_attention_dim, # 4096
|
||||
id_embeddings_dim=1152,
|
||||
num_tokens=self.num_tokens,
|
||||
)
|
||||
image_proj_model.to(self.device, dtype=self.dtype)
|
||||
ip_attn_procs = self.load_ip_adapter(model.model, weight, (start_at, end_at))
|
||||
# process control image
|
||||
image_prompt_embeds = self.get_image_embeds(pil_image=pil_image, clip_image_embeds=None)
|
||||
# set model
|
||||
# is_patched = is_model_pathched(model.model)
|
||||
bi = model.clone()
|
||||
FluxUpdateModules(bi, ip_attn_procs, image_prompt_embeds)
|
||||
|
||||
return (bi, image)
|
||||
|
||||
|
||||
def patch_sd3(
|
||||
patcher,
|
||||
ip_procs,
|
||||
resampler: TimeResampler,
|
||||
clip_embeds,
|
||||
weight=1.0,
|
||||
start=0.0,
|
||||
end=1.0,
|
||||
):
|
||||
"""
|
||||
Patches a model_sampler to add the ipadapter
|
||||
"""
|
||||
mmdit = patcher.model.diffusion_model
|
||||
timestep_schedule_max = patcher.model.model_config.sampling_settings.get(
|
||||
"timesteps", 1000
|
||||
)
|
||||
# hook the model's forward function
|
||||
# so that when it gets called, we can grab the timestep and send it to the resampler
|
||||
ip_options = {
|
||||
"hidden_states": None,
|
||||
"t_emb": None,
|
||||
"weight": weight,
|
||||
}
|
||||
|
||||
def ddit_wrapper(forward, args):
|
||||
# this is between 0 and 1, so the adapters can calculate start_point and end_point
|
||||
# actually, do we need to get the sigma value instead?
|
||||
t_percent = 1 - args["timestep"].flatten()[0].cpu().item()
|
||||
if start <= t_percent <= end:
|
||||
batch_size = args["input"].shape[0] // len(args["cond_or_uncond"])
|
||||
# if we're only doing cond or only doing uncond, only pass one of them through the resampler
|
||||
embeds = clip_embeds[args["cond_or_uncond"]]
|
||||
# slight efficiency optimization todo: pass the embeds through and then afterwards
|
||||
# repeat to the batch size
|
||||
embeds = torch.repeat_interleave(embeds, batch_size, dim=0)
|
||||
# the resampler wants between 0 and MAX_STEPS
|
||||
timestep = args["timestep"] * timestep_schedule_max
|
||||
image_emb, t_emb = resampler(embeds, timestep, need_temb=True)
|
||||
# these will need to be accessible to the IPAdapters
|
||||
ip_options["hidden_states"] = image_emb
|
||||
ip_options["t_emb"] = t_emb
|
||||
else:
|
||||
ip_options["hidden_states"] = None
|
||||
ip_options["t_emb"] = None
|
||||
|
||||
return forward(args["input"], args["timestep"], **args["c"])
|
||||
|
||||
patcher.set_model_unet_function_wrapper(ddit_wrapper)
|
||||
# patch each dit block
|
||||
for i, block in enumerate(mmdit.joint_blocks):
|
||||
wrapper = JointBlockIPWrapper(block, ip_procs[i], ip_options)
|
||||
patcher.set_model_patch_replace(wrapper, "dit", "double_block", i)
|
||||
|
||||
class InstantXSD3IpadapterApply:
|
||||
def __init__(self):
|
||||
self.device = None
|
||||
self.dtype = torch.float16
|
||||
self.clip_image_processor = None
|
||||
self.image_encoder = None
|
||||
self.resampler = None
|
||||
self.procs = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(self, image):
|
||||
clip_image = self.clip_image_processor.image_processor(image, return_tensors="pt", do_rescale=False).pixel_values
|
||||
clip_image_embeds = self.image_encoder(
|
||||
clip_image.to(self.device, dtype=self.image_encoder.dtype),
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
clip_image_embeds = torch.cat(
|
||||
[clip_image_embeds, torch.zeros_like(clip_image_embeds)], dim=0
|
||||
)
|
||||
clip_image_embeds = clip_image_embeds.to(dtype=torch.float16)
|
||||
return clip_image_embeds
|
||||
|
||||
def apply_ipadapter(self, model, ipadapter, image, weight, start_at, end_at, provider=None, use_tiled=False):
|
||||
self.device = provider.lower()
|
||||
if "clipvision" in ipadapter:
|
||||
self.image_encoder = ipadapter["clipvision"]['model']['image_encoder'].to(self.device, dtype=self.dtype)
|
||||
self.clip_image_processor = ipadapter["clipvision"]['model']['clip_image_processor']
|
||||
if "ipadapter" in ipadapter:
|
||||
self.ip_ckpt = ipadapter["ipadapter"]['file']
|
||||
self.state_dict = ipadapter["ipadapter"]['model']
|
||||
|
||||
self.resampler = TimeResampler(
|
||||
dim=1280,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=20,
|
||||
num_queries=64,
|
||||
embedding_dim=1152,
|
||||
output_dim=2432,
|
||||
ff_mult=4,
|
||||
timestep_in_dim=320,
|
||||
timestep_flip_sin_to_cos=True,
|
||||
timestep_freq_shift=0,
|
||||
)
|
||||
self.resampler.eval()
|
||||
self.resampler.to(self.device, dtype=self.dtype)
|
||||
self.resampler.load_state_dict(self.state_dict["image_proj"])
|
||||
|
||||
# now we'll create the attention processors
|
||||
# ip_adapter.keys looks like [0.proj, 0.to_k, ..., 1.proj, 1.to_k, ...]
|
||||
n_procs = len(
|
||||
set(x.split(".")[0] for x in self.state_dict["ip_adapter"].keys())
|
||||
)
|
||||
self.procs = torch.nn.ModuleList(
|
||||
[
|
||||
# this is hardcoded for SD3.5L
|
||||
IPAttnProcessor(
|
||||
hidden_size=2432,
|
||||
cross_attention_dim=2432,
|
||||
ip_hidden_states_dim=2432,
|
||||
ip_encoder_hidden_states_dim=2432,
|
||||
head_dim=64,
|
||||
timesteps_emb_dim=1280,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
for _ in range(n_procs)
|
||||
]
|
||||
)
|
||||
self.procs.load_state_dict(self.state_dict["ip_adapter"])
|
||||
|
||||
work_model = model.clone()
|
||||
embeds = self.encode(image)
|
||||
|
||||
patch_sd3(
|
||||
work_model,
|
||||
self.procs,
|
||||
self.resampler,
|
||||
embeds,
|
||||
weight,
|
||||
start_at,
|
||||
end_at,
|
||||
)
|
||||
|
||||
return (work_model, image)
|
||||
@@ -0,0 +1,87 @@
|
||||
import numbers
|
||||
from typing import Dict, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
class IPAFluxAttnProcessor2_0(nn.Module):
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, timestep_range=None):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size # 3072
|
||||
self.cross_attention_dim = cross_attention_dim # 4096
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
|
||||
self.norm_added_v = RMSNorm(128, eps=1e-5, elementwise_affine=False)
|
||||
self.timestep_range = timestep_range
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
num_heads,
|
||||
query,
|
||||
image_emb: torch.FloatTensor,
|
||||
t: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
# only apply IPA if timestep is within range
|
||||
if self.timestep_range is not None:
|
||||
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
||||
return None
|
||||
# `ip-adapter` projections
|
||||
ip_hidden_states = image_emb
|
||||
ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states)
|
||||
ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_hidden_states_key_proj = rearrange(ip_hidden_states_key_proj, 'B L (H D) -> B H L D', H=num_heads)
|
||||
ip_hidden_states_value_proj = rearrange(ip_hidden_states_value_proj, 'B L (H D) -> B H L D', H=num_heads)
|
||||
|
||||
ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj)
|
||||
ip_hidden_states_value_proj = self.norm_added_v(ip_hidden_states_value_proj)
|
||||
|
||||
ip_hidden_states = F.scaled_dot_product_attention(query.to(image_emb.device).to(image_emb.dtype),
|
||||
ip_hidden_states_key_proj,
|
||||
ip_hidden_states_value_proj,
|
||||
dropout_p=0.0, is_causal=False)
|
||||
|
||||
ip_hidden_states = rearrange(ip_hidden_states, "B H L D -> B L (H D)", H=num_heads)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype).to(query.device)
|
||||
|
||||
return self.scale * ip_hidden_states
|
||||
@@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention
|
||||
from ..attention_processor import IPAFluxAttnProcessor2_0
|
||||
from comfy.ldm.flux.layers import DoubleStreamBlock, SingleStreamBlock
|
||||
from comfy import model_management as mm
|
||||
|
||||
class DoubleStreamBlockIPA(nn.Module):
|
||||
def __init__(self, original_block: DoubleStreamBlock, ip_adapter, image_emb):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = original_block.img_mlp[0].out_features
|
||||
mlp_ratio = mlp_hidden_dim / original_block.hidden_size
|
||||
mlp_hidden_dim = int(original_block.hidden_size * mlp_ratio)
|
||||
self.num_heads = original_block.num_heads
|
||||
self.hidden_size = original_block.hidden_size
|
||||
self.img_mod = original_block.img_mod
|
||||
self.img_norm1 = original_block.img_norm1
|
||||
self.img_attn = original_block.img_attn
|
||||
|
||||
self.img_norm2 = original_block.img_norm2
|
||||
self.img_mlp = original_block.img_mlp
|
||||
|
||||
self.txt_mod = original_block.txt_mod
|
||||
self.txt_norm1 = original_block.txt_norm1
|
||||
self.txt_attn = original_block.txt_attn
|
||||
|
||||
self.txt_norm2 = original_block.txt_norm2
|
||||
self.txt_mlp = original_block.txt_mlp
|
||||
self.flipped_img_txt = original_block.flipped_img_txt
|
||||
|
||||
self.ip_adapter = ip_adapter
|
||||
self.image_emb = image_emb
|
||||
self.device = mm.get_torch_device()
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
|
||||
1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3,
|
||||
1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
for adapter, image in zip(self.ip_adapter, self.image_emb):
|
||||
# this does a separate attention for each adapter
|
||||
ip_hidden_states = adapter(self.num_heads, img_q, image, t)
|
||||
if ip_hidden_states is not None:
|
||||
ip_hidden_states = ip_hidden_states.to(self.device)
|
||||
img_attn = img_attn + ip_hidden_states
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlockIPA(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(self, original_block: SingleStreamBlock, ip_adapter, image_emb):
|
||||
super().__init__()
|
||||
self.hidden_dim = original_block.hidden_size
|
||||
self.num_heads = original_block.num_heads
|
||||
self.scale = original_block.scale
|
||||
|
||||
self.mlp_hidden_dim = original_block.mlp_hidden_dim
|
||||
# qkv and mlp_in
|
||||
self.linear1 = original_block.linear1
|
||||
# proj and mlp_out
|
||||
self.linear2 = original_block.linear2
|
||||
|
||||
self.norm = original_block.norm
|
||||
|
||||
self.hidden_size = original_block.hidden_size
|
||||
self.pre_norm = original_block.pre_norm
|
||||
|
||||
self.mlp_act = original_block.mlp_act
|
||||
self.modulation = original_block.modulation
|
||||
|
||||
self.ip_adapter = ip_adapter
|
||||
self.image_emb = image_emb
|
||||
self.device = mm.get_torch_device()
|
||||
|
||||
def add_adapter(self, ip_adapter: IPAFluxAttnProcessor2_0, image_emb):
|
||||
self.ip_adapter.append(ip_adapter)
|
||||
self.image_emb.append(image_emb)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, t: Tensor, attn_mask=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
|
||||
for adapter, image in zip(self.ip_adapter, self.image_emb):
|
||||
# this does a separate attention for each adapter
|
||||
# maybe we want a single joint attention call for all adapters?
|
||||
ip_hidden_states = adapter(self.num_heads, q, image, t)
|
||||
if ip_hidden_states is not None:
|
||||
ip_hidden_states = ip_hidden_states.to(self.device)
|
||||
attn = attn + ip_hidden_states
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
@@ -0,0 +1,219 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import (RMSNorm, JointBlock,)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
num_params_dict = dict(
|
||||
zero=6,
|
||||
normal=2,
|
||||
)
|
||||
num_params = num_params_dict[mode]
|
||||
self.linear = nn.Linear(
|
||||
time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True
|
||||
)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.mode = mode
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
hidden_dtype=None,
|
||||
emb=None,
|
||||
):
|
||||
emb = self.linear(self.silu(emb))
|
||||
if self.mode == "normal":
|
||||
shift_msa, scale_msa = emb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x
|
||||
|
||||
elif self.mode == "zero":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
|
||||
6, dim=1
|
||||
)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class IPAttnProcessor(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
ip_hidden_states_dim=None,
|
||||
ip_encoder_hidden_states_dim=None,
|
||||
head_dim=None,
|
||||
timesteps_emb_dim=1280,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm_ip = AdaLayerNorm(
|
||||
ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim
|
||||
)
|
||||
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
||||
self.norm_q = RMSNorm(head_dim, 1e-6)
|
||||
self.norm_k = RMSNorm(head_dim, 1e-6)
|
||||
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ip_hidden_states,
|
||||
img_query,
|
||||
img_key=None,
|
||||
img_value=None,
|
||||
t_emb=None,
|
||||
n_heads=1,
|
||||
):
|
||||
if ip_hidden_states is None:
|
||||
return None
|
||||
|
||||
if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"):
|
||||
return None
|
||||
|
||||
# norm ip input
|
||||
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=t_emb)
|
||||
|
||||
# to k and v
|
||||
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
||||
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
||||
|
||||
# reshape
|
||||
img_query = rearrange(img_query, "b l (h d) -> b h l d", h=n_heads)
|
||||
img_key = rearrange(img_key, "b l (h d) -> b h l d", h=n_heads)
|
||||
# note that the image is in a different shape: b l h d
|
||||
# so we transpose to b h l d
|
||||
# or do we have to transpose here?
|
||||
img_value = torch.transpose(img_value, 1, 2)
|
||||
ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=n_heads)
|
||||
ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=n_heads)
|
||||
|
||||
# norm
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
ip_key = self.norm_ip_k(ip_key)
|
||||
|
||||
# cat img
|
||||
key = torch.cat([img_key, ip_key], dim=2)
|
||||
value = torch.cat([img_value, ip_value], dim=2)
|
||||
|
||||
#
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
img_query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)")
|
||||
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
|
||||
return ip_hidden_states
|
||||
|
||||
|
||||
class JointBlockIPWrapper:
|
||||
"""To be used as a patch_replace with Comfy"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
original_block: JointBlock,
|
||||
adapter: IPAttnProcessor,
|
||||
ip_options=None,
|
||||
):
|
||||
self.original_block = original_block
|
||||
self.adapter = adapter
|
||||
if ip_options is None:
|
||||
ip_options = {}
|
||||
self.ip_options = ip_options
|
||||
|
||||
def block_mixing(self, context, x, context_block, x_block, c):
|
||||
"""
|
||||
Comes from mmdit.py. Modified to add ipadapter attention.
|
||||
"""
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
qkv = tuple(torch.cat((context_qkv[j], x_qkv[j]), dim=1) for j in range(3))
|
||||
|
||||
attn = optimized_attention(
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
heads=x_block.attn.num_heads,
|
||||
)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
# if the current timestep is not in the ipadapter enabling range, then the resampler wasn't run
|
||||
# and the hidden states will be None
|
||||
if (
|
||||
self.ip_options["hidden_states"] is not None
|
||||
and self.ip_options["t_emb"] is not None
|
||||
):
|
||||
# IP-Adapter
|
||||
ip_attn = self.adapter(
|
||||
self.ip_options["hidden_states"],
|
||||
*x_qkv,
|
||||
self.ip_options["t_emb"],
|
||||
x_block.attn.num_heads,
|
||||
)
|
||||
x_attn = x_attn + ip_attn * self.ip_options["weight"]
|
||||
|
||||
# Everything else is unchanged
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
|
||||
else:
|
||||
context = None
|
||||
if x_block.x_block_self_attn:
|
||||
attn2 = optimized_attention(
|
||||
x_qkv2[0],
|
||||
x_qkv2[1],
|
||||
x_qkv2[2],
|
||||
heads=x_block.attn2.num_heads,
|
||||
)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
return context, x
|
||||
|
||||
def __call__(self, args, _):
|
||||
# Code from mmdit.py:
|
||||
# in this case, we're blocks_replace[("double_block", i)]
|
||||
# note that although we're passed the original block,
|
||||
# we can't actually get it from inside its wrapper
|
||||
# (which would simplify the whole code...)
|
||||
# ```
|
||||
# def block_wrap(args):
|
||||
# out = {}
|
||||
# out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
|
||||
# return out
|
||||
# out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
|
||||
# context = out["txt"]
|
||||
# x = out["img"]
|
||||
# ```
|
||||
c, x = self.block_mixing(
|
||||
args["txt"],
|
||||
args["img"],
|
||||
self.original_block.context_block,
|
||||
self.original_block.x_block,
|
||||
c=args["vec"],
|
||||
)
|
||||
return {"txt": c, "img": x}
|
||||
@@ -0,0 +1,385 @@
|
||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": nn.SiLU(),
|
||||
"silu": nn.SiLU(),
|
||||
"mish": nn.Mish(),
|
||||
"gelu": nn.GELU(),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
Args
|
||||
timesteps (torch.Tensor):
|
||||
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (int):
|
||||
the dimension of the output.
|
||||
flip_sin_to_cos (bool):
|
||||
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
||||
downscale_freq_shift (float):
|
||||
Controls the delta between frequencies between dimensions
|
||||
scale (float):
|
||||
Scaling factor applied to the embeddings.
|
||||
max_period (int):
|
||||
Controls the maximum frequency of the embeddings
|
||||
Returns
|
||||
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents, shift=None, scale=None):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
if shift is not None and scale is not None:
|
||||
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
class TimeResampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
timestep_in_dim=320,
|
||||
timestep_flip_sin_to_cos=True,
|
||||
timestep_freq_shift=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
# msa
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
# ff
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
# adaLN
|
||||
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# time
|
||||
self.time_proj = Timesteps(
|
||||
timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift
|
||||
)
|
||||
self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
|
||||
|
||||
# adaLN
|
||||
# self.adaLN_modulation = nn.Sequential(
|
||||
# nn.SiLU(),
|
||||
# nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
|
||||
# )
|
||||
|
||||
def forward(self, x, timestep, need_temb=False):
|
||||
timestep_emb = self.embedding_time(x, timestep) # bs, dim
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
x = x + timestep_emb[:, None]
|
||||
|
||||
for attn, ff, adaLN_modulation in self.layers:
|
||||
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(
|
||||
timestep_emb
|
||||
).chunk(4, dim=1)
|
||||
latents = attn(x, latents, shift_msa, scale_msa) + latents
|
||||
|
||||
res = latents
|
||||
for idx_ff in range(len(ff)):
|
||||
layer_ff = ff[idx_ff]
|
||||
latents = layer_ff(latents)
|
||||
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
|
||||
latents = latents * (
|
||||
1 + scale_mlp.unsqueeze(1)
|
||||
) + shift_mlp.unsqueeze(1)
|
||||
latents = latents + res
|
||||
|
||||
# latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
|
||||
if need_temb:
|
||||
return latents, timestep_emb
|
||||
else:
|
||||
return latents
|
||||
|
||||
def embedding_time(self, sample, timestep):
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, None)
|
||||
return emb
|
||||
136
custom_nodes/ComfyUI-Easy-Use/py/modules/ipadapter/utils.py
Normal file
136
custom_nodes/ComfyUI-Easy-Use/py/modules/ipadapter/utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from .flux.layers import DoubleStreamBlockIPA, SingleStreamBlockIPA
|
||||
from comfy.ldm.flux.layers import timestep_embedding
|
||||
from types import MethodType
|
||||
|
||||
def FluxUpdateModules(bi, ip_attn_procs, image_emb):
|
||||
flux_model = bi.model
|
||||
bi.add_object_patch(f"diffusion_model.forward_orig", MethodType(forward_orig_ipa, flux_model.diffusion_model))
|
||||
for i, original in enumerate(flux_model.diffusion_model.double_blocks):
|
||||
patch_name = f"double_blocks.{i}"
|
||||
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
|
||||
# if there's already a patch there, collect its adapters and replace it
|
||||
procs = [ip_attn_procs[patch_name]]
|
||||
embs = [image_emb]
|
||||
if isinstance(maybe_patched_layer, DoubleStreamBlockIPA):
|
||||
procs = maybe_patched_layer.ip_adapter + procs
|
||||
embs = maybe_patched_layer.image_emb + embs
|
||||
# initial ipa models with image embeddings
|
||||
new_layer = DoubleStreamBlockIPA(original, procs, embs)
|
||||
# for example, ComfyUI internally uses model.add_patches to add loras
|
||||
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
|
||||
for i, original in enumerate(flux_model.diffusion_model.single_blocks):
|
||||
patch_name = f"single_blocks.{i}"
|
||||
maybe_patched_layer = bi.get_model_object(f"diffusion_model.{patch_name}")
|
||||
procs = [ip_attn_procs[patch_name]]
|
||||
embs = [image_emb]
|
||||
if isinstance(maybe_patched_layer, SingleStreamBlockIPA):
|
||||
procs = maybe_patched_layer.ip_adapter + procs
|
||||
embs = maybe_patched_layer.image_emb + embs
|
||||
# initial ipa models with image embeddings
|
||||
new_layer = SingleStreamBlockIPA(original, procs, embs)
|
||||
bi.add_object_patch(f"diffusion_model.{patch_name}", new_layer)
|
||||
|
||||
def is_model_pathched(model):
|
||||
def test(mod):
|
||||
if isinstance(mod, DoubleStreamBlockIPA):
|
||||
return True
|
||||
else:
|
||||
for p in mod.children():
|
||||
if test(p):
|
||||
return True
|
||||
return False
|
||||
|
||||
result = test(model)
|
||||
return result
|
||||
|
||||
|
||||
def forward_orig_ipa(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor|None = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
|
||||
else:
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
if isinstance(block, DoubleStreamBlockIPA): # ipadaper
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
if isinstance(block, SingleStreamBlockIPA): # ipadaper
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], t=args["timesteps"], attn_mask=args.get("attn_mask"))
|
||||
else:
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "timesteps": timesteps, "attn_mask": attn_mask}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
if isinstance(block, SingleStreamBlockIPA): # ipadaper
|
||||
img = block(img, vec=vec, pe=pe, t=timesteps, attn_mask=attn_mask)
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"_name_or_path": "THUDM/chatglm3-6b-base",
|
||||
"model_type": "chatglm",
|
||||
"architectures": [
|
||||
"ChatGLMModel"
|
||||
],
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_chatglm.ChatGLMConfig",
|
||||
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
||||
"AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
||||
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
||||
"AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
|
||||
},
|
||||
"add_bias_linear": false,
|
||||
"add_qkv_bias": true,
|
||||
"apply_query_key_layer_scaling": true,
|
||||
"apply_residual_connection_post_layernorm": false,
|
||||
"attention_dropout": 0.0,
|
||||
"attention_softmax_in_fp32": true,
|
||||
"bias_dropout_fusion": true,
|
||||
"ffn_hidden_size": 13696,
|
||||
"fp32_residual_connection": false,
|
||||
"hidden_dropout": 0.0,
|
||||
"hidden_size": 4096,
|
||||
"kv_channels": 128,
|
||||
"layernorm_epsilon": 1e-05,
|
||||
"multi_query_attention": true,
|
||||
"multi_query_group_num": 2,
|
||||
"num_attention_heads": 32,
|
||||
"num_layers": 28,
|
||||
"original_rope": true,
|
||||
"padded_vocab_size": 65024,
|
||||
"post_layer_norm": true,
|
||||
"rmsnorm": true,
|
||||
"seq_length": 32768,
|
||||
"use_cache": true,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.30.2",
|
||||
"tie_word_embeddings": false,
|
||||
"eos_token_id": 2,
|
||||
"pad_token_id": 0
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
def __init__(
|
||||
self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
classifier_dropout=None,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
super().__init__(**kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -0,0 +1,300 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Union, Dict
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.utils import logging, PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
# reload tokenizer
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
self.bos_id: int = self.sp_model.bos_id()
|
||||
self.eos_id: int = self.sp_model.eos_id()
|
||||
self.pad_id: int = self.sp_model.unk_id()
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
self.special_tokens = {}
|
||||
self.index_special_tokens = {}
|
||||
for token in special_tokens:
|
||||
self.special_tokens[token] = self.n_words
|
||||
self.index_special_tokens[self.n_words] = token
|
||||
self.n_words += 1
|
||||
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
||||
|
||||
def tokenize(self, s: str, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
last_index = 0
|
||||
t = []
|
||||
for match in re.finditer(self.role_special_token_expression, s):
|
||||
if last_index < match.start():
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
||||
t.append(s[match.start():match.end()])
|
||||
last_index = match.end()
|
||||
if last_index < len(s):
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
||||
return t
|
||||
else:
|
||||
return self.sp_model.EncodeAsPieces(s)
|
||||
|
||||
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||
assert type(s) is str
|
||||
t = self.sp_model.encode(s)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
text, buffer = "", []
|
||||
for token in t:
|
||||
if token in self.index_special_tokens:
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
buffer = []
|
||||
text += self.index_special_tokens[token]
|
||||
else:
|
||||
buffer.append(token)
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self.sp_model.DecodePieces(tokens)
|
||||
return text
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.index_special_tokens:
|
||||
return self.index_special_tokens[index]
|
||||
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||
return ""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
||||
**kwargs):
|
||||
self.name = "GLMTokenizer"
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.tokenizer = SPTokenizer(vocab_file)
|
||||
self.special_tokens = {
|
||||
"<bos>": self.tokenizer.bos_id,
|
||||
"<eos>": self.tokenizer.eos_id,
|
||||
"<pad>": self.tokenizer.pad_id
|
||||
}
|
||||
self.encode_special_tokens = encode_special_tokens
|
||||
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
encode_special_tokens=encode_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||
return self.tokenizer.special_tokens[token]
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_words
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.tokenizer.convert_token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.tokenizer.convert_id_to_token(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.tokenizer.decode_tokens(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def get_prefix_tokens(self):
|
||||
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
||||
return prefix_tokens
|
||||
|
||||
def build_single_message(self, role, metadata, message):
|
||||
assert role in ["system", "user", "assistant", "observation"], role
|
||||
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||
message_tokens = self.tokenizer.encode(message)
|
||||
tokens = role_tokens + message_tokens
|
||||
return tokens
|
||||
|
||||
def build_chat_input(self, query, history=None, role="user"):
|
||||
if history is None:
|
||||
history = []
|
||||
input_ids = []
|
||||
for item in history:
|
||||
content = item["content"]
|
||||
if item["role"] == "system" and "tools" in item:
|
||||
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||
input_ids.extend(self.build_single_message(role, "", query))
|
||||
input_ids.extend([self.get_command("<|assistant|>")])
|
||||
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
prefix_tokens = self.get_prefix_tokens()
|
||||
token_ids_0 = prefix_tokens + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * seq_length
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = list(range(seq_length))
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
Binary file not shown.
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name_or_path": "THUDM/chatglm3-6b-base",
|
||||
"remove_space": false,
|
||||
"do_lower_case": false,
|
||||
"tokenizer_class": "ChatGLMTokenizer",
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenization_chatglm.ChatGLMTokenizer",
|
||||
null
|
||||
]
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
303
custom_nodes/ComfyUI-Easy-Use/py/modules/kolors/loader.py
Normal file
303
custom_nodes/ComfyUI-Easy-Use/py/modules/kolors/loader.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
import subprocess
|
||||
import sys
|
||||
import comfy.supported_models
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.model_detection as model_detection
|
||||
import comfy.model_base as model_base
|
||||
from comfy.model_base import sdxl_pooled, CLIPEmbeddingNoiseAugmentation, Timestep, ModelType
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from comfy.clip_vision import ClipVisionModel, Output
|
||||
from comfy.utils import load_torch_file
|
||||
from .chatglm.modeling_chatglm import ChatGLMModel, ChatGLMConfig
|
||||
from .chatglm.tokenization_chatglm import ChatGLMTokenizer
|
||||
|
||||
class KolorsUNetModel(UNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.encoder_hid_proj = torch.nn.Linear(4096, 2048, bias=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
if "context" in kwargs:
|
||||
kwargs["context"] = self.encoder_hid_proj(kwargs["context"])
|
||||
result = super().forward(*args, **kwargs)
|
||||
return result
|
||||
|
||||
class KolorsSDXL(model_base.SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
model_base.BaseModel.__init__(self, model_config, model_type, device=device, unet_model=KolorsUNetModel)
|
||||
self.embedder = Timestep(256)
|
||||
self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
out = []
|
||||
out.append(self.embedder(torch.Tensor([height])))
|
||||
out.append(self.embedder(torch.Tensor([width])))
|
||||
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||
out.append(self.embedder(torch.Tensor([target_height])))
|
||||
out.append(self.embedder(torch.Tensor([target_width])))
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(
|
||||
dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
class Kolors(comfy.supported_models.SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 5632,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = KolorsSDXL(self, model_type=self.model_type(state_dict, prefix), device=device, )
|
||||
out.__class__ = model_base.SDXL
|
||||
if self.inpaint_model():
|
||||
out.set_inpaint()
|
||||
return out
|
||||
|
||||
def kolors_unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
match = {}
|
||||
transformer_depth = []
|
||||
|
||||
attn_res = 1
|
||||
count_blocks = model_detection.count_blocks
|
||||
down_blocks = count_blocks(state_dict, "down_blocks.{}")
|
||||
for i in range(down_blocks):
|
||||
attn_blocks = count_blocks(
|
||||
state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
||||
res_blocks = count_blocks(
|
||||
state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
|
||||
for ab in range(attn_blocks):
|
||||
transformer_count = count_blocks(
|
||||
state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
||||
transformer_depth.append(transformer_count)
|
||||
if transformer_count > 0:
|
||||
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(
|
||||
i, ab)].shape[1]
|
||||
|
||||
attn_res *= 2
|
||||
if attn_blocks == 0:
|
||||
for i in range(res_blocks):
|
||||
transformer_depth.append(0)
|
||||
|
||||
match["transformer_depth"] = transformer_depth
|
||||
|
||||
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
||||
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
||||
match["adm_in_channels"] = None
|
||||
if "class_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
||||
elif "add_embedding.linear_1.weight" in state_dict:
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
Kolors = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
Kolors_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
|
||||
'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 9,
|
||||
'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
Kolors_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
|
||||
'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 5632, 'dtype': dtype, 'in_channels': 8,
|
||||
'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
|
||||
'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
|
||||
'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
|
||||
'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
|
||||
'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 1,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True,
|
||||
'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4,
|
||||
'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0,
|
||||
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1,
|
||||
'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [Kolors, Kolors_inpaint,
|
||||
Kolors_ip2p, SDXL, SDXL_mid_cnet, SDXL_small_cnet]
|
||||
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
for k in match:
|
||||
if match[k] != unet_config[k]:
|
||||
# print("key {} does not match".format(k), match[k], "||", unet_config[k])
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
return model_detection.convert_config(unet_config)
|
||||
return None
|
||||
|
||||
# chatglm3 model
|
||||
class chatGLM3Model(torch.nn.Module):
|
||||
def __init__(self, textmodel_json_config=None, device='cpu', offload_device='cpu', model_path=None):
|
||||
super().__init__()
|
||||
if model_path is None:
|
||||
raise ValueError("model_path is required")
|
||||
self.device = device
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"chatglm",
|
||||
"config_chatglm.json"
|
||||
)
|
||||
with open(textmodel_json_config, 'r') as file:
|
||||
config = json.load(file)
|
||||
textmodel_json_config = ChatGLMConfig(**config)
|
||||
is_accelerate_available = False
|
||||
try:
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
is_accelerate_available = True
|
||||
except:
|
||||
pass
|
||||
|
||||
from contextlib import nullcontext
|
||||
with (init_empty_weights() if is_accelerate_available else nullcontext()):
|
||||
with torch.no_grad():
|
||||
print('torch version:', torch.__version__)
|
||||
self.text_encoder = ChatGLMModel(textmodel_json_config).eval()
|
||||
if '4bit' in model_path:
|
||||
try:
|
||||
import cpm_kernels
|
||||
except ImportError:
|
||||
print("Installing cpm_kernels...")
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", "cpm_kernels"], check=True)
|
||||
pass
|
||||
self.text_encoder.quantize(4)
|
||||
elif '8bit' in model_path:
|
||||
self.text_encoder.quantize(8)
|
||||
|
||||
sd = load_torch_file(model_path)
|
||||
if is_accelerate_available:
|
||||
for key in sd:
|
||||
set_module_tensor_to_device(self.text_encoder, key, device=offload_device, value=sd[key])
|
||||
else:
|
||||
print("WARNING: Accelerate not available, use load_state_dict load model")
|
||||
self.text_encoder.load_state_dict()
|
||||
|
||||
def load_chatglm3(model_path=None):
|
||||
if model_path is None:
|
||||
return
|
||||
|
||||
load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
|
||||
glm3model = chatGLM3Model(
|
||||
device=load_device,
|
||||
offload_device=offload_device,
|
||||
model_path=model_path
|
||||
)
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'chatglm', "tokenizer")
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
|
||||
text_encoder = glm3model.text_encoder
|
||||
return {"text_encoder":text_encoder, "tokenizer":tokenizer}
|
||||
|
||||
|
||||
# clipvision model
|
||||
def load_clipvision_vitl_336(path):
|
||||
sd = load_torch_file(path)
|
||||
if "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
raise Exception("Unsupported clip vision model")
|
||||
clip = ClipVisionModel(json_config)
|
||||
m, u = clip.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
print("missing clip vision: {}".format(m))
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
t = sd.pop(k)
|
||||
del t
|
||||
return clip
|
||||
|
||||
class applyKolorsUnet:
|
||||
def __enter__(self):
|
||||
import comfy.ldm.modules.diffusionmodules.openaimodel
|
||||
import comfy.utils
|
||||
import comfy.clip_vision
|
||||
|
||||
self.original_UNET_MAP_BASIC = comfy.utils.UNET_MAP_BASIC.copy()
|
||||
comfy.utils.UNET_MAP_BASIC.add(("encoder_hid_proj.weight", "encoder_hid_proj.weight"),)
|
||||
comfy.utils.UNET_MAP_BASIC.add(("encoder_hid_proj.bias", "encoder_hid_proj.bias"),)
|
||||
|
||||
self.original_unet_config_from_diffusers_unet = model_detection.unet_config_from_diffusers_unet
|
||||
model_detection.unet_config_from_diffusers_unet = kolors_unet_config_from_diffusers_unet
|
||||
|
||||
import comfy.supported_models
|
||||
self.original_supported_models = comfy.supported_models.models
|
||||
comfy.supported_models.models = [Kolors]
|
||||
|
||||
self.original_load_clipvision_from_sd = comfy.clip_vision.load_clipvision_from_sd
|
||||
comfy.clip_vision.load_clipvision_from_sd = load_clipvision_vitl_336
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
import comfy.ldm.modules.diffusionmodules.openaimodel
|
||||
import comfy.utils
|
||||
import comfy.supported_models
|
||||
import comfy.clip_vision
|
||||
|
||||
comfy.utils.UNET_MAP_BASIC = self.original_UNET_MAP_BASIC
|
||||
|
||||
model_detection.unet_config_from_diffusers_unet = self.original_unet_config_from_diffusers_unet
|
||||
comfy.supported_models.models = self.original_supported_models
|
||||
|
||||
comfy.clip_vision.load_clipvision_from_sd = self.original_load_clipvision_from_sd
|
||||
|
||||
|
||||
def is_kolors_model(model):
|
||||
unet_config = model.model.model_config.unet_config if hasattr(model, 'model') else None
|
||||
if unet_config and "adm_in_channels" in unet_config and unet_config["adm_in_channels"] == 5632:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
from torch.nn import Linear
|
||||
from types import MethodType
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
from comfy.cldm.cldm import ControlNet
|
||||
from comfy.controlnet import ControlLora
|
||||
|
||||
def patch_controlnet(model, control_net):
|
||||
import comfy.controlnet
|
||||
if isinstance(control_net, ControlLora):
|
||||
del_keys = []
|
||||
for k in control_net.control_weights:
|
||||
if k.startswith("label_emb.0.0."):
|
||||
del_keys.append(k)
|
||||
|
||||
for k in del_keys:
|
||||
control_net.control_weights.pop(k)
|
||||
|
||||
super_pre_run = ControlLora.pre_run
|
||||
super_copy = ControlLora.copy
|
||||
|
||||
super_forward = ControlNet.forward
|
||||
|
||||
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
context = model.model.diffusion_model.encoder_hid_proj(context)
|
||||
return super_forward(self, x, hint, timesteps, context, **kwargs)
|
||||
|
||||
def KolorsControlLora_pre_run(self, *args, **kwargs):
|
||||
result = super_pre_run(self, *args, **kwargs)
|
||||
|
||||
if hasattr(self, "control_model"):
|
||||
self.control_model.forward = MethodType(
|
||||
KolorsControlNet_forward, self.control_model)
|
||||
return result
|
||||
|
||||
control_net.pre_run = MethodType(
|
||||
KolorsControlLora_pre_run, control_net)
|
||||
|
||||
def KolorsControlLora_copy(self, *args, **kwargs):
|
||||
c = super_copy(self, *args, **kwargs)
|
||||
c.pre_run = MethodType(
|
||||
KolorsControlLora_pre_run, c)
|
||||
return c
|
||||
|
||||
control_net.copy = MethodType(KolorsControlLora_copy, control_net)
|
||||
|
||||
elif isinstance(control_net, comfy.controlnet.ControlNet):
|
||||
model_label_emb = model.model.diffusion_model.label_emb
|
||||
control_net.control_model.label_emb = model_label_emb
|
||||
control_net.control_model_wrapped.model.label_emb = model_label_emb
|
||||
super_forward = ControlNet.forward
|
||||
|
||||
def KolorsControlNet_forward(self, x, hint, timesteps, context, **kwargs):
|
||||
with torch.cuda.amp.autocast(enabled=True):
|
||||
context = model.model.diffusion_model.encoder_hid_proj(context)
|
||||
return super_forward(self, x, hint, timesteps, context, **kwargs)
|
||||
|
||||
control_net.control_model.forward = MethodType(
|
||||
KolorsControlNet_forward, control_net.control_model)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Type {control_net} not supported for KolorsControlNetPatch")
|
||||
|
||||
return control_net
|
||||
105
custom_nodes/ComfyUI-Easy-Use/py/modules/kolors/text_encode.py
Normal file
105
custom_nodes/ComfyUI-Easy-Use/py/modules/kolors/text_encode.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import re
|
||||
import random
|
||||
import gc
|
||||
import comfy.model_management as mm
|
||||
from nodes import ConditioningConcat, ConditioningZeroOut, ConditioningSetTimestepRange, ConditioningCombine
|
||||
|
||||
def chatglm3_text_encode(chatglm3_model, prompt, clean_gpu=False):
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
if clean_gpu:
|
||||
mm.unload_all_models()
|
||||
mm.soft_empty_cache()
|
||||
# Function to randomly select an option from the brackets
|
||||
|
||||
def choose_random_option(match):
|
||||
options = match.group(1).split('|')
|
||||
return random.choice(options)
|
||||
|
||||
prompt = re.sub(r'\{([^{}]*)\}', choose_random_option, prompt)
|
||||
|
||||
if "|" in prompt:
|
||||
prompt = prompt.split("|")
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizer = chatglm3_model['tokenizer']
|
||||
text_encoder = chatglm3_model['text_encoder']
|
||||
text_encoder.to(device)
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=256,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
output = text_encoder(
|
||||
input_ids=text_inputs['input_ids'],
|
||||
attention_mask=text_inputs['attention_mask'],
|
||||
position_ids=text_inputs['position_ids'],
|
||||
output_hidden_states=True)
|
||||
|
||||
# [batch_size, 77, 4096]
|
||||
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
|
||||
text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, 1, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
|
||||
bs_embed = text_proj.shape[0]
|
||||
text_proj = text_proj.repeat(1, 1).view(bs_embed, -1)
|
||||
text_encoder.to(offload_device)
|
||||
if clean_gpu:
|
||||
mm.soft_empty_cache()
|
||||
gc.collect()
|
||||
return [[prompt_embeds, {"pooled_output": text_proj},]]
|
||||
|
||||
def chatglm3_adv_text_encode(chatglm3_model, text, clean_gpu=False):
|
||||
time_start = 0
|
||||
time_end = 1
|
||||
match = re.search(r'TIMESTEP.*$', text)
|
||||
if match:
|
||||
timestep = match.group()
|
||||
timestep = timestep.split(' ')
|
||||
timestep = timestep[0]
|
||||
text = text.replace(timestep, '')
|
||||
value = timestep.split(':')
|
||||
if len(value) >= 3:
|
||||
time_start = float(value[1])
|
||||
time_end = float(value[2])
|
||||
elif len(value) == 2:
|
||||
time_start = float(value[1])
|
||||
time_end = 1
|
||||
elif len(value) == 1:
|
||||
time_start = 0.1
|
||||
time_end = 1
|
||||
|
||||
|
||||
pass3 = [x.strip() for x in text.split("BREAK")]
|
||||
pass3 = [x for x in pass3 if x != '']
|
||||
|
||||
if len(pass3) == 0:
|
||||
pass3 = ['']
|
||||
|
||||
conditioning = None
|
||||
|
||||
for text in pass3:
|
||||
cond = chatglm3_text_encode(chatglm3_model, text, clean_gpu)
|
||||
if conditioning is not None:
|
||||
conditioning = ConditioningConcat().concat(conditioning, cond)[0]
|
||||
else:
|
||||
conditioning = cond
|
||||
|
||||
# setTimeStepRange
|
||||
if time_start > 0 or time_end < 1:
|
||||
conditioning_2, = ConditioningSetTimestepRange().set_range(conditioning, 0, time_start)
|
||||
conditioning_1, = ConditioningZeroOut().zero_out(conditioning)
|
||||
conditioning_1, = ConditioningSetTimestepRange().set_range(conditioning_1, time_start, time_end)
|
||||
conditioning, = ConditioningCombine().combine(conditioning_1, conditioning_2)
|
||||
|
||||
return conditioning
|
||||
@@ -0,0 +1,219 @@
|
||||
#credit to huchenlei for this module
|
||||
#from https://github.com/huchenlei/ComfyUI-layerdiffuse
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.lora
|
||||
import copy
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from comfy.utils import load_torch_file
|
||||
from comfy.conds import CONDRegular
|
||||
from comfy_extras.nodes_compositing import JoinImageWithAlpha
|
||||
try:
|
||||
from .model import ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel
|
||||
except:
|
||||
ModelPatcher, TransparentVAEDecoder, calculate_weight_adjust_channel = None, None, None
|
||||
from .attension_sharing import AttentionSharingPatcher
|
||||
from ...config import LAYER_DIFFUSION, LAYER_DIFFUSION_DIR, LAYER_DIFFUSION_VAE
|
||||
from ...libs.utils import to_lora_patch_dict, get_local_filepath, get_sd_version
|
||||
|
||||
load_layer_model_state_dict = load_torch_file
|
||||
class LayerMethod(Enum):
|
||||
FG_ONLY_ATTN = "Attention Injection"
|
||||
FG_ONLY_CONV = "Conv Injection"
|
||||
FG_TO_BLEND = "Foreground"
|
||||
FG_BLEND_TO_BG = "Foreground to Background"
|
||||
BG_TO_BLEND = "Background"
|
||||
BG_BLEND_TO_FG = "Background to Foreground"
|
||||
EVERYTHING = "Everything"
|
||||
|
||||
class LayerDiffuse:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.vae_transparent_decoder = None
|
||||
self.frames = 1
|
||||
|
||||
def get_layer_diffusion_method(self, method, has_blend_latent):
|
||||
method = LayerMethod(method)
|
||||
if method == LayerMethod.BG_TO_BLEND and has_blend_latent:
|
||||
method = LayerMethod.BG_BLEND_TO_FG
|
||||
elif method == LayerMethod.FG_TO_BLEND and has_blend_latent:
|
||||
method = LayerMethod.FG_BLEND_TO_BG
|
||||
return method
|
||||
|
||||
def apply_layer_c_concat(self, cond, uncond, c_concat):
|
||||
def write_c_concat(cond):
|
||||
new_cond = []
|
||||
for t in cond:
|
||||
n = [t[0], t[1].copy()]
|
||||
if "model_conds" not in n[1]:
|
||||
n[1]["model_conds"] = {}
|
||||
n[1]["model_conds"]["c_concat"] = CONDRegular(c_concat)
|
||||
new_cond.append(n)
|
||||
return new_cond
|
||||
|
||||
return (write_c_concat(cond), write_c_concat(uncond))
|
||||
|
||||
def apply_layer_diffusion(self, model, method, weight, samples, blend_samples, positive, negative, image=None, additional_cond=(None, None, None)):
|
||||
control_img: Optional[torch.TensorType] = None
|
||||
sd_version = get_sd_version(model)
|
||||
model_url = LAYER_DIFFUSION[method.value][sd_version]["model_url"]
|
||||
|
||||
if image is not None:
|
||||
image = image.movedim(-1, 1)
|
||||
|
||||
try:
|
||||
if hasattr(comfy.lora, "calculate_weight"):
|
||||
comfy.lora.calculate_weight = calculate_weight_adjust_channel(comfy.lora.calculate_weight)
|
||||
else:
|
||||
ModelPatcher.calculate_weight = calculate_weight_adjust_channel(ModelPatcher.calculate_weight)
|
||||
except:
|
||||
pass
|
||||
|
||||
if method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN] and sd_version == 'sd1':
|
||||
self.frames = 1
|
||||
elif method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG, LayerMethod.FG_BLEND_TO_BG] and sd_version == 'sd1':
|
||||
self.frames = 2
|
||||
batch_size, _, height, width = samples['samples'].shape
|
||||
if batch_size % 2 != 0:
|
||||
raise Exception(f"The batch size should be a multiple of 2. 批次大小需为2的倍数")
|
||||
control_img = image
|
||||
elif method == LayerMethod.EVERYTHING and sd_version == 'sd1':
|
||||
batch_size, _, height, width = samples['samples'].shape
|
||||
self.frames = 3
|
||||
if batch_size % 3 != 0:
|
||||
raise Exception(f"The batch size should be a multiple of 3. 批次大小需为3的倍数")
|
||||
if model_url is None:
|
||||
raise Exception(f"{method.value} is not supported for {sd_version} model")
|
||||
|
||||
model_path = get_local_filepath(model_url, LAYER_DIFFUSION_DIR)
|
||||
layer_lora_state_dict = load_layer_model_state_dict(model_path)
|
||||
work_model = model.clone()
|
||||
if sd_version == 'sd1':
|
||||
patcher = AttentionSharingPatcher(
|
||||
work_model, self.frames, use_control=control_img is not None
|
||||
)
|
||||
patcher.load_state_dict(layer_lora_state_dict, strict=True)
|
||||
if control_img is not None:
|
||||
patcher.set_control(control_img)
|
||||
else:
|
||||
layer_lora_patch_dict = to_lora_patch_dict(layer_lora_state_dict)
|
||||
work_model.add_patches(layer_lora_patch_dict, weight)
|
||||
|
||||
# cond_contact
|
||||
if method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.FG_ONLY_CONV]:
|
||||
samp_model = work_model
|
||||
elif sd_version == 'sdxl':
|
||||
if method in [LayerMethod.BG_TO_BLEND, LayerMethod.FG_TO_BLEND]:
|
||||
c_concat = model.model.latent_format.process_in(samples["samples"])
|
||||
else:
|
||||
c_concat = model.model.latent_format.process_in(torch.cat([samples["samples"], blend_samples["samples"]], dim=1))
|
||||
samp_model, positive, negative = (work_model,) + self.apply_layer_c_concat(positive, negative, c_concat)
|
||||
elif sd_version == 'sd1':
|
||||
if method in [LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG]:
|
||||
additional_cond = (additional_cond[0], None)
|
||||
elif method in [LayerMethod.FG_TO_BLEND, LayerMethod.FG_BLEND_TO_BG]:
|
||||
additional_cond = (additional_cond[1], None)
|
||||
|
||||
work_model.model_options.setdefault("transformer_options", {})
|
||||
work_model.model_options["transformer_options"]["cond_overwrite"] = [
|
||||
cond[0][0] if cond is not None else None
|
||||
for cond in additional_cond
|
||||
]
|
||||
samp_model = work_model
|
||||
|
||||
return samp_model, positive, negative
|
||||
|
||||
def join_image_with_alpha(self, image, alpha):
|
||||
out = image.movedim(-1, 1)
|
||||
if out.shape[1] == 3: # RGB
|
||||
out = torch.cat([out, torch.ones_like(out[:, :1, :, :])], dim=1)
|
||||
for i in range(out.shape[0]):
|
||||
out[i, 3, :, :] = alpha
|
||||
return out.movedim(1, -1)
|
||||
|
||||
def image_to_alpha(self, image, latent):
|
||||
pixel = image.movedim(-1, 1) # [B, H, W, C] => [B, C, H, W]
|
||||
decoded = []
|
||||
sub_batch_size = 16
|
||||
for start_idx in range(0, latent.shape[0], sub_batch_size):
|
||||
decoded.append(
|
||||
self.vae_transparent_decoder.decode_pixel(
|
||||
pixel[start_idx: start_idx + sub_batch_size],
|
||||
latent[start_idx: start_idx + sub_batch_size],
|
||||
)
|
||||
)
|
||||
pixel_with_alpha = torch.cat(decoded, dim=0)
|
||||
# [B, C, H, W] => [B, H, W, C]
|
||||
pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
|
||||
image = pixel_with_alpha[..., 1:]
|
||||
alpha = pixel_with_alpha[..., 0]
|
||||
|
||||
alpha = 1.0 - alpha
|
||||
try:
|
||||
new_images, = JoinImageWithAlpha().execute(image, alpha)
|
||||
except:
|
||||
new_images, = JoinImageWithAlpha().join_image_with_alpha(image, alpha)
|
||||
return new_images, alpha
|
||||
|
||||
def make_3d_mask(self, mask):
|
||||
if len(mask.shape) == 4:
|
||||
return mask.squeeze(0)
|
||||
|
||||
elif len(mask.shape) == 2:
|
||||
return mask.unsqueeze(0)
|
||||
|
||||
return mask
|
||||
|
||||
def masks_to_list(self, masks):
|
||||
if masks is None:
|
||||
empty_mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||
return ([empty_mask],)
|
||||
|
||||
res = []
|
||||
|
||||
for mask in masks:
|
||||
res.append(mask)
|
||||
|
||||
return [self.make_3d_mask(x) for x in res]
|
||||
|
||||
def layer_diffusion_decode(self, layer_diffusion_method, latent, blend_samples, samp_images, model):
|
||||
alpha = []
|
||||
if layer_diffusion_method is not None:
|
||||
sd_version = get_sd_version(model)
|
||||
if sd_version not in ['sdxl', 'sd1']:
|
||||
raise Exception(f"Only SDXL and SD1.5 model supported for Layer Diffusion")
|
||||
method = self.get_layer_diffusion_method(layer_diffusion_method, blend_samples is not None)
|
||||
sd15_allow = True if sd_version == 'sd1' and method in [LayerMethod.FG_ONLY_ATTN, LayerMethod.EVERYTHING, LayerMethod.BG_TO_BLEND, LayerMethod.BG_BLEND_TO_FG] else False
|
||||
sdxl_allow = True if sd_version == 'sdxl' and method in [LayerMethod.FG_ONLY_CONV, LayerMethod.FG_ONLY_ATTN, LayerMethod.BG_BLEND_TO_FG] else False
|
||||
if sdxl_allow or sd15_allow:
|
||||
if self.vae_transparent_decoder is None:
|
||||
model_url = LAYER_DIFFUSION_VAE['decode'][sd_version]["model_url"]
|
||||
if model_url is None:
|
||||
raise Exception(f"{method.value} is not supported for {sd_version} model")
|
||||
decoder_file = get_local_filepath(model_url, LAYER_DIFFUSION_DIR)
|
||||
self.vae_transparent_decoder = TransparentVAEDecoder(
|
||||
load_torch_file(decoder_file),
|
||||
device=comfy.model_management.get_torch_device(),
|
||||
dtype=(torch.float16 if comfy.model_management.should_use_fp16() else torch.float32),
|
||||
)
|
||||
if method in [LayerMethod.EVERYTHING, LayerMethod.BG_BLEND_TO_FG, LayerMethod.BG_TO_BLEND]:
|
||||
new_images = []
|
||||
sliced_samples = copy.copy({"samples": latent})
|
||||
for index in range(len(samp_images)):
|
||||
if index % self.frames == 0:
|
||||
img = samp_images[index::self.frames]
|
||||
alpha_images, _alpha = self.image_to_alpha(img, sliced_samples["samples"][index::self.frames])
|
||||
alpha.append(self.make_3d_mask(_alpha[0]))
|
||||
new_images.append(alpha_images[0])
|
||||
else:
|
||||
new_images.append(samp_images[index])
|
||||
else:
|
||||
new_images, alpha = self.image_to_alpha(samp_images, latent)
|
||||
else:
|
||||
new_images = samp_images
|
||||
else:
|
||||
new_images = samp_images
|
||||
|
||||
|
||||
return (new_images, samp_images, alpha)
|
||||
@@ -0,0 +1,359 @@
|
||||
# Currently only sd15
|
||||
|
||||
import functools
|
||||
import torch
|
||||
import einops
|
||||
|
||||
from comfy import model_management, utils
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
module_mapping_sd15 = {
|
||||
0: "input_blocks.1.1.transformer_blocks.0.attn1",
|
||||
1: "input_blocks.1.1.transformer_blocks.0.attn2",
|
||||
2: "input_blocks.2.1.transformer_blocks.0.attn1",
|
||||
3: "input_blocks.2.1.transformer_blocks.0.attn2",
|
||||
4: "input_blocks.4.1.transformer_blocks.0.attn1",
|
||||
5: "input_blocks.4.1.transformer_blocks.0.attn2",
|
||||
6: "input_blocks.5.1.transformer_blocks.0.attn1",
|
||||
7: "input_blocks.5.1.transformer_blocks.0.attn2",
|
||||
8: "input_blocks.7.1.transformer_blocks.0.attn1",
|
||||
9: "input_blocks.7.1.transformer_blocks.0.attn2",
|
||||
10: "input_blocks.8.1.transformer_blocks.0.attn1",
|
||||
11: "input_blocks.8.1.transformer_blocks.0.attn2",
|
||||
12: "output_blocks.3.1.transformer_blocks.0.attn1",
|
||||
13: "output_blocks.3.1.transformer_blocks.0.attn2",
|
||||
14: "output_blocks.4.1.transformer_blocks.0.attn1",
|
||||
15: "output_blocks.4.1.transformer_blocks.0.attn2",
|
||||
16: "output_blocks.5.1.transformer_blocks.0.attn1",
|
||||
17: "output_blocks.5.1.transformer_blocks.0.attn2",
|
||||
18: "output_blocks.6.1.transformer_blocks.0.attn1",
|
||||
19: "output_blocks.6.1.transformer_blocks.0.attn2",
|
||||
20: "output_blocks.7.1.transformer_blocks.0.attn1",
|
||||
21: "output_blocks.7.1.transformer_blocks.0.attn2",
|
||||
22: "output_blocks.8.1.transformer_blocks.0.attn1",
|
||||
23: "output_blocks.8.1.transformer_blocks.0.attn2",
|
||||
24: "output_blocks.9.1.transformer_blocks.0.attn1",
|
||||
25: "output_blocks.9.1.transformer_blocks.0.attn2",
|
||||
26: "output_blocks.10.1.transformer_blocks.0.attn1",
|
||||
27: "output_blocks.10.1.transformer_blocks.0.attn2",
|
||||
28: "output_blocks.11.1.transformer_blocks.0.attn1",
|
||||
29: "output_blocks.11.1.transformer_blocks.0.attn2",
|
||||
30: "middle_block.1.transformer_blocks.0.attn1",
|
||||
31: "middle_block.1.transformer_blocks.0.attn2",
|
||||
}
|
||||
|
||||
|
||||
def compute_cond_mark(cond_or_uncond, sigmas):
|
||||
cond_or_uncond_size = int(sigmas.shape[0])
|
||||
|
||||
cond_mark = []
|
||||
for cx in cond_or_uncond:
|
||||
cond_mark += [cx] * cond_or_uncond_size
|
||||
|
||||
cond_mark = torch.Tensor(cond_mark).to(sigmas)
|
||||
return cond_mark
|
||||
|
||||
|
||||
class LoRALinearLayer(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, rank: int = 256, org=None):
|
||||
super().__init__()
|
||||
self.down = torch.nn.Linear(in_features, rank, bias=False)
|
||||
self.up = torch.nn.Linear(rank, out_features, bias=False)
|
||||
self.org = [org]
|
||||
|
||||
def forward(self, h):
|
||||
org_weight = self.org[0].weight.to(h)
|
||||
org_bias = self.org[0].bias.to(h) if self.org[0].bias is not None else None
|
||||
down_weight = self.down.weight
|
||||
up_weight = self.up.weight
|
||||
final_weight = org_weight + torch.mm(up_weight, down_weight)
|
||||
return torch.nn.functional.linear(h, final_weight, org_bias)
|
||||
|
||||
|
||||
class AttentionSharingUnit(torch.nn.Module):
|
||||
# `transformer_options` passed to the most recent BasicTransformerBlock.forward
|
||||
# call.
|
||||
transformer_options: dict = {}
|
||||
|
||||
def __init__(self, module, frames=2, use_control=True, rank=256):
|
||||
super().__init__()
|
||||
|
||||
self.heads = module.heads
|
||||
self.frames = frames
|
||||
self.original_module = [module]
|
||||
q_in_channels, q_out_channels = (
|
||||
module.to_q.in_features,
|
||||
module.to_q.out_features,
|
||||
)
|
||||
k_in_channels, k_out_channels = (
|
||||
module.to_k.in_features,
|
||||
module.to_k.out_features,
|
||||
)
|
||||
v_in_channels, v_out_channels = (
|
||||
module.to_v.in_features,
|
||||
module.to_v.out_features,
|
||||
)
|
||||
o_in_channels, o_out_channels = (
|
||||
module.to_out[0].in_features,
|
||||
module.to_out[0].out_features,
|
||||
)
|
||||
|
||||
hidden_size = k_out_channels
|
||||
|
||||
self.to_q_lora = [
|
||||
LoRALinearLayer(q_in_channels, q_out_channels, rank, module.to_q)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.to_k_lora = [
|
||||
LoRALinearLayer(k_in_channels, k_out_channels, rank, module.to_k)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.to_v_lora = [
|
||||
LoRALinearLayer(v_in_channels, v_out_channels, rank, module.to_v)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.to_out_lora = [
|
||||
LoRALinearLayer(o_in_channels, o_out_channels, rank, module.to_out[0])
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
|
||||
self.to_q_lora = torch.nn.ModuleList(self.to_q_lora)
|
||||
self.to_k_lora = torch.nn.ModuleList(self.to_k_lora)
|
||||
self.to_v_lora = torch.nn.ModuleList(self.to_v_lora)
|
||||
self.to_out_lora = torch.nn.ModuleList(self.to_out_lora)
|
||||
|
||||
self.temporal_i = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_n = torch.nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6
|
||||
)
|
||||
self.temporal_q = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_k = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_v = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
self.temporal_o = torch.nn.Linear(
|
||||
in_features=hidden_size, out_features=hidden_size
|
||||
)
|
||||
|
||||
self.control_convs = None
|
||||
|
||||
if use_control:
|
||||
self.control_convs = [
|
||||
torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, hidden_size, kernel_size=1),
|
||||
)
|
||||
for _ in range(self.frames)
|
||||
]
|
||||
self.control_convs = torch.nn.ModuleList(self.control_convs)
|
||||
|
||||
self.control_signals = None
|
||||
|
||||
def forward(self, h, context=None, value=None):
|
||||
transformer_options = self.transformer_options
|
||||
|
||||
modified_hidden_states = einops.rearrange(
|
||||
h, "(b f) d c -> f b d c", f=self.frames
|
||||
)
|
||||
|
||||
if self.control_convs is not None:
|
||||
context_dim = int(modified_hidden_states.shape[2])
|
||||
control_outs = []
|
||||
for f in range(self.frames):
|
||||
control_signal = self.control_signals[context_dim].to(
|
||||
modified_hidden_states
|
||||
)
|
||||
control = self.control_convs[f](control_signal)
|
||||
control = einops.rearrange(control, "b c h w -> b (h w) c")
|
||||
control_outs.append(control)
|
||||
control_outs = torch.stack(control_outs, dim=0)
|
||||
modified_hidden_states = modified_hidden_states + control_outs.to(
|
||||
modified_hidden_states
|
||||
)
|
||||
|
||||
if context is None:
|
||||
framed_context = modified_hidden_states
|
||||
else:
|
||||
framed_context = einops.rearrange(
|
||||
context, "(b f) d c -> f b d c", f=self.frames
|
||||
)
|
||||
|
||||
framed_cond_mark = einops.rearrange(
|
||||
compute_cond_mark(
|
||||
transformer_options["cond_or_uncond"],
|
||||
transformer_options["sigmas"],
|
||||
),
|
||||
"(b f) -> f b",
|
||||
f=self.frames,
|
||||
).to(modified_hidden_states)
|
||||
|
||||
attn_outs = []
|
||||
for f in range(self.frames):
|
||||
fcf = framed_context[f]
|
||||
|
||||
if context is not None:
|
||||
cond_overwrite = transformer_options.get("cond_overwrite", [])
|
||||
if len(cond_overwrite) > f:
|
||||
cond_overwrite = cond_overwrite[f]
|
||||
else:
|
||||
cond_overwrite = None
|
||||
if cond_overwrite is not None:
|
||||
cond_mark = framed_cond_mark[f][:, None, None]
|
||||
fcf = cond_overwrite.to(fcf) * (1.0 - cond_mark) + fcf * cond_mark
|
||||
|
||||
q = self.to_q_lora[f](modified_hidden_states[f])
|
||||
k = self.to_k_lora[f](fcf)
|
||||
v = self.to_v_lora[f](fcf)
|
||||
o = optimized_attention(q, k, v, self.heads)
|
||||
o = self.to_out_lora[f](o)
|
||||
o = self.original_module[0].to_out[1](o)
|
||||
attn_outs.append(o)
|
||||
|
||||
attn_outs = torch.stack(attn_outs, dim=0)
|
||||
modified_hidden_states = modified_hidden_states + attn_outs.to(
|
||||
modified_hidden_states
|
||||
)
|
||||
modified_hidden_states = einops.rearrange(
|
||||
modified_hidden_states, "f b d c -> (b f) d c", f=self.frames
|
||||
)
|
||||
|
||||
x = modified_hidden_states
|
||||
x = self.temporal_n(x)
|
||||
x = self.temporal_i(x)
|
||||
d = x.shape[1]
|
||||
|
||||
x = einops.rearrange(x, "(b f) d c -> (b d) f c", f=self.frames)
|
||||
|
||||
q = self.temporal_q(x)
|
||||
k = self.temporal_k(x)
|
||||
v = self.temporal_v(x)
|
||||
|
||||
x = optimized_attention(q, k, v, self.heads)
|
||||
x = self.temporal_o(x)
|
||||
x = einops.rearrange(x, "(b d) f c -> (b f) d c", d=d)
|
||||
|
||||
modified_hidden_states = modified_hidden_states + x
|
||||
|
||||
return modified_hidden_states - h
|
||||
|
||||
@classmethod
|
||||
def hijack_transformer_block(cls):
|
||||
def register_get_transformer_options(func):
|
||||
@functools.wraps(func)
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
cls.transformer_options = transformer_options
|
||||
return func(self, x, context, transformer_options)
|
||||
|
||||
return forward
|
||||
|
||||
from comfy.ldm.modules.attention import BasicTransformerBlock
|
||||
|
||||
BasicTransformerBlock.forward = register_get_transformer_options(
|
||||
BasicTransformerBlock.forward
|
||||
)
|
||||
|
||||
|
||||
AttentionSharingUnit.hijack_transformer_block()
|
||||
|
||||
|
||||
class AdditionalAttentionCondsEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.blocks_0 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 64*64*256
|
||||
|
||||
self.blocks_1 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 32*32*256
|
||||
|
||||
self.blocks_2 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 16*16*256
|
||||
|
||||
self.blocks_3 = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
torch.nn.SiLU(),
|
||||
) # 8*8*256
|
||||
|
||||
self.blks = [self.blocks_0, self.blocks_1, self.blocks_2, self.blocks_3]
|
||||
|
||||
def __call__(self, h):
|
||||
results = {}
|
||||
for b in self.blks:
|
||||
h = b(h)
|
||||
results[int(h.shape[2]) * int(h.shape[3])] = h
|
||||
return results
|
||||
|
||||
|
||||
class HookerLayers(torch.nn.Module):
|
||||
def __init__(self, layer_list):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(layer_list)
|
||||
|
||||
|
||||
class AttentionSharingPatcher(torch.nn.Module):
|
||||
def __init__(self, unet, frames=2, use_control=True, rank=256):
|
||||
super().__init__()
|
||||
model_management.unload_model_clones(unet)
|
||||
|
||||
units = []
|
||||
for i in range(32):
|
||||
real_key = module_mapping_sd15[i]
|
||||
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
|
||||
u = AttentionSharingUnit(
|
||||
attn_module, frames=frames, use_control=use_control, rank=rank
|
||||
)
|
||||
units.append(u)
|
||||
unet.add_object_patch("diffusion_model." + real_key, u)
|
||||
|
||||
self.hookers = HookerLayers(units)
|
||||
|
||||
if use_control:
|
||||
self.kwargs_encoder = AdditionalAttentionCondsEncoder()
|
||||
else:
|
||||
self.kwargs_encoder = None
|
||||
|
||||
self.dtype = torch.float32
|
||||
if model_management.should_use_fp16(model_management.get_torch_device()):
|
||||
self.dtype = torch.float16
|
||||
self.hookers.half()
|
||||
return
|
||||
|
||||
def set_control(self, img):
|
||||
img = img.cpu().float() * 2.0 - 1.0
|
||||
signals = self.kwargs_encoder(img)
|
||||
for m in self.hookers.layers:
|
||||
m.control_signals = signals
|
||||
return
|
||||
390
custom_nodes/ComfyUI-Easy-Use/py/modules/layer_diffuse/model.py
Normal file
390
custom_nodes/ComfyUI-Easy-Use/py/modules/layer_diffuse/model.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import cv2
|
||||
import numpy as np
|
||||
import comfy.model_management
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, Tuple
|
||||
from ...libs.utils import install_package
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
install_package("diffusers", "0.27.2", True, "0.25.0")
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers import __version__
|
||||
if __version__:
|
||||
if version.parse(__version__) < version.parse("0.26.0"):
|
||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
else:
|
||||
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
import functools
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class LatentTransparencyOffsetEncoder(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.blocks = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
|
||||
nn.SiLU(),
|
||||
zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)),
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.blocks(x)
|
||||
|
||||
|
||||
# 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3
|
||||
class UNet1024(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = (
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
"AttnUpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 4,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
|
||||
)
|
||||
self.latent_conv_in = zero_module(
|
||||
nn.Conv2d(4, block_out_channels[2], kernel_size=1)
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=None,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=(
|
||||
attention_head_dim
|
||||
if attention_head_dim is not None
|
||||
else output_channel
|
||||
),
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift="default",
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=None,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
attention_head_dim=(
|
||||
attention_head_dim
|
||||
if attention_head_dim is not None
|
||||
else block_out_channels[-1]
|
||||
),
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=None,
|
||||
add_attention=True,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[
|
||||
min(i + 1, len(block_out_channels) - 1)
|
||||
]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=(
|
||||
attention_head_dim
|
||||
if attention_head_dim is not None
|
||||
else output_channel
|
||||
),
|
||||
resnet_time_scale_shift="default",
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(
|
||||
block_out_channels[0], out_channels, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
def forward(self, x, latent):
|
||||
sample_latent = self.latent_conv_in(latent)
|
||||
sample = self.conv_in(x)
|
||||
emb = None
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for i, downsample_block in enumerate(self.down_blocks):
|
||||
if i == 3:
|
||||
sample = sample + sample_latent
|
||||
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[
|
||||
: -len(upsample_block.resnets)
|
||||
]
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
return sample
|
||||
|
||||
|
||||
def checkerboard(shape):
|
||||
return np.indices(shape).sum(axis=0) % 2
|
||||
|
||||
|
||||
def fill_checkerboard_bg(y: torch.Tensor) -> torch.Tensor:
|
||||
alpha = y[..., :1]
|
||||
fg = y[..., 1:]
|
||||
B, H, W, C = fg.shape
|
||||
cb = checkerboard(shape=(H // 64, W // 64))
|
||||
cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST)
|
||||
cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None]
|
||||
cb = torch.from_numpy(cb).to(fg)
|
||||
vis = fg * alpha + cb * (1 - alpha)
|
||||
return vis
|
||||
|
||||
|
||||
class TransparentVAEDecoder:
|
||||
def __init__(self, sd, device, dtype):
|
||||
self.load_device = device
|
||||
self.dtype = dtype
|
||||
|
||||
model = UNet1024(in_channels=3, out_channels=4)
|
||||
model.load_state_dict(sd, strict=True)
|
||||
model.to(self.load_device, dtype=self.dtype)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_single_pass(self, pixel, latent):
|
||||
y = self.model(pixel, latent)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_augmented(self, pixel, latent):
|
||||
args = [
|
||||
[False, 0],
|
||||
[False, 1],
|
||||
[False, 2],
|
||||
[False, 3],
|
||||
[True, 0],
|
||||
[True, 1],
|
||||
[True, 2],
|
||||
[True, 3],
|
||||
]
|
||||
|
||||
result = []
|
||||
|
||||
for flip, rok in tqdm(args):
|
||||
feed_pixel = pixel.clone()
|
||||
feed_latent = latent.clone()
|
||||
|
||||
if flip:
|
||||
feed_pixel = torch.flip(feed_pixel, dims=(3,))
|
||||
feed_latent = torch.flip(feed_latent, dims=(3,))
|
||||
|
||||
feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3))
|
||||
feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3))
|
||||
|
||||
eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1)
|
||||
eps = torch.rot90(eps, k=-rok, dims=(2, 3))
|
||||
|
||||
if flip:
|
||||
eps = torch.flip(eps, dims=(3,))
|
||||
|
||||
result += [eps]
|
||||
|
||||
result = torch.stack(result, dim=0)
|
||||
median = torch.median(result, dim=0).values
|
||||
return median
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_pixel(
|
||||
self, pixel: torch.TensorType, latent: torch.TensorType
|
||||
) -> torch.TensorType:
|
||||
# pixel.shape = [B, C=3, H, W]
|
||||
assert pixel.shape[1] == 3
|
||||
pixel_device = pixel.device
|
||||
pixel_dtype = pixel.dtype
|
||||
|
||||
pixel = pixel.to(device=self.load_device, dtype=self.dtype)
|
||||
latent = latent.to(device=self.load_device, dtype=self.dtype)
|
||||
# y.shape = [B, C=4, H, W]
|
||||
y = self.estimate_augmented(pixel, latent)
|
||||
y = y.clip(0, 1)
|
||||
assert y.shape[1] == 4
|
||||
# Restore image to original device of input image.
|
||||
return y.to(pixel_device, dtype=pixel_dtype)
|
||||
|
||||
|
||||
def calculate_weight_adjust_channel(func):
|
||||
"""Patches ComfyUI's LoRA weight application to accept multi-channel inputs."""
|
||||
@functools.wraps(func)
|
||||
def calculate_weight(
|
||||
patches, weight: torch.Tensor, key: str, intermediate_type=torch.float32
|
||||
) -> torch.Tensor:
|
||||
weight = func(patches, weight, key, intermediate_type)
|
||||
|
||||
for p in patches:
|
||||
alpha = p[0]
|
||||
v = p[1]
|
||||
|
||||
# The recursion call should be handled in the main func call.
|
||||
if isinstance(v, list):
|
||||
continue
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if all(
|
||||
(
|
||||
alpha != 0.0,
|
||||
w1.shape != weight.shape,
|
||||
w1.ndim == weight.ndim == 4,
|
||||
)
|
||||
):
|
||||
new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)]
|
||||
print(
|
||||
f"Merged with {key} channel changed from {weight.shape} to {new_shape}"
|
||||
)
|
||||
new_diff = alpha * comfy.model_management.cast_to_device(
|
||||
w1, weight.device, weight.dtype
|
||||
)
|
||||
new_weight = torch.zeros(size=new_shape).to(weight)
|
||||
new_weight[
|
||||
: weight.shape[0],
|
||||
: weight.shape[1],
|
||||
: weight.shape[2],
|
||||
: weight.shape[3],
|
||||
] = weight
|
||||
new_weight[
|
||||
: new_diff.shape[0],
|
||||
: new_diff.shape[1],
|
||||
: new_diff.shape[2],
|
||||
: new_diff.shape[3],
|
||||
] += new_diff
|
||||
new_weight = new_weight.contiguous().clone()
|
||||
weight = new_weight
|
||||
return weight
|
||||
|
||||
return calculate_weight
|
||||
|
||||
|
||||
except ImportError:
|
||||
ModelMixin = None
|
||||
ConfigMixin = None
|
||||
TransparentVAEDecoder = None
|
||||
calculate_weight_adjust_channel = None
|
||||
print("\33[33mModule 'diffusers' load failed. If you don't have it installed, do it:\033[0m")
|
||||
print("\33[33mpip install diffusers\033[0m")
|
||||
|
||||
|
||||
|
||||
1358
custom_nodes/ComfyUI-Easy-Use/py/nodes/adapter.py
Normal file
1358
custom_nodes/ComfyUI-Easy-Use/py/nodes/adapter.py
Normal file
File diff suppressed because it is too large
Load Diff
139
custom_nodes/ComfyUI-Easy-Use/py/nodes/api.py
Normal file
139
custom_nodes/ComfyUI-Easy-Use/py/nodes/api.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import comfy.utils
|
||||
from ..libs.api.fluxai import fluxaiAPI
|
||||
from ..libs.api.bizyair import bizyairAPI, encode_data
|
||||
from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS
|
||||
|
||||
class joyCaption2API:
|
||||
API_URL = f"/supernode/joycaption2"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"do_sample": ([True, False],),
|
||||
"temperature": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 2.0,
|
||||
"step": 0.01,
|
||||
"round": 0.001,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"max_tokens": (
|
||||
"INT",
|
||||
{
|
||||
"default": 256,
|
||||
"min": 16,
|
||||
"max": 512,
|
||||
"step": 16,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
"caption_type": (
|
||||
[
|
||||
"Descriptive",
|
||||
"Descriptive (Informal)",
|
||||
"Training Prompt",
|
||||
"MidJourney",
|
||||
"Booru tag list",
|
||||
"Booru-like tag list",
|
||||
"Art Critic",
|
||||
"Product Listing",
|
||||
"Social Media Post",
|
||||
],
|
||||
),
|
||||
"caption_length": (
|
||||
["any", "very short", "short", "medium-length", "long", "very long"]
|
||||
+ [str(i) for i in range(20, 261, 10)],
|
||||
),
|
||||
"extra_options": (
|
||||
"STRING",
|
||||
{
|
||||
"placeholder": "Extra options(e.g):\nIf there is a person/character in the image you must refer to them as {name}.",
|
||||
"tooltip": "Extra options for the model",
|
||||
"multiline": True,
|
||||
},
|
||||
),
|
||||
"name_input": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "",
|
||||
"tooltip": "Name input is only used if an Extra Option is selected that requires it.",
|
||||
},
|
||||
),
|
||||
"custom_prompt": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "",
|
||||
"multiline": True,
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional":{
|
||||
"apikey_override": ("STRING", {"default": "", "forceInput": True, "tooltip":"Override the API key in the local config"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("caption",)
|
||||
|
||||
FUNCTION = "joycaption"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
CATEGORY = "EasyUse/API"
|
||||
|
||||
def joycaption(
|
||||
self,
|
||||
image,
|
||||
do_sample,
|
||||
temperature,
|
||||
max_tokens,
|
||||
caption_type,
|
||||
caption_length,
|
||||
extra_options,
|
||||
name_input,
|
||||
custom_prompt,
|
||||
apikey_override=None
|
||||
):
|
||||
pbar = comfy.utils.ProgressBar(100)
|
||||
pbar.update_absolute(10)
|
||||
SIZE_LIMIT = 1536
|
||||
_, w, h, c = image.shape
|
||||
if w > SIZE_LIMIT or h > SIZE_LIMIT:
|
||||
node_class = ALL_NODE_CLASS_MAPPINGS['easy imageScaleDownToSize']
|
||||
image, = node_class().image_scale_down_to_size(image, SIZE_LIMIT, True)
|
||||
|
||||
payload = {
|
||||
"image": None,
|
||||
"do_sample": do_sample == True,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_tokens,
|
||||
"caption_type": caption_type,
|
||||
"caption_length": caption_length,
|
||||
"extra_options": [extra_options],
|
||||
"name_input": name_input,
|
||||
"custom_prompt": custom_prompt,
|
||||
}
|
||||
|
||||
pbar.update_absolute(30)
|
||||
caption = bizyairAPI.joyCaption(payload, image, apikey_override, API_URL=self.API_URL)
|
||||
|
||||
pbar.update_absolute(100)
|
||||
return (caption,)
|
||||
|
||||
class joyCaption3API(joyCaption2API):
|
||||
API_URL = f"/supernode/joycaption3"
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"easy joyCaption2API": joyCaption2API,
|
||||
"easy joyCaption3API": joyCaption3API,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"easy joyCaption2API": "JoyCaption2 (BizyAIR)",
|
||||
"easy joyCaption3API": "JoyCaption3 (BizyAIR)",
|
||||
}
|
||||
521
custom_nodes/ComfyUI-Easy-Use/py/nodes/deprecated.py
Normal file
521
custom_nodes/ComfyUI-Easy-Use/py/nodes/deprecated.py
Normal file
@@ -0,0 +1,521 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy
|
||||
import comfy.model_management
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from nodes import ConditioningSetMask, RepeatLatentBatch
|
||||
from comfy_extras.nodes_mask import LatentCompositeMasked
|
||||
from ..libs.log import log_node_info, log_node_warn
|
||||
from ..libs.adv_encode import advanced_encode
|
||||
from ..libs.utils import AlwaysEqualProxy
|
||||
any_type = AlwaysEqualProxy("*")
|
||||
|
||||
|
||||
class If:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"any": (any_type,),
|
||||
"if": (any_type,),
|
||||
"else": (any_type,),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (any_type,)
|
||||
RETURN_NAMES = ("?",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
return (kwargs['if'] if kwargs['any'] else kwargs['else'],)
|
||||
|
||||
|
||||
class poseEditor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("STRING", {"default": ""})
|
||||
}}
|
||||
|
||||
FUNCTION = "output_pose"
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
RETURN_TYPES = ()
|
||||
RETURN_NAMES = ()
|
||||
|
||||
def output_pose(self, image):
|
||||
return ()
|
||||
|
||||
|
||||
class imageToMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE",),
|
||||
"channel": (['red', 'green', 'blue'],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "convert"
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
|
||||
def convert_to_single_channel(self, image, channel='red'):
|
||||
from PIL import Image
|
||||
# Convert to RGB mode to access individual channels
|
||||
image = image.convert('RGB')
|
||||
|
||||
# Extract the desired channel and convert to greyscale
|
||||
if channel == 'red':
|
||||
channel_img = image.split()[0].convert('L')
|
||||
elif channel == 'green':
|
||||
channel_img = image.split()[1].convert('L')
|
||||
elif channel == 'blue':
|
||||
channel_img = image.split()[2].convert('L')
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid channel option. Please choose 'red', 'green', or 'blue'.")
|
||||
|
||||
# Convert the greyscale channel back to RGB mode
|
||||
channel_img = Image.merge(
|
||||
'RGB', (channel_img, channel_img, channel_img))
|
||||
|
||||
return channel_img
|
||||
|
||||
def convert(self, image, channel='red'):
|
||||
from ..libs.image import pil2tensor, tensor2pil
|
||||
image = self.convert_to_single_channel(tensor2pil(image), channel)
|
||||
image = pil2tensor(image)
|
||||
return (image.squeeze().mean(2),)
|
||||
|
||||
# 显示推理时间
|
||||
class showSpentTime:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
},
|
||||
}
|
||||
|
||||
FUNCTION = "notify"
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
RETURN_TYPES = ()
|
||||
RETURN_NAMES = ()
|
||||
|
||||
def notify(self, pipe, spent_time=None, unique_id=None, extra_pnginfo=None):
|
||||
if unique_id and extra_pnginfo and "workflow" in extra_pnginfo:
|
||||
workflow = extra_pnginfo["workflow"]
|
||||
node = next((x for x in workflow["nodes"] if str(x["id"]) == unique_id), None)
|
||||
if node:
|
||||
spent_time = pipe['loader_settings']['spent_time'] if 'spent_time' in pipe['loader_settings'] else ''
|
||||
node["widgets_values"] = [spent_time]
|
||||
|
||||
return {"ui": {"text": [spent_time]}, "result": {}}
|
||||
|
||||
|
||||
# 潜空间sigma相乘
|
||||
class latentNoisy:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS,),
|
||||
"steps": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at_step": ("INT", {"default": 10000, "min": 1, "max": 10000}),
|
||||
"source": (["CPU", "GPU"],),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
},
|
||||
"optional": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"optional_model": ("MODEL",),
|
||||
"optional_latent": ("LATENT",)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE", "LATENT", "FLOAT",)
|
||||
RETURN_NAMES = ("pipe", "latent", "sigma",)
|
||||
FUNCTION = "run"
|
||||
DEPRECATED = True
|
||||
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
|
||||
def run(self, sampler_name, scheduler, steps, start_at_step, end_at_step, source, seed, pipe=None, optional_model=None, optional_latent=None):
|
||||
model = optional_model if optional_model is not None else pipe["model"]
|
||||
batch_size = pipe["loader_settings"]["batch_size"]
|
||||
empty_latent_height = pipe["loader_settings"]["empty_latent_height"]
|
||||
empty_latent_width = pipe["loader_settings"]["empty_latent_width"]
|
||||
|
||||
if optional_latent is not None:
|
||||
samples = optional_latent
|
||||
else:
|
||||
torch.manual_seed(seed)
|
||||
if source == "CPU":
|
||||
device = "cpu"
|
||||
else:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
noise = torch.randn((batch_size, 4, empty_latent_height // 8, empty_latent_width // 8), dtype=torch.float32,
|
||||
device=device).cpu()
|
||||
|
||||
samples = {"samples": noise}
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
end_at_step = min(steps, end_at_step)
|
||||
start_at_step = min(start_at_step, end_at_step)
|
||||
comfy.model_management.load_model_gpu(model)
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model.model, load_device=device, offload_device=comfy.model_management.unet_offload_device())
|
||||
sampler = comfy.samplers.KSampler(model_patcher, steps=steps, device=device, sampler=sampler_name,
|
||||
scheduler=scheduler, denoise=1.0, model_options=model.model_options)
|
||||
sigmas = sampler.sigmas
|
||||
sigma = sigmas[start_at_step] - sigmas[end_at_step]
|
||||
sigma /= model.model.latent_format.scale_factor
|
||||
sigma = sigma.cpu().numpy()
|
||||
|
||||
samples_out = samples.copy()
|
||||
|
||||
s1 = samples["samples"]
|
||||
samples_out["samples"] = s1 * sigma
|
||||
|
||||
if pipe is None:
|
||||
pipe = {}
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"samples": samples_out
|
||||
}
|
||||
del pipe
|
||||
|
||||
return (new_pipe, samples_out, sigma)
|
||||
|
||||
# Latent遮罩复合
|
||||
class latentCompositeMaskedWithCond:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"text_combine": ("LIST",),
|
||||
"source_latent": ("LATENT",),
|
||||
"source_mask": ("MASK",),
|
||||
"destination_mask": ("MASK",),
|
||||
"text_combine_mode": (["add", "replace", "cover"], {"default": "add"}),
|
||||
"replace_text": ("STRING", {"default": ""})
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "my_unique_id": "UNIQUE_ID"},
|
||||
}
|
||||
|
||||
OUTPUT_IS_LIST = (False, False, True)
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE", "LATENT", "CONDITIONING")
|
||||
RETURN_NAMES = ("pipe", "latent", "conditioning",)
|
||||
FUNCTION = "run"
|
||||
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
|
||||
def run(self, pipe, text_combine, source_latent, source_mask, destination_mask, text_combine_mode, replace_text, prompt=None, extra_pnginfo=None, my_unique_id=None):
|
||||
positive = None
|
||||
clip = pipe["clip"]
|
||||
destination_latent = pipe["samples"]
|
||||
|
||||
conds = []
|
||||
|
||||
for text in text_combine:
|
||||
if text_combine_mode == 'cover':
|
||||
positive = text
|
||||
elif text_combine_mode == 'replace' and replace_text != '':
|
||||
positive = pipe["loader_settings"]["positive"].replace(replace_text, text)
|
||||
else:
|
||||
positive = pipe["loader_settings"]["positive"] + ',' + text
|
||||
positive_token_normalization = pipe["loader_settings"]["positive_token_normalization"]
|
||||
positive_weight_interpretation = pipe["loader_settings"]["positive_weight_interpretation"]
|
||||
a1111_prompt_style = pipe["loader_settings"]["a1111_prompt_style"]
|
||||
positive_cond = pipe["positive"]
|
||||
|
||||
log_node_warn("Positive encoding...")
|
||||
steps = pipe["loader_settings"]["steps"] if "steps" in pipe["loader_settings"] else 1
|
||||
positive_embeddings_final = advanced_encode(clip, positive,
|
||||
positive_token_normalization,
|
||||
positive_weight_interpretation, w_max=1.0,
|
||||
apply_to_pooled='enable', a1111_prompt_style=a1111_prompt_style, steps=steps)
|
||||
|
||||
# source cond
|
||||
(cond_1,) = ConditioningSetMask().append(positive_cond, source_mask, "default", 1)
|
||||
(cond_2,) = ConditioningSetMask().append(positive_embeddings_final, destination_mask, "default", 1)
|
||||
positive_cond = cond_1 + cond_2
|
||||
|
||||
conds.append(positive_cond)
|
||||
# latent composite masked
|
||||
(samples,) = LatentCompositeMasked().composite(destination_latent, source_latent, 0, 0, False)
|
||||
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"samples": samples,
|
||||
"loader_settings": {
|
||||
**pipe["loader_settings"],
|
||||
"positive": positive,
|
||||
}
|
||||
}
|
||||
|
||||
del pipe
|
||||
|
||||
return (new_pipe, samples, conds)
|
||||
|
||||
# 噪声注入到潜空间
|
||||
class injectNoiseToLatent:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"strength": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 200.0, "step": 0.0001}),
|
||||
"normalize": ("BOOLEAN", {"default": False}),
|
||||
"average": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"pipe_to_noise": ("PIPE_LINE",),
|
||||
"image_to_latent": ("IMAGE",),
|
||||
"latent": ("LATENT",),
|
||||
"noise": ("LATENT",),
|
||||
"mask": ("MASK",),
|
||||
"mix_randn_amount": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.001}),
|
||||
"seed": ("INT", {"default": 123, "min": 0, "max": 0xffffffffffffffff, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "inject"
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
|
||||
|
||||
def inject(self,strength, normalize, average, pipe_to_noise=None, noise=None, image_to_latent=None, latent=None, mix_randn_amount=0, mask=None, seed=None):
|
||||
|
||||
vae = pipe_to_noise["vae"] if pipe_to_noise is not None else pipe_to_noise["vae"]
|
||||
batch_size = pipe_to_noise["loader_settings"]["batch_size"] if pipe_to_noise is not None and "batch_size" in pipe_to_noise["loader_settings"] else 1
|
||||
if noise is None and pipe_to_noise is not None:
|
||||
noise = pipe_to_noise["samples"]
|
||||
elif noise is None:
|
||||
raise Exception("InjectNoiseToLatent: No noise provided")
|
||||
|
||||
if image_to_latent is not None and vae is not None:
|
||||
samples = {"samples": vae.encode(image_to_latent[:, :, :, :3])}
|
||||
latents = RepeatLatentBatch().repeat(samples, batch_size)[0]
|
||||
elif latent is not None:
|
||||
latents = latent
|
||||
else:
|
||||
latents = {"samples": noise["samples"].clone()}
|
||||
|
||||
samples = latents.copy()
|
||||
if latents["samples"].shape != noise["samples"].shape:
|
||||
raise ValueError("InjectNoiseToLatent: Latent and noise must have the same shape")
|
||||
if average:
|
||||
noised = (samples["samples"].clone() + noise["samples"].clone()) / 2
|
||||
else:
|
||||
noised = samples["samples"].clone() + noise["samples"].clone() * strength
|
||||
if normalize:
|
||||
noised = noised / noised.std()
|
||||
if mask is not None:
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])),
|
||||
size=(noised.shape[2], noised.shape[3]), mode="bilinear")
|
||||
mask = mask.expand((-1, noised.shape[1], -1, -1))
|
||||
if mask.shape[0] < noised.shape[0]:
|
||||
mask = mask.repeat((noised.shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:noised.shape[0]]
|
||||
noised = mask * noised + (1 - mask) * latents["samples"]
|
||||
if mix_randn_amount > 0:
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
rand_noise = torch.randn_like(noised)
|
||||
noised = ((1 - mix_randn_amount) * noised + mix_randn_amount *
|
||||
rand_noise) / ((mix_randn_amount ** 2 + (1 - mix_randn_amount) ** 2) ** 0.5)
|
||||
samples["samples"] = noised
|
||||
return (samples,)
|
||||
|
||||
|
||||
from ..libs.api.stability import stableAPI
|
||||
class stableDiffusion3API:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"positive": ("STRING", {"default": "", "placeholder": "Positive", "multiline": True}),
|
||||
"negative": ("STRING", {"default": "", "placeholder": "Negative", "multiline": True}),
|
||||
"model": (["sd3", "sd3-turbo"],),
|
||||
"aspect_ratio": (['16:9', '1:1', '21:9', '2:3', '3:2', '4:5', '5:4', '9:16', '9:21'],),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 4294967294}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0}),
|
||||
},
|
||||
"optional": {
|
||||
"optional_image": ("IMAGE",),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("image",)
|
||||
|
||||
FUNCTION = "generate"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
DEPRECATED = True
|
||||
|
||||
|
||||
def generate(self, positive, negative, model, aspect_ratio, seed, denoise, optional_image=None, unique_id=None, extra_pnginfo=None):
|
||||
stableAPI.getAPIKeys()
|
||||
mode = 'text-to-image'
|
||||
if optional_image is not None:
|
||||
mode = 'image-to-image'
|
||||
output_image = stableAPI.generate_sd3_image(positive, negative, aspect_ratio, seed=seed, mode=mode, model=model, strength=denoise, image=optional_image)
|
||||
return (output_image,)
|
||||
|
||||
|
||||
class saveImageLazy():
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.compress_level = 4
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE",),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"save_metadata": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional":{},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
OUTPUT_NODE = False
|
||||
FUNCTION = "save"
|
||||
|
||||
DEPRECATED = True
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
|
||||
def save(self, images, filename_prefix, save_metadata, prompt=None, extra_pnginfo=None):
|
||||
extension = 'png'
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
|
||||
results = list()
|
||||
for (batch_number, image) in enumerate(images):
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
metadata = None
|
||||
|
||||
filename_with_batch_num = filename.replace(
|
||||
"%batch_num%", str(batch_number))
|
||||
|
||||
counter = 1
|
||||
|
||||
if os.path.exists(full_output_folder) and os.listdir(full_output_folder):
|
||||
filtered_filenames = list(filter(
|
||||
lambda filename: filename.startswith(
|
||||
filename_with_batch_num + "_")
|
||||
and filename[len(filename_with_batch_num) + 1:-4].isdigit(),
|
||||
os.listdir(full_output_folder)
|
||||
))
|
||||
|
||||
if filtered_filenames:
|
||||
max_counter = max(
|
||||
int(filename[len(filename_with_batch_num) + 1:-4])
|
||||
for filename in filtered_filenames
|
||||
)
|
||||
counter = max_counter + 1
|
||||
|
||||
file = f"{filename_with_batch_num}_{counter:05}.{extension}"
|
||||
|
||||
save_path = os.path.join(full_output_folder, file)
|
||||
|
||||
if save_metadata:
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata.add_text(
|
||||
x, json.dumps(extra_pnginfo[x]))
|
||||
|
||||
img.save(save_path, pnginfo=metadata)
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
return {"ui": {"images": results} , "result": (images,)}
|
||||
|
||||
from .logic import saveText, showAnything
|
||||
|
||||
class showAnythingLazy(showAnything):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {}, "optional": {"anything": (any_type, {}), },
|
||||
"hidden": {"unique_id": "UNIQUE_ID", "extra_pnginfo": "EXTRA_PNGINFO",
|
||||
}}
|
||||
|
||||
RETURN_TYPES = (any_type,)
|
||||
RETURN_NAMES = ('output',)
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_NODE = False
|
||||
OUTPUT_IS_LIST = (False,)
|
||||
DEPRECATED = True
|
||||
FUNCTION = "log_input"
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
|
||||
class saveTextLazy(saveText):
|
||||
|
||||
RETURN_TYPES = ("STRING", "IMAGE")
|
||||
RETURN_NAMES = ("text", 'image',)
|
||||
|
||||
FUNCTION = "save_text"
|
||||
OUTPUT_NODE = False
|
||||
DEPRECATED = True
|
||||
CATEGORY = "EasyUse/🚫 Deprecated"
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"easy if": If,
|
||||
"easy poseEditor": poseEditor,
|
||||
"easy imageToMask": imageToMask,
|
||||
"easy showSpentTime": showSpentTime,
|
||||
"easy latentNoisy": latentNoisy,
|
||||
"easy latentCompositeMaskedWithCond": latentCompositeMaskedWithCond,
|
||||
"easy injectNoiseToLatent": injectNoiseToLatent,
|
||||
"easy stableDiffusion3API": stableDiffusion3API,
|
||||
"easy saveImageLazy": saveImageLazy,
|
||||
"easy saveTextLazy": saveTextLazy,
|
||||
"easy showAnythingLazy": showAnythingLazy,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"easy if": "If (🚫Deprecated)",
|
||||
"easy poseEditor": "PoseEditor (🚫Deprecated)",
|
||||
"easy imageToMask": "ImageToMask (🚫Deprecated)",
|
||||
"easy showSpentTime": "Show Spent Time (🚫Deprecated)",
|
||||
"easy latentNoisy": "LatentNoisy (🚫Deprecated)",
|
||||
"easy latentCompositeMaskedWithCond": "LatentCompositeMaskedWithCond (🚫Deprecated)",
|
||||
"easy injectNoiseToLatent": "InjectNoiseToLatent (🚫Deprecated)",
|
||||
"easy stableDiffusion3API": "StableDiffusion3API (🚫Deprecated)",
|
||||
"easy saveImageLazy": "SaveImageLazy (🚫Deprecated)",
|
||||
"easy saveTextLazy": "SaveTextLazy (🚫Deprecated)",
|
||||
"easy showAnythingLazy": "ShowAnythingLazy (🚫Deprecated)",
|
||||
}
|
||||
643
custom_nodes/ComfyUI-Easy-Use/py/nodes/fix.py
Normal file
643
custom_nodes/ComfyUI-Easy-Use/py/nodes/fix.py
Normal file
@@ -0,0 +1,643 @@
|
||||
import sys
|
||||
import time
|
||||
import comfy
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from comfy_extras.chainner_models import model_loading
|
||||
|
||||
from server import PromptServer
|
||||
from nodes import MAX_RESOLUTION, NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS
|
||||
|
||||
from ..libs.utils import easySave, get_sd_version
|
||||
from ..libs.sampler import easySampler
|
||||
from .. import easyCache, sampler
|
||||
|
||||
class hiresFix:
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos", "bislerp"]
|
||||
crop_methods = ["disabled", "center"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model_name": (folder_paths.get_filename_list("upscale_models"),),
|
||||
"rescale_after_model": ([False, True], {"default": True}),
|
||||
"rescale_method": (s.upscale_methods,),
|
||||
"rescale": (["by percentage", "to Width/Height", 'to longer side - maintain aspect'],),
|
||||
"percent": ("INT", {"default": 50, "min": 0, "max": 1000, "step": 1}),
|
||||
"width": ("INT", {"default": 1024, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 1024, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"longer_side": ("INT", {"default": 1024, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"crop": (s.crop_methods,),
|
||||
"image_output": (["Hide", "Preview", "Save", "Hide&Save", "Sender", "Sender&Save"],{"default": "Preview"}),
|
||||
"link_id": ("INT", {"default": 0, "min": 0, "max": sys.maxsize, "step": 1}),
|
||||
"save_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
},
|
||||
"optional": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "my_unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE", "IMAGE", "LATENT", )
|
||||
RETURN_NAMES = ('pipe', 'image', "latent", )
|
||||
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "EasyUse/Fix"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
return pixels
|
||||
|
||||
def upscale(self, model_name, rescale_after_model, rescale_method, rescale, percent, width, height,
|
||||
longer_side, crop, image_output, link_id, save_prefix, pipe=None, image=None, vae=None, prompt=None,
|
||||
extra_pnginfo=None, my_unique_id=None):
|
||||
|
||||
new_pipe = {}
|
||||
if pipe is not None:
|
||||
image = image if image is not None else pipe["images"]
|
||||
vae = vae if vae is not None else pipe.get("vae")
|
||||
elif image is None or vae is None:
|
||||
raise ValueError("pipe or image or vae missing.")
|
||||
# Load Model
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
upscale_model = model_loading.load_state_dict(sd).eval()
|
||||
|
||||
# Model upscale
|
||||
device = comfy.model_management.get_torch_device()
|
||||
upscale_model.to(device)
|
||||
in_img = image.movedim(-1, -3).to(device)
|
||||
|
||||
tile = 128 + 64
|
||||
overlap = 8
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile,
|
||||
tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap,
|
||||
upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
upscale_model.cpu()
|
||||
s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
|
||||
|
||||
# Post Model Rescale
|
||||
if rescale_after_model == True:
|
||||
samples = s.movedim(-1, 1)
|
||||
orig_height = samples.shape[2]
|
||||
orig_width = samples.shape[3]
|
||||
if rescale == "by percentage" and percent != 0:
|
||||
height = percent / 100 * orig_height
|
||||
width = percent / 100 * orig_width
|
||||
if (width > MAX_RESOLUTION):
|
||||
width = MAX_RESOLUTION
|
||||
if (height > MAX_RESOLUTION):
|
||||
height = MAX_RESOLUTION
|
||||
|
||||
width = easySampler.enforce_mul_of_64(width)
|
||||
height = easySampler.enforce_mul_of_64(height)
|
||||
elif rescale == "to longer side - maintain aspect":
|
||||
longer_side = easySampler.enforce_mul_of_64(longer_side)
|
||||
if orig_width > orig_height:
|
||||
width, height = longer_side, easySampler.enforce_mul_of_64(longer_side * orig_height / orig_width)
|
||||
else:
|
||||
width, height = easySampler.enforce_mul_of_64(longer_side * orig_width / orig_height), longer_side
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, rescale_method, crop)
|
||||
s = s.movedim(1, -1)
|
||||
|
||||
# vae encode
|
||||
pixels = self.vae_encode_crop_pixels(s)
|
||||
t = vae.encode(pixels[:, :, :, :3])
|
||||
|
||||
if pipe is not None:
|
||||
new_pipe = {
|
||||
"model": pipe['model'],
|
||||
"positive": pipe['positive'],
|
||||
"negative": pipe['negative'],
|
||||
"vae": vae,
|
||||
"clip": pipe['clip'],
|
||||
|
||||
"samples": {"samples": t},
|
||||
"images": s,
|
||||
"seed": pipe['seed'],
|
||||
|
||||
"loader_settings": {
|
||||
**pipe["loader_settings"],
|
||||
}
|
||||
}
|
||||
del pipe
|
||||
else:
|
||||
new_pipe = {}
|
||||
|
||||
results = easySave(s, save_prefix, image_output, prompt, extra_pnginfo)
|
||||
|
||||
if image_output in ("Sender", "Sender&Save"):
|
||||
PromptServer.instance.send_sync("img-send", {"link_id": link_id, "images": results})
|
||||
|
||||
if image_output in ("Hide", "Hide&Save"):
|
||||
return (new_pipe, s, {"samples": t},)
|
||||
|
||||
return {"ui": {"images": results},
|
||||
"result": (new_pipe, s, {"samples": t},)}
|
||||
|
||||
# 预细节修复
|
||||
class preDetailerFix:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"guide_size": ("FLOAT", {"default": 256, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"guide_size_for": ("BOOLEAN", {"default": True, "label_on": "bbox", "label_off": "crop_region"}),
|
||||
"max_size": ("FLOAT", {"default": 768, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS + ['align_your_steps'],),
|
||||
"denoise": ("FLOAT", {"default": 0.5, "min": 0.0001, "max": 1.0, "step": 0.01}),
|
||||
"feather": ("INT", {"default": 5, "min": 0, "max": 100, "step": 1}),
|
||||
"noise_mask": ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}),
|
||||
"force_inpaint": ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}),
|
||||
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}),
|
||||
"wildcard": ("STRING", {"multiline": True, "dynamicPrompts": False}),
|
||||
"cycle": ("INT", {"default": 1, "min": 1, "max": 10, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"bbox_segm_pipe": ("PIPE_LINE",),
|
||||
"sam_pipe": ("PIPE_LINE",),
|
||||
"optional_image": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
OUTPUT_IS_LIST = (False,)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Fix"
|
||||
|
||||
def doit(self, pipe, guide_size, guide_size_for, max_size, seed, steps, cfg, sampler_name, scheduler, denoise, feather, noise_mask, force_inpaint, drop_size, wildcard, cycle, bbox_segm_pipe=None, sam_pipe=None, optional_image=None):
|
||||
|
||||
model = pipe["model"] if "model" in pipe else None
|
||||
if model is None:
|
||||
raise Exception(f"[ERROR] pipe['model'] is missing")
|
||||
clip = pipe["clip"] if"clip" in pipe else None
|
||||
if clip is None:
|
||||
raise Exception(f"[ERROR] pipe['clip'] is missing")
|
||||
vae = pipe["vae"] if "vae" in pipe else None
|
||||
if vae is None:
|
||||
raise Exception(f"[ERROR] pipe['vae'] is missing")
|
||||
if optional_image is not None:
|
||||
images = optional_image
|
||||
else:
|
||||
images = pipe["images"] if "images" in pipe else None
|
||||
if images is None:
|
||||
raise Exception(f"[ERROR] pipe['image'] is missing")
|
||||
positive = pipe["positive"] if "positive" in pipe else None
|
||||
if positive is None:
|
||||
raise Exception(f"[ERROR] pipe['positive'] is missing")
|
||||
negative = pipe["negative"] if "negative" in pipe else None
|
||||
if negative is None:
|
||||
raise Exception(f"[ERROR] pipe['negative'] is missing")
|
||||
bbox_segm_pipe = bbox_segm_pipe or (pipe["bbox_segm_pipe"] if pipe and "bbox_segm_pipe" in pipe else None)
|
||||
if bbox_segm_pipe is None:
|
||||
raise Exception(f"[ERROR] bbox_segm_pipe or pipe['bbox_segm_pipe'] is missing")
|
||||
sam_pipe = sam_pipe or (pipe["sam_pipe"] if pipe and "sam_pipe" in pipe else None)
|
||||
if sam_pipe is None:
|
||||
raise Exception(f"[ERROR] sam_pipe or pipe['sam_pipe'] is missing")
|
||||
|
||||
loader_settings = pipe["loader_settings"] if "loader_settings" in pipe else {}
|
||||
|
||||
if(scheduler == 'align_your_steps'):
|
||||
model_version = get_sd_version(model)
|
||||
if model_version == 'sdxl':
|
||||
scheduler = 'AYS SDXL'
|
||||
elif model_version == 'svd':
|
||||
scheduler = 'AYS SVD'
|
||||
else:
|
||||
scheduler = 'AYS SD1'
|
||||
|
||||
new_pipe = {
|
||||
"images": images,
|
||||
"model": model,
|
||||
"clip": clip,
|
||||
"vae": vae,
|
||||
"positive": positive,
|
||||
"negative": negative,
|
||||
"seed": seed,
|
||||
|
||||
"bbox_segm_pipe": bbox_segm_pipe,
|
||||
"sam_pipe": sam_pipe,
|
||||
|
||||
"loader_settings": loader_settings,
|
||||
|
||||
"detail_fix_settings": {
|
||||
"guide_size": guide_size,
|
||||
"guide_size_for": guide_size_for,
|
||||
"max_size": max_size,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"cfg": cfg,
|
||||
"sampler_name": sampler_name,
|
||||
"scheduler": scheduler,
|
||||
"denoise": denoise,
|
||||
"feather": feather,
|
||||
"noise_mask": noise_mask,
|
||||
"force_inpaint": force_inpaint,
|
||||
"drop_size": drop_size,
|
||||
"wildcard": wildcard,
|
||||
"cycle": cycle
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
del bbox_segm_pipe
|
||||
del sam_pipe
|
||||
|
||||
return (new_pipe,)
|
||||
|
||||
# 预遮罩细节修复
|
||||
class preMaskDetailerFix:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"mask": ("MASK",),
|
||||
|
||||
"guide_size": ("FLOAT", {"default": 384, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"guide_size_for": ("BOOLEAN", {"default": True, "label_on": "bbox", "label_off": "crop_region"}),
|
||||
"max_size": ("FLOAT", {"default": 1024, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"mask_mode": ("BOOLEAN", {"default": True, "label_on": "masked only", "label_off": "whole"}),
|
||||
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS,),
|
||||
"denoise": ("FLOAT", {"default": 0.5, "min": 0.0001, "max": 1.0, "step": 0.01}),
|
||||
|
||||
"feather": ("INT", {"default": 5, "min": 0, "max": 100, "step": 1}),
|
||||
"crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 10, "step": 0.1}),
|
||||
"drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}),
|
||||
"refiner_ratio": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 100}),
|
||||
"cycle": ("INT", {"default": 1, "min": 1, "max": 10, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
# "patch": ("INPAINT_PATCH",),
|
||||
"optional_image": ("IMAGE",),
|
||||
"inpaint_model": ("BOOLEAN", {"default": False, "label_on": "enabled", "label_off": "disabled"}),
|
||||
"noise_mask_feather": ("INT", {"default": 20, "min": 0, "max": 100, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
OUTPUT_IS_LIST = (False,)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Fix"
|
||||
|
||||
def doit(self, pipe, mask, guide_size, guide_size_for, max_size, mask_mode, seed, steps, cfg, sampler_name, scheduler, denoise, feather, crop_factor, drop_size,refiner_ratio, batch_size, cycle, optional_image=None, inpaint_model=False, noise_mask_feather=20):
|
||||
|
||||
model = pipe["model"] if "model" in pipe else None
|
||||
if model is None:
|
||||
raise Exception(f"[ERROR] pipe['model'] is missing")
|
||||
clip = pipe["clip"] if"clip" in pipe else None
|
||||
if clip is None:
|
||||
raise Exception(f"[ERROR] pipe['clip'] is missing")
|
||||
vae = pipe["vae"] if "vae" in pipe else None
|
||||
if vae is None:
|
||||
raise Exception(f"[ERROR] pipe['vae'] is missing")
|
||||
if optional_image is not None:
|
||||
images = optional_image
|
||||
else:
|
||||
images = pipe["images"] if "images" in pipe else None
|
||||
if images is None:
|
||||
raise Exception(f"[ERROR] pipe['image'] is missing")
|
||||
positive = pipe["positive"] if "positive" in pipe else None
|
||||
if positive is None:
|
||||
raise Exception(f"[ERROR] pipe['positive'] is missing")
|
||||
negative = pipe["negative"] if "negative" in pipe else None
|
||||
if negative is None:
|
||||
raise Exception(f"[ERROR] pipe['negative'] is missing")
|
||||
latent = pipe["samples"] if "samples" in pipe else None
|
||||
if latent is None:
|
||||
raise Exception(f"[ERROR] pipe['samples'] is missing")
|
||||
|
||||
if 'noise_mask' not in latent:
|
||||
if images is None:
|
||||
raise Exception("No Images found")
|
||||
if vae is None:
|
||||
raise Exception("No VAE found")
|
||||
x = (images.shape[1] // 8) * 8
|
||||
y = (images.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])),
|
||||
size=(images.shape[1], images.shape[2]), mode="bilinear")
|
||||
|
||||
pixels = images.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
mask_erosion = mask
|
||||
|
||||
m = (1.0 - mask.round()).squeeze(1)
|
||||
for i in range(3):
|
||||
pixels[:, :, :, i] -= 0.5
|
||||
pixels[:, :, :, i] *= m
|
||||
pixels[:, :, :, i] += 0.5
|
||||
t = vae.encode(pixels)
|
||||
|
||||
latent = {"samples": t, "noise_mask": (mask_erosion[:, :, :x, :y].round())}
|
||||
# when patch was linked
|
||||
# if patch is not None:
|
||||
# worker = InpaintWorker(node_name="easy kSamplerInpainting")
|
||||
# model, = worker.patch(model, latent, patch)
|
||||
|
||||
loader_settings = pipe["loader_settings"] if "loader_settings" in pipe else {}
|
||||
|
||||
new_pipe = {
|
||||
"images": images,
|
||||
"model": model,
|
||||
"clip": clip,
|
||||
"vae": vae,
|
||||
"positive": positive,
|
||||
"negative": negative,
|
||||
"seed": seed,
|
||||
"mask": mask,
|
||||
|
||||
"loader_settings": loader_settings,
|
||||
|
||||
"detail_fix_settings": {
|
||||
"guide_size": guide_size,
|
||||
"guide_size_for": guide_size_for,
|
||||
"max_size": max_size,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"cfg": cfg,
|
||||
"sampler_name": sampler_name,
|
||||
"scheduler": scheduler,
|
||||
"denoise": denoise,
|
||||
"feather": feather,
|
||||
"crop_factor": crop_factor,
|
||||
"drop_size": drop_size,
|
||||
"refiner_ratio": refiner_ratio,
|
||||
"batch_size": batch_size,
|
||||
"cycle": cycle
|
||||
},
|
||||
|
||||
"mask_settings": {
|
||||
"mask_mode": mask_mode,
|
||||
"inpaint_model": inpaint_model,
|
||||
"noise_mask_feather": noise_mask_feather
|
||||
}
|
||||
}
|
||||
|
||||
del pipe
|
||||
|
||||
return (new_pipe,)
|
||||
|
||||
# 细节修复
|
||||
class detailerFix:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"image_output": (["Hide", "Preview", "Save", "Hide&Save", "Sender", "Sender&Save"],{"default": "Preview"}),
|
||||
"link_id": ("INT", {"default": 0, "min": 0, "max": sys.maxsize, "step": 1}),
|
||||
"save_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
},
|
||||
"optional": {
|
||||
"model": ("MODEL",),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "my_unique_id": "UNIQUE_ID", }
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE", "IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("pipe", "image", "cropped_refined", "cropped_enhanced_alpha")
|
||||
OUTPUT_NODE = True
|
||||
OUTPUT_IS_LIST = (False, False, True, True)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Fix"
|
||||
|
||||
|
||||
def doit(self, pipe, image_output, link_id, save_prefix, model=None, prompt=None, extra_pnginfo=None, my_unique_id=None):
|
||||
|
||||
# Clean loaded_objects
|
||||
easyCache.update_loaded_objects(prompt)
|
||||
|
||||
my_unique_id = int(my_unique_id)
|
||||
|
||||
model = model or (pipe["model"] if "model" in pipe else None)
|
||||
if model is None:
|
||||
raise Exception(f"[ERROR] model or pipe['model'] is missing")
|
||||
|
||||
detail_fix_settings = pipe["detail_fix_settings"] if "detail_fix_settings" in pipe else None
|
||||
if detail_fix_settings is None:
|
||||
raise Exception(f"[ERROR] detail_fix_settings or pipe['detail_fix_settings'] is missing")
|
||||
|
||||
mask = pipe["mask"] if "mask" in pipe else None
|
||||
|
||||
image = pipe["images"]
|
||||
clip = pipe["clip"]
|
||||
vae = pipe["vae"]
|
||||
seed = pipe["seed"]
|
||||
positive = pipe["positive"]
|
||||
negative = pipe["negative"]
|
||||
loader_settings = pipe["loader_settings"] if "loader_settings" in pipe else {}
|
||||
guide_size = pipe["detail_fix_settings"]["guide_size"] if "guide_size" in pipe["detail_fix_settings"] else 256
|
||||
guide_size_for = pipe["detail_fix_settings"]["guide_size_for"] if "guide_size_for" in pipe[
|
||||
"detail_fix_settings"] else True
|
||||
max_size = pipe["detail_fix_settings"]["max_size"] if "max_size" in pipe["detail_fix_settings"] else 768
|
||||
steps = pipe["detail_fix_settings"]["steps"] if "steps" in pipe["detail_fix_settings"] else 20
|
||||
cfg = pipe["detail_fix_settings"]["cfg"] if "cfg" in pipe["detail_fix_settings"] else 1.0
|
||||
sampler_name = pipe["detail_fix_settings"]["sampler_name"] if "sampler_name" in pipe[
|
||||
"detail_fix_settings"] else None
|
||||
scheduler = pipe["detail_fix_settings"]["scheduler"] if "scheduler" in pipe["detail_fix_settings"] else None
|
||||
denoise = pipe["detail_fix_settings"]["denoise"] if "denoise" in pipe["detail_fix_settings"] else 0.5
|
||||
feather = pipe["detail_fix_settings"]["feather"] if "feather" in pipe["detail_fix_settings"] else 5
|
||||
crop_factor = pipe["detail_fix_settings"]["crop_factor"] if "crop_factor" in pipe["detail_fix_settings"] else 3.0
|
||||
drop_size = pipe["detail_fix_settings"]["drop_size"] if "drop_size" in pipe["detail_fix_settings"] else 10
|
||||
refiner_ratio = pipe["detail_fix_settings"]["refiner_ratio"] if "refiner_ratio" in pipe else 0.2
|
||||
batch_size = pipe["detail_fix_settings"]["batch_size"] if "batch_size" in pipe["detail_fix_settings"] else 1
|
||||
noise_mask = pipe["detail_fix_settings"]["noise_mask"] if "noise_mask" in pipe["detail_fix_settings"] else None
|
||||
force_inpaint = pipe["detail_fix_settings"]["force_inpaint"] if "force_inpaint" in pipe["detail_fix_settings"] else False
|
||||
wildcard = pipe["detail_fix_settings"]["wildcard"] if "wildcard" in pipe["detail_fix_settings"] else ""
|
||||
cycle = pipe["detail_fix_settings"]["cycle"] if "cycle" in pipe["detail_fix_settings"] else 1
|
||||
|
||||
bbox_segm_pipe = pipe["bbox_segm_pipe"] if pipe and "bbox_segm_pipe" in pipe else None
|
||||
sam_pipe = pipe["sam_pipe"] if "sam_pipe" in pipe else None
|
||||
|
||||
# 细节修复初始时间
|
||||
start_time = int(time.time() * 1000)
|
||||
if "mask_settings" in pipe:
|
||||
mask_mode = pipe['mask_settings']["mask_mode"] if "inpaint_model" in pipe['mask_settings'] else True
|
||||
inpaint_model = pipe['mask_settings']["inpaint_model"] if "inpaint_model" in pipe['mask_settings'] else False
|
||||
noise_mask_feather = pipe['mask_settings']["noise_mask_feather"] if "noise_mask_feather" in pipe['mask_settings'] else 20
|
||||
cls = ALL_NODE_CLASS_MAPPINGS["MaskDetailerPipe"]
|
||||
if "MaskDetailerPipe" not in ALL_NODE_CLASS_MAPPINGS:
|
||||
raise Exception(f"[ERROR] To use MaskDetailerPipe, you need to install 'Impact Pack'")
|
||||
basic_pipe = (model, clip, vae, positive, negative)
|
||||
result_img, result_cropped_enhanced, result_cropped_enhanced_alpha, basic_pipe, refiner_basic_pipe_opt = cls().doit(image, mask, basic_pipe, guide_size, guide_size_for, max_size, mask_mode,
|
||||
seed, steps, cfg, sampler_name, scheduler, denoise,
|
||||
feather, crop_factor, drop_size, refiner_ratio, batch_size, cycle=1,
|
||||
refiner_basic_pipe_opt=None, detailer_hook=None, inpaint_model=inpaint_model, noise_mask_feather=noise_mask_feather)
|
||||
result_mask = mask
|
||||
result_cnet_images = ()
|
||||
else:
|
||||
if bbox_segm_pipe is None:
|
||||
raise Exception(f"[ERROR] bbox_segm_pipe or pipe['bbox_segm_pipe'] is missing")
|
||||
if sam_pipe is None:
|
||||
raise Exception(f"[ERROR] sam_pipe or pipe['sam_pipe'] is missing")
|
||||
bbox_detector_opt, bbox_threshold, bbox_dilation, bbox_crop_factor, segm_detector_opt = bbox_segm_pipe
|
||||
sam_model_opt, sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion, sam_mask_hint_threshold, sam_mask_hint_use_negative = sam_pipe
|
||||
if "FaceDetailer" not in ALL_NODE_CLASS_MAPPINGS:
|
||||
raise Exception(f"[ERROR] To use FaceDetailer, you need to install 'Impact Pack'")
|
||||
cls = ALL_NODE_CLASS_MAPPINGS["FaceDetailer"]
|
||||
|
||||
result_img, result_cropped_enhanced, result_cropped_enhanced_alpha, result_mask, pipe, result_cnet_images = cls().doit(
|
||||
image, model, clip, vae, guide_size, guide_size_for, max_size, seed, steps, cfg, sampler_name,
|
||||
scheduler,
|
||||
positive, negative, denoise, feather, noise_mask, force_inpaint,
|
||||
bbox_threshold, bbox_dilation, bbox_crop_factor,
|
||||
sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion, sam_mask_hint_threshold,
|
||||
sam_mask_hint_use_negative, drop_size, bbox_detector_opt, wildcard, cycle, sam_model_opt,
|
||||
segm_detector_opt,
|
||||
detailer_hook=None)
|
||||
|
||||
# 细节修复结束时间
|
||||
end_time = int(time.time() * 1000)
|
||||
|
||||
spent_time = 'Fix:' + str((end_time - start_time) / 1000) + '"'
|
||||
|
||||
results = easySave(result_img, save_prefix, image_output, prompt, extra_pnginfo)
|
||||
sampler.update_value_by_id("results", my_unique_id, results)
|
||||
|
||||
# Clean loaded_objects
|
||||
easyCache.update_loaded_objects(prompt)
|
||||
|
||||
new_pipe = {
|
||||
"samples": None,
|
||||
"images": result_img,
|
||||
"model": model,
|
||||
"clip": clip,
|
||||
"vae": vae,
|
||||
"seed": seed,
|
||||
"positive": positive,
|
||||
"negative": negative,
|
||||
"wildcard": wildcard,
|
||||
"bbox_segm_pipe": bbox_segm_pipe,
|
||||
"sam_pipe": sam_pipe,
|
||||
|
||||
"loader_settings": {
|
||||
**loader_settings,
|
||||
"spent_time": spent_time
|
||||
},
|
||||
"detail_fix_settings": detail_fix_settings
|
||||
}
|
||||
if "mask_settings" in pipe:
|
||||
new_pipe["mask_settings"] = pipe["mask_settings"]
|
||||
|
||||
sampler.update_value_by_id("pipe_line", my_unique_id, new_pipe)
|
||||
|
||||
del bbox_segm_pipe
|
||||
del sam_pipe
|
||||
del pipe
|
||||
|
||||
if image_output in ("Hide", "Hide&Save"):
|
||||
return (new_pipe, result_img, result_cropped_enhanced, result_cropped_enhanced_alpha, result_mask, result_cnet_images)
|
||||
|
||||
if image_output in ("Sender", "Sender&Save"):
|
||||
PromptServer.instance.send_sync("img-send", {"link_id": link_id, "images": results})
|
||||
|
||||
return {"ui": {"images": results}, "result": (new_pipe, result_img, result_cropped_enhanced, result_cropped_enhanced_alpha, result_mask, result_cnet_images )}
|
||||
|
||||
class ultralyticsDetectorForDetailerFix:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
bboxs = ["bbox/" + x for x in folder_paths.get_filename_list("ultralytics_bbox")]
|
||||
segms = ["segm/" + x for x in folder_paths.get_filename_list("ultralytics_segm")]
|
||||
return {"required":
|
||||
{"model_name": (bboxs + segms,),
|
||||
"bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"bbox_dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}),
|
||||
"bbox_crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 10, "step": 0.1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("bbox_segm_pipe",)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Fix"
|
||||
|
||||
def doit(self, model_name, bbox_threshold, bbox_dilation, bbox_crop_factor):
|
||||
if 'UltralyticsDetectorProvider' not in ALL_NODE_CLASS_MAPPINGS:
|
||||
raise Exception(f"[ERROR] To use UltralyticsDetectorProvider, you need to install 'Impact Pack'")
|
||||
cls = ALL_NODE_CLASS_MAPPINGS['UltralyticsDetectorProvider']
|
||||
bbox_detector, segm_detector = cls().doit(model_name)
|
||||
pipe = (bbox_detector, bbox_threshold, bbox_dilation, bbox_crop_factor, segm_detector)
|
||||
return (pipe,)
|
||||
|
||||
class samLoaderForDetailerFix:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (folder_paths.get_filename_list("sams"),),
|
||||
"device_mode": (["AUTO", "Prefer GPU", "CPU"],{"default": "AUTO"}),
|
||||
"sam_detection_hint": (
|
||||
["center-1", "horizontal-2", "vertical-2", "rect-4", "diamond-4", "mask-area", "mask-points",
|
||||
"mask-point-bbox", "none"],),
|
||||
"sam_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
|
||||
"sam_threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"sam_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
|
||||
"sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"sam_mask_hint_use_negative": (["False", "Small", "Outter"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("sam_pipe",)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Fix"
|
||||
|
||||
def doit(self, model_name, device_mode, sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion, sam_mask_hint_threshold, sam_mask_hint_use_negative):
|
||||
if 'SAMLoader' not in ALL_NODE_CLASS_MAPPINGS:
|
||||
raise Exception(f"[ERROR] To use SAMLoader, you need to install 'Impact Pack'")
|
||||
cls = ALL_NODE_CLASS_MAPPINGS['SAMLoader']
|
||||
(sam_model,) = cls().load_model(model_name, device_mode)
|
||||
pipe = (sam_model, sam_detection_hint, sam_dilation, sam_threshold, sam_bbox_expansion, sam_mask_hint_threshold, sam_mask_hint_use_negative)
|
||||
return (pipe,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"easy hiresFix": hiresFix,
|
||||
"easy preDetailerFix": preDetailerFix,
|
||||
"easy preMaskDetailerFix": preMaskDetailerFix,
|
||||
"easy ultralyticsDetectorPipe": ultralyticsDetectorForDetailerFix,
|
||||
"easy samLoaderPipe": samLoaderForDetailerFix,
|
||||
"easy detailerFix": detailerFix
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"easy hiresFix": "HiresFix",
|
||||
"easy preDetailerFix": "PreDetailerFix",
|
||||
"easy preMaskDetailerFix": "preMaskDetailerFix",
|
||||
"easy ultralyticsDetectorPipe": "UltralyticsDetector (Pipe)",
|
||||
"easy samLoaderPipe": "SAMLoader (Pipe)",
|
||||
"easy detailerFix": "DetailerFix",
|
||||
}
|
||||
2205
custom_nodes/ComfyUI-Easy-Use/py/nodes/image.py
Normal file
2205
custom_nodes/ComfyUI-Easy-Use/py/nodes/image.py
Normal file
File diff suppressed because it is too large
Load Diff
356
custom_nodes/ComfyUI-Easy-Use/py/nodes/inpaint.py
Normal file
356
custom_nodes/ComfyUI-Easy-Use/py/nodes/inpaint.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import re
|
||||
import torch
|
||||
import comfy
|
||||
from comfy_extras.nodes_mask import GrowMask
|
||||
from nodes import VAEEncodeForInpaint, NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS
|
||||
from ..libs.utils import get_local_filepath
|
||||
from ..libs.log import log_node_info
|
||||
from ..libs import cache as backend_cache
|
||||
from ..config import *
|
||||
|
||||
# FooocusInpaint
|
||||
class applyFooocusInpaint:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"latent": ("LATENT",),
|
||||
"head": (list(FOOOCUS_INPAINT_HEAD.keys()),),
|
||||
"patch": (list(FOOOCUS_INPAINT_PATCH.keys()),),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
CATEGORY = "EasyUse/Inpaint"
|
||||
FUNCTION = "apply"
|
||||
|
||||
def apply(self, model, latent, head, patch):
|
||||
from ..modules.fooocus import InpaintHead, InpaintWorker
|
||||
head_file = get_local_filepath(FOOOCUS_INPAINT_HEAD[head]["model_url"], INPAINT_DIR)
|
||||
inpaint_head_model = InpaintHead()
|
||||
sd = torch.load(head_file, map_location='cpu')
|
||||
inpaint_head_model.load_state_dict(sd)
|
||||
|
||||
patch_file = get_local_filepath(FOOOCUS_INPAINT_PATCH[patch]["model_url"], INPAINT_DIR)
|
||||
inpaint_lora = comfy.utils.load_torch_file(patch_file, safe_load=True)
|
||||
|
||||
patch = (inpaint_head_model, inpaint_lora)
|
||||
worker = InpaintWorker(node_name="easy kSamplerInpainting")
|
||||
cloned = model.clone()
|
||||
|
||||
m, = worker.patch(cloned, latent, patch)
|
||||
return (m,)
|
||||
|
||||
# brushnet
|
||||
from ..modules.brushnet import BrushNet
|
||||
class applyBrushNet:
|
||||
|
||||
def get_files_with_extension(folder='inpaint', extensions='.safetensors'):
|
||||
return [file for file in folder_paths.get_filename_list(folder) if file.endswith(extensions)]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"image": ("IMAGE",),
|
||||
"mask": ("MASK",),
|
||||
"brushnet": (s.get_files_with_extension(),),
|
||||
"dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
|
||||
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
CATEGORY = "EasyUse/Inpaint"
|
||||
FUNCTION = "apply"
|
||||
|
||||
def apply(self, pipe, image, mask, brushnet, dtype, scale, start_at, end_at):
|
||||
|
||||
model = pipe['model']
|
||||
vae = pipe['vae']
|
||||
positive = pipe['positive']
|
||||
negative = pipe['negative']
|
||||
cls = BrushNet()
|
||||
if brushnet in backend_cache.cache:
|
||||
log_node_info("easy brushnetApply", f"Using {brushnet} Cached")
|
||||
_, brushnet_model = backend_cache.cache[brushnet][1]
|
||||
else:
|
||||
brushnet_file = os.path.join(folder_paths.get_full_path("inpaint", brushnet))
|
||||
brushnet_model, = cls.load_brushnet_model(brushnet_file, dtype)
|
||||
backend_cache.update_cache(brushnet, 'brushnet', (False, brushnet_model))
|
||||
m, positive, negative, latent = cls.brushnet_model_update(model=model, vae=vae, image=image, mask=mask,
|
||||
brushnet=brushnet_model, positive=positive,
|
||||
negative=negative, scale=scale, start_at=start_at,
|
||||
end_at=end_at)
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"model": m,
|
||||
"positive": positive,
|
||||
"negative": negative,
|
||||
"samples": latent,
|
||||
}
|
||||
del pipe
|
||||
return (new_pipe,)
|
||||
|
||||
# #powerpaint
|
||||
class applyPowerPaint:
|
||||
def get_files_with_extension(folder='inpaint', extensions='.safetensors'):
|
||||
return [file for file in folder_paths.get_filename_list(folder) if file.endswith(extensions)]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"image": ("IMAGE",),
|
||||
"mask": ("MASK",),
|
||||
"powerpaint_model": (s.get_files_with_extension(),),
|
||||
"powerpaint_clip": (s.get_files_with_extension(extensions='.bin'),),
|
||||
"dtype": (['float16', 'bfloat16', 'float32', 'float64'],),
|
||||
"fitting": ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
|
||||
"function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'],),
|
||||
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
"save_memory": (['none', 'auto', 'max'],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
CATEGORY = "EasyUse/Inpaint"
|
||||
FUNCTION = "apply"
|
||||
|
||||
def apply(self, pipe, image, mask, powerpaint_model, powerpaint_clip, dtype, fitting, function, scale, start_at, end_at, save_memory='none'):
|
||||
model = pipe['model']
|
||||
vae = pipe['vae']
|
||||
positive = pipe['positive']
|
||||
negative = pipe['negative']
|
||||
|
||||
cls = BrushNet()
|
||||
# load powerpaint clip
|
||||
if powerpaint_clip in backend_cache.cache:
|
||||
log_node_info("easy powerpaintApply", f"Using {powerpaint_clip} Cached")
|
||||
_, ppclip = backend_cache.cache[powerpaint_clip][1]
|
||||
else:
|
||||
model_url = POWERPAINT_MODELS['base_fp16']['model_url']
|
||||
base_clip = get_local_filepath(model_url, os.path.join(folder_paths.models_dir, 'clip'))
|
||||
ppclip, = cls.load_powerpaint_clip(base_clip, os.path.join(folder_paths.get_full_path("inpaint", powerpaint_clip)))
|
||||
backend_cache.update_cache(powerpaint_clip, 'ppclip', (False, ppclip))
|
||||
# load powerpaint model
|
||||
if powerpaint_model in backend_cache.cache:
|
||||
log_node_info("easy powerpaintApply", f"Using {powerpaint_model} Cached")
|
||||
_, powerpaint = backend_cache.cache[powerpaint_model][1]
|
||||
else:
|
||||
powerpaint_file = os.path.join(folder_paths.get_full_path("inpaint", powerpaint_model))
|
||||
powerpaint, = cls.load_brushnet_model(powerpaint_file, dtype)
|
||||
backend_cache.update_cache(powerpaint_model, 'powerpaint', (False, powerpaint))
|
||||
m, positive, negative, latent = cls.powerpaint_model_update(model=model, vae=vae, image=image, mask=mask, powerpaint=powerpaint,
|
||||
clip=ppclip, positive=positive,
|
||||
negative=negative, fitting=fitting, function=function,
|
||||
scale=scale, start_at=start_at, end_at=end_at, save_memory=save_memory)
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"model": m,
|
||||
"positive": positive,
|
||||
"negative": negative,
|
||||
"samples": latent,
|
||||
}
|
||||
del pipe
|
||||
return (new_pipe,)
|
||||
|
||||
from node_helpers import conditioning_set_values
|
||||
class applyInpaint:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"image": ("IMAGE",),
|
||||
"mask": ("MASK",),
|
||||
"inpaint_mode": (('normal', 'fooocus_inpaint', 'brushnet_random', 'brushnet_segmentation', 'powerpaint'),),
|
||||
"encode": (('none', 'vae_encode_inpaint', 'inpaint_model_conditioning', 'different_diffusion'), {"default": "none"}),
|
||||
"grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),
|
||||
"dtype": (['float16', 'bfloat16', 'float32', 'float64'],),
|
||||
"fitting": ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
|
||||
"function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'],),
|
||||
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
},
|
||||
"optional":{
|
||||
"noise_mask": ("BOOLEAN", {"default": True})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
CATEGORY = "EasyUse/Inpaint"
|
||||
FUNCTION = "apply"
|
||||
|
||||
def inpaint_model_conditioning(self, pipe, image, vae, mask, grow_mask_by, noise_mask=True):
|
||||
if grow_mask_by >0:
|
||||
mask, = GrowMask().expand_mask(mask, grow_mask_by, False)
|
||||
positive, negative, = pipe['positive'], pipe['negative']
|
||||
|
||||
pixels = image
|
||||
x = (pixels.shape[1] // 8) * 8
|
||||
y = (pixels.shape[2] // 8) * 8
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])),
|
||||
size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
|
||||
|
||||
orig_pixels = pixels
|
||||
pixels = orig_pixels.clone()
|
||||
if pixels.shape[1] != x or pixels.shape[2] != y:
|
||||
x_offset = (pixels.shape[1] % 8) // 2
|
||||
y_offset = (pixels.shape[2] % 8) // 2
|
||||
pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
|
||||
mask = mask[:, :, x_offset:x + x_offset, y_offset:y + y_offset]
|
||||
|
||||
m = (1.0 - mask.round()).squeeze(1)
|
||||
for i in range(3):
|
||||
pixels[:, :, :, i] -= 0.5
|
||||
pixels[:, :, :, i] *= m
|
||||
pixels[:, :, :, i] += 0.5
|
||||
concat_latent = vae.encode(pixels)
|
||||
orig_latent = vae.encode(orig_pixels)
|
||||
|
||||
out_latent = {}
|
||||
|
||||
out_latent["samples"] = orig_latent
|
||||
if noise_mask:
|
||||
out_latent["noise_mask"] = mask
|
||||
|
||||
out = []
|
||||
for conditioning in [positive, negative]:
|
||||
c = conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
|
||||
"concat_mask": mask})
|
||||
out.append(c)
|
||||
|
||||
pipe['positive'] = out[0]
|
||||
pipe['negative'] = out[1]
|
||||
pipe['samples'] = out_latent
|
||||
|
||||
return pipe
|
||||
|
||||
def get_brushnet_model(self, type, model):
|
||||
model_type = 'sdxl' if isinstance(model.model.model_config, comfy.supported_models.SDXL) else 'sd1'
|
||||
if type == 'brushnet_random':
|
||||
brush_model = BRUSHNET_MODELS['random_mask'][model_type]['model_url']
|
||||
if model_type == 'sdxl':
|
||||
pattern = 'brushnet.random.mask.sdxl.*.(safetensors|bin)$'
|
||||
else:
|
||||
pattern = 'brushnet.random.mask.*.(safetensors|bin)$'
|
||||
elif type == 'brushnet_segmentation':
|
||||
brush_model = BRUSHNET_MODELS['segmentation_mask'][model_type]['model_url']
|
||||
if model_type == 'sdxl':
|
||||
pattern = 'brushnet.segmentation.mask.sdxl.*.(safetensors|bin)$'
|
||||
else:
|
||||
pattern = 'brushnet.segmentation.mask.*.(safetensors|bin)$'
|
||||
|
||||
|
||||
brushfile = [e for e in folder_paths.get_filename_list('inpaint') if re.search(pattern, e, re.IGNORECASE)]
|
||||
brushname = brushfile[0] if brushfile else None
|
||||
if not brushname:
|
||||
from urllib.parse import urlparse
|
||||
get_local_filepath(brush_model, INPAINT_DIR)
|
||||
parsed_url = urlparse(brush_model)
|
||||
brushname = os.path.basename(parsed_url.path)
|
||||
return brushname
|
||||
|
||||
def get_powerpaint_model(self, model):
|
||||
model_type = 'sdxl' if isinstance(model.model.model_config, comfy.supported_models.SDXL) else 'sd1'
|
||||
if model_type == 'sdxl':
|
||||
raise Exception("Powerpaint not supported for SDXL models")
|
||||
|
||||
powerpaint_model = POWERPAINT_MODELS['v2.1']['model_url']
|
||||
powerpaint_clip = POWERPAINT_MODELS['v2.1']['clip_url']
|
||||
|
||||
from urllib.parse import urlparse
|
||||
get_local_filepath(powerpaint_model, os.path.join(INPAINT_DIR, 'powerpaint'))
|
||||
model_parsed_url = urlparse(powerpaint_model)
|
||||
clip_parsed_url = urlparse(powerpaint_clip)
|
||||
model_name = os.path.join("powerpaint",os.path.basename(model_parsed_url.path))
|
||||
clip_name = os.path.join("powerpaint",os.path.basename(clip_parsed_url.path))
|
||||
return model_name, clip_name
|
||||
|
||||
def apply(self, pipe, image, mask, inpaint_mode, encode, grow_mask_by, dtype, fitting, function, scale, start_at, end_at, noise_mask=True):
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
}
|
||||
del pipe
|
||||
if inpaint_mode in ['brushnet_random', 'brushnet_segmentation']:
|
||||
brushnet = self.get_brushnet_model(inpaint_mode, new_pipe['model'])
|
||||
new_pipe, = applyBrushNet().apply(new_pipe, image, mask, brushnet, dtype, scale, start_at, end_at)
|
||||
elif inpaint_mode == 'powerpaint':
|
||||
powerpaint_model, powerpaint_clip = self.get_powerpaint_model(new_pipe['model'])
|
||||
new_pipe, = applyPowerPaint().apply(new_pipe, image, mask, powerpaint_model, powerpaint_clip, dtype, fitting, function, scale, start_at, end_at)
|
||||
|
||||
vae = new_pipe['vae']
|
||||
if encode == 'none':
|
||||
if inpaint_mode == 'fooocus_inpaint':
|
||||
model, = applyFooocusInpaint().apply(new_pipe['model'], new_pipe['samples'],
|
||||
list(FOOOCUS_INPAINT_HEAD.keys())[0],
|
||||
list(FOOOCUS_INPAINT_PATCH.keys())[0])
|
||||
new_pipe['model'] = model
|
||||
elif encode == 'vae_encode_inpaint':
|
||||
latent, = VAEEncodeForInpaint().encode(vae, image, mask, grow_mask_by)
|
||||
new_pipe['samples'] = latent
|
||||
if inpaint_mode == 'fooocus_inpaint':
|
||||
model, = applyFooocusInpaint().apply(new_pipe['model'], new_pipe['samples'],
|
||||
list(FOOOCUS_INPAINT_HEAD.keys())[0],
|
||||
list(FOOOCUS_INPAINT_PATCH.keys())[0])
|
||||
new_pipe['model'] = model
|
||||
elif encode == 'inpaint_model_conditioning':
|
||||
if inpaint_mode == 'fooocus_inpaint':
|
||||
latent, = VAEEncodeForInpaint().encode(vae, image, mask, grow_mask_by)
|
||||
new_pipe['samples'] = latent
|
||||
model, = applyFooocusInpaint().apply(new_pipe['model'], new_pipe['samples'],
|
||||
list(FOOOCUS_INPAINT_HEAD.keys())[0],
|
||||
list(FOOOCUS_INPAINT_PATCH.keys())[0])
|
||||
new_pipe['model'] = model
|
||||
new_pipe = self.inpaint_model_conditioning(new_pipe, image, vae, mask, 0, noise_mask=noise_mask)
|
||||
else:
|
||||
new_pipe = self.inpaint_model_conditioning(new_pipe, image, vae, mask, grow_mask_by, noise_mask=noise_mask)
|
||||
elif encode == 'different_diffusion':
|
||||
if inpaint_mode == 'fooocus_inpaint':
|
||||
latent, = VAEEncodeForInpaint().encode(vae, image, mask, grow_mask_by)
|
||||
new_pipe['samples'] = latent
|
||||
model, = applyFooocusInpaint().apply(new_pipe['model'], new_pipe['samples'],
|
||||
list(FOOOCUS_INPAINT_HEAD.keys())[0],
|
||||
list(FOOOCUS_INPAINT_PATCH.keys())[0])
|
||||
new_pipe['model'] = model
|
||||
new_pipe = self.inpaint_model_conditioning(new_pipe, image, vae, mask, 0, noise_mask=noise_mask)
|
||||
else:
|
||||
new_pipe = self.inpaint_model_conditioning(new_pipe, image, vae, mask, grow_mask_by, noise_mask=noise_mask)
|
||||
cls = ALL_NODE_CLASS_MAPPINGS['DifferentialDiffusion']
|
||||
if cls is not None:
|
||||
try:
|
||||
model, = cls().execute(new_pipe['model'])
|
||||
except Exception:
|
||||
model, = cls().apply(new_pipe['model'])
|
||||
new_pipe['model'] = model
|
||||
else:
|
||||
raise Exception("Differential Diffusion not found,please update comfyui")
|
||||
|
||||
return (new_pipe,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"easy applyFooocusInpaint": applyFooocusInpaint,
|
||||
"easy applyBrushNet": applyBrushNet,
|
||||
"easy applyPowerPaint": applyPowerPaint,
|
||||
"easy applyInpaint": applyInpaint
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"easy applyFooocusInpaint": "Easy Apply Fooocus Inpaint",
|
||||
"easy applyBrushNet": "Easy Apply BrushNet",
|
||||
"easy applyPowerPaint": "Easy Apply PowerPaint",
|
||||
"easy applyInpaint": "Easy Apply Inpaint"
|
||||
}
|
||||
1565
custom_nodes/ComfyUI-Easy-Use/py/nodes/loaders.py
Normal file
1565
custom_nodes/ComfyUI-Easy-Use/py/nodes/loaders.py
Normal file
File diff suppressed because it is too large
Load Diff
2013
custom_nodes/ComfyUI-Easy-Use/py/nodes/logic.py
Executable file
2013
custom_nodes/ComfyUI-Easy-Use/py/nodes/logic.py
Executable file
File diff suppressed because it is too large
Load Diff
778
custom_nodes/ComfyUI-Easy-Use/py/nodes/pipe.py
Normal file
778
custom_nodes/ComfyUI-Easy-Use/py/nodes/pipe.py
Normal file
@@ -0,0 +1,778 @@
|
||||
import os
|
||||
import folder_paths
|
||||
import comfy.samplers, comfy.supported_models
|
||||
|
||||
from nodes import LatentFromBatch, RepeatLatentBatch
|
||||
from ..config import MAX_SEED_NUM
|
||||
|
||||
from ..libs.log import log_node_warn
|
||||
from ..libs.utils import get_sd_version
|
||||
from ..libs.conditioning import prompt_to_cond, set_cond
|
||||
|
||||
from .. import easyCache
|
||||
|
||||
# 节点束输入
|
||||
class pipeIn:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {},
|
||||
"optional": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"model": ("MODEL",),
|
||||
"pos": ("CONDITIONING",),
|
||||
"neg": ("CONDITIONING",),
|
||||
"latent": ("LATENT",),
|
||||
"vae": ("VAE",),
|
||||
"clip": ("CLIP",),
|
||||
"image": ("IMAGE",),
|
||||
"xyPlot": ("XYPLOT",),
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
FUNCTION = "flush"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def flush(self, pipe=None, model=None, pos=None, neg=None, latent=None, vae=None, clip=None, image=None, xyplot=None, my_unique_id=None):
|
||||
|
||||
model = model if model is not None else pipe.get("model")
|
||||
if model is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Model missing from pipeLine")
|
||||
pos = pos if pos is not None else pipe.get("positive")
|
||||
if pos is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Pos Conditioning missing from pipeLine")
|
||||
neg = neg if neg is not None else pipe.get("negative")
|
||||
if neg is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Neg Conditioning missing from pipeLine")
|
||||
vae = vae if vae is not None else pipe.get("vae")
|
||||
if vae is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "VAE missing from pipeLine")
|
||||
clip = clip if clip is not None else pipe.get("clip") if pipe is not None and "clip" in pipe else None
|
||||
# if clip is None:
|
||||
# log_node_warn(f'pipeIn[{my_unique_id}]', "Clip missing from pipeLine")
|
||||
if latent is not None:
|
||||
samples = latent
|
||||
elif image is None:
|
||||
samples = pipe.get("samples") if pipe is not None else None
|
||||
image = pipe.get("images") if pipe is not None else None
|
||||
elif image is not None:
|
||||
if pipe is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = pipe["loader_settings"]["batch_size"] if "batch_size" in pipe["loader_settings"] else 1
|
||||
samples = {"samples": vae.encode(image[:, :, :, :3])}
|
||||
samples = RepeatLatentBatch().repeat(samples, batch_size)[0]
|
||||
|
||||
if pipe is None:
|
||||
pipe = {"loader_settings": {"positive": "", "negative": "", "xyplot": None}}
|
||||
|
||||
xyplot = xyplot if xyplot is not None else pipe['loader_settings']['xyplot'] if xyplot in pipe['loader_settings'] else None
|
||||
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"model": model,
|
||||
"positive": pos,
|
||||
"negative": neg,
|
||||
"vae": vae,
|
||||
"clip": clip,
|
||||
|
||||
"samples": samples,
|
||||
"images": image,
|
||||
"seed": pipe.get('seed') if pipe is not None and "seed" in pipe else None,
|
||||
|
||||
"loader_settings": {
|
||||
**pipe["loader_settings"],
|
||||
"xyplot": xyplot
|
||||
}
|
||||
}
|
||||
del pipe
|
||||
|
||||
return (new_pipe,)
|
||||
|
||||
# 节点束输出
|
||||
class pipeOut:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT", "VAE", "CLIP", "IMAGE", "INT",)
|
||||
RETURN_NAMES = ("pipe", "model", "pos", "neg", "latent", "vae", "clip", "image", "seed",)
|
||||
FUNCTION = "flush"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def flush(self, pipe, my_unique_id=None):
|
||||
model = pipe.get("model")
|
||||
pos = pipe.get("positive")
|
||||
neg = pipe.get("negative")
|
||||
latent = pipe.get("samples")
|
||||
vae = pipe.get("vae")
|
||||
clip = pipe.get("clip")
|
||||
image = pipe.get("images")
|
||||
seed = pipe.get("seed")
|
||||
|
||||
return pipe, model, pos, neg, latent, vae, clip, image, seed
|
||||
|
||||
# 编辑节点束
|
||||
class pipeEdit:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"clip_skip": ("INT", {"default": -1, "min": -24, "max": 0, "step": 1}),
|
||||
|
||||
"optional_positive": ("STRING", {"default": "", "multiline": True}),
|
||||
"positive_token_normalization": (["none", "mean", "length", "length+mean"],),
|
||||
"positive_weight_interpretation": (["comfy", "A1111", "comfy++", "compel", "fixed attention"],),
|
||||
|
||||
"optional_negative": ("STRING", {"default": "", "multiline": True}),
|
||||
"negative_token_normalization": (["none", "mean", "length", "length+mean"],),
|
||||
"negative_weight_interpretation": (["comfy", "A1111", "comfy++", "compel", "fixed attention"],),
|
||||
|
||||
"a1111_prompt_style": ("BOOLEAN", {"default": False}),
|
||||
"conditioning_mode": (['replace', 'concat', 'combine', 'average', 'timestep'], {"default": "replace"}),
|
||||
"average_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"old_cond_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"old_cond_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"new_cond_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"new_cond_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"model": ("MODEL",),
|
||||
"pos": ("CONDITIONING",),
|
||||
"neg": ("CONDITIONING",),
|
||||
"latent": ("LATENT",),
|
||||
"vae": ("VAE",),
|
||||
"clip": ("CLIP",),
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID", "prompt":"PROMPT"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE", "MODEL", "CONDITIONING", "CONDITIONING", "LATENT", "VAE", "CLIP", "IMAGE")
|
||||
RETURN_NAMES = ("pipe", "model", "pos", "neg", "latent", "vae", "clip", "image")
|
||||
FUNCTION = "edit"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def edit(self, clip_skip, optional_positive, positive_token_normalization, positive_weight_interpretation, optional_negative, negative_token_normalization, negative_weight_interpretation, a1111_prompt_style, conditioning_mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end, pipe=None, model=None, pos=None, neg=None, latent=None, vae=None, clip=None, image=None, my_unique_id=None, prompt=None):
|
||||
|
||||
model = model if model is not None else pipe.get("model")
|
||||
if model is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Model missing from pipeLine")
|
||||
vae = vae if vae is not None else pipe.get("vae")
|
||||
if vae is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "VAE missing from pipeLine")
|
||||
clip = clip if clip is not None else pipe.get("clip")
|
||||
if clip is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Clip missing from pipeLine")
|
||||
if image is None:
|
||||
image = pipe.get("images") if pipe is not None else None
|
||||
samples = latent if latent is not None else pipe.get("samples")
|
||||
if samples is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Latent missing from pipeLine")
|
||||
else:
|
||||
batch_size = pipe["loader_settings"]["batch_size"] if "batch_size" in pipe["loader_settings"] else 1
|
||||
samples = {"samples": vae.encode(image[:, :, :, :3])}
|
||||
samples = RepeatLatentBatch().repeat(samples, batch_size)[0]
|
||||
|
||||
pipe_lora_stack = pipe.get("lora_stack") if pipe is not None and "lora_stack" in pipe else []
|
||||
|
||||
steps = pipe["loader_settings"]["steps"] if "steps" in pipe["loader_settings"] else 1
|
||||
if pos is None and optional_positive != '':
|
||||
pos, positive_wildcard_prompt, model, clip = prompt_to_cond('positive', model, clip, clip_skip,
|
||||
pipe_lora_stack, optional_positive, positive_token_normalization,positive_weight_interpretation,
|
||||
a1111_prompt_style, my_unique_id, prompt, easyCache, True, steps)
|
||||
pos = set_cond(pipe['positive'], pos, conditioning_mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end)
|
||||
pipe['loader_settings']['positive'] = positive_wildcard_prompt
|
||||
pipe['loader_settings']['positive_token_normalization'] = positive_token_normalization
|
||||
pipe['loader_settings']['positive_weight_interpretation'] = positive_weight_interpretation
|
||||
if a1111_prompt_style:
|
||||
pipe['loader_settings']['a1111_prompt_style'] = True
|
||||
else:
|
||||
pos = pipe.get("positive")
|
||||
if pos is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Pos Conditioning missing from pipeLine")
|
||||
|
||||
if neg is None and optional_negative != '':
|
||||
neg, negative_wildcard_prompt, model, clip = prompt_to_cond("negative", model, clip, clip_skip, pipe_lora_stack, optional_negative,
|
||||
negative_token_normalization, negative_weight_interpretation,
|
||||
a1111_prompt_style, my_unique_id, prompt, easyCache, True, steps)
|
||||
neg = set_cond(pipe['negative'], neg, conditioning_mode, average_strength, old_cond_start, old_cond_end, new_cond_start, new_cond_end)
|
||||
pipe['loader_settings']['negative'] = negative_wildcard_prompt
|
||||
pipe['loader_settings']['negative_token_normalization'] = negative_token_normalization
|
||||
pipe['loader_settings']['negative_weight_interpretation'] = negative_weight_interpretation
|
||||
if a1111_prompt_style:
|
||||
pipe['loader_settings']['a1111_prompt_style'] = True
|
||||
else:
|
||||
neg = pipe.get("negative")
|
||||
if neg is None:
|
||||
log_node_warn(f'pipeIn[{my_unique_id}]', "Neg Conditioning missing from pipeLine")
|
||||
if pipe is None:
|
||||
pipe = {"loader_settings": {"positive": "", "negative": "", "xyplot": None}}
|
||||
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"model": model,
|
||||
"positive": pos,
|
||||
"negative": neg,
|
||||
"vae": vae,
|
||||
"clip": clip,
|
||||
|
||||
"samples": samples,
|
||||
"images": image,
|
||||
"seed": pipe.get('seed') if pipe is not None and "seed" in pipe else None,
|
||||
"loader_settings":{
|
||||
**pipe["loader_settings"]
|
||||
}
|
||||
}
|
||||
del pipe
|
||||
|
||||
return (new_pipe, model,pos, neg, latent, vae, clip, image)
|
||||
|
||||
# 编辑节点束提示词
|
||||
class pipeEditPrompt:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"positive": ("STRING", {"default": "", "multiline": True}),
|
||||
"negative": ("STRING", {"default": "", "multiline": True}),
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID", "prompt": "PROMPT"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
FUNCTION = "edit"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def edit(self, pipe, positive, negative, my_unique_id=None, prompt=None):
|
||||
model = pipe.get("model")
|
||||
if model is None:
|
||||
log_node_warn(f'pipeEdit[{my_unique_id}]', "Model missing from pipeLine")
|
||||
|
||||
from ..modules.kolors.loader import is_kolors_model
|
||||
model_type = get_sd_version(model)
|
||||
if model_type == 'sdxl' and is_kolors_model(model):
|
||||
from ..modules.kolors.text_encode import chatglm3_adv_text_encode
|
||||
auto_clean_gpu = pipe["loader_settings"]["auto_clean_gpu"] if "auto_clean_gpu" in pipe["loader_settings"] else False
|
||||
chatglm3_model = pipe["chatglm3_model"] if "chatglm3_model" in pipe else None
|
||||
# text encode
|
||||
log_node_warn("Positive encoding...")
|
||||
positive_embeddings_final = chatglm3_adv_text_encode(chatglm3_model, positive, auto_clean_gpu)
|
||||
log_node_warn("Negative encoding...")
|
||||
negative_embeddings_final = chatglm3_adv_text_encode(chatglm3_model, negative, auto_clean_gpu)
|
||||
else:
|
||||
clip_skip = pipe["loader_settings"]["clip_skip"] if "clip_skip" in pipe["loader_settings"] else -1
|
||||
lora_stack = pipe.get("lora_stack") if pipe is not None and "lora_stack" in pipe else []
|
||||
clip = pipe.get("clip") if pipe is not None and "clip" in pipe else None
|
||||
positive_token_normalization = pipe["loader_settings"]["positive_token_normalization"] if "positive_token_normalization" in pipe["loader_settings"] else "none"
|
||||
positive_weight_interpretation = pipe["loader_settings"]["positive_weight_interpretation"] if "positive_weight_interpretation" in pipe["loader_settings"] else "comfy"
|
||||
negative_token_normalization = pipe["loader_settings"]["negative_token_normalization"] if "negative_token_normalization" in pipe["loader_settings"] else "none"
|
||||
negative_weight_interpretation = pipe["loader_settings"]["negative_weight_interpretation"] if "negative_weight_interpretation" in pipe["loader_settings"] else "comfy"
|
||||
a1111_prompt_style = pipe["loader_settings"]["a1111_prompt_style"] if "a1111_prompt_style" in pipe["loader_settings"] else False
|
||||
# Prompt to Conditioning
|
||||
positive_embeddings_final, positive_wildcard_prompt, model, clip = prompt_to_cond('positive', model, clip,
|
||||
clip_skip, lora_stack,
|
||||
positive,
|
||||
positive_token_normalization,
|
||||
positive_weight_interpretation,
|
||||
a1111_prompt_style,
|
||||
my_unique_id, prompt,
|
||||
easyCache,
|
||||
model_type=model_type)
|
||||
negative_embeddings_final, negative_wildcard_prompt, model, clip = prompt_to_cond('negative', model, clip,
|
||||
clip_skip, lora_stack,
|
||||
negative,
|
||||
negative_token_normalization,
|
||||
negative_weight_interpretation,
|
||||
a1111_prompt_style,
|
||||
my_unique_id, prompt,
|
||||
easyCache,
|
||||
model_type=model_type)
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"model": model,
|
||||
"positive": positive_embeddings_final,
|
||||
"negative": negative_embeddings_final,
|
||||
}
|
||||
del pipe
|
||||
|
||||
return (new_pipe,)
|
||||
|
||||
|
||||
# 节点束到基础节点束(pipe to ComfyUI-Impack-pack's basic_pipe)
|
||||
class pipeToBasicPipe:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BASIC_PIPE",)
|
||||
RETURN_NAMES = ("basic_pipe",)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def doit(self, pipe, my_unique_id=None):
|
||||
new_pipe = (pipe.get('model'), pipe.get('clip'), pipe.get('vae'), pipe.get('positive'), pipe.get('negative'))
|
||||
del pipe
|
||||
return (new_pipe,)
|
||||
|
||||
# 批次索引
|
||||
class pipeBatchIndex:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"pipe": ("PIPE_LINE",),
|
||||
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
|
||||
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID"},}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
FUNCTION = "doit"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def doit(self, pipe, batch_index, length, my_unique_id=None):
|
||||
samples = pipe["samples"]
|
||||
new_samples, = LatentFromBatch().frombatch(samples, batch_index, length)
|
||||
new_pipe = {
|
||||
**pipe,
|
||||
"samples": new_samples
|
||||
}
|
||||
del pipe
|
||||
return (new_pipe,)
|
||||
|
||||
# pipeXYPlot
|
||||
class pipeXYPlot:
|
||||
lora_list = ["None"] + folder_paths.get_filename_list("loras")
|
||||
lora_strengths = {"min": -4.0, "max": 4.0, "step": 0.01}
|
||||
token_normalization = ["none", "mean", "length", "length+mean"]
|
||||
weight_interpretation = ["comfy", "A1111", "compel", "comfy++"]
|
||||
|
||||
loader_dict = {
|
||||
"ckpt_name": folder_paths.get_filename_list("checkpoints"),
|
||||
"vae_name": ["Baked-VAE"] + folder_paths.get_filename_list("vae"),
|
||||
"clip_skip": {"min": -24, "max": -1, "step": 1},
|
||||
"lora_name": lora_list,
|
||||
"lora_model_strength": lora_strengths,
|
||||
"lora_clip_strength": lora_strengths,
|
||||
"positive": [],
|
||||
"negative": [],
|
||||
}
|
||||
|
||||
sampler_dict = {
|
||||
"steps": {"min": 1, "max": 100, "step": 1},
|
||||
"cfg": {"min": 0.0, "max": 100.0, "step": 1.0},
|
||||
"sampler_name": comfy.samplers.KSampler.SAMPLERS,
|
||||
"scheduler": comfy.samplers.KSampler.SCHEDULERS,
|
||||
"denoise": {"min": 0.0, "max": 1.0, "step": 0.01},
|
||||
"seed": {"min": 0, "max": MAX_SEED_NUM},
|
||||
}
|
||||
|
||||
plot_dict = {**sampler_dict, **loader_dict}
|
||||
|
||||
plot_values = ["None", ]
|
||||
plot_values.append("---------------------")
|
||||
for k in sampler_dict:
|
||||
plot_values.append(f'preSampling: {k}')
|
||||
plot_values.append("---------------------")
|
||||
for k in loader_dict:
|
||||
plot_values.append(f'loader: {k}')
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
rejected = ["None", "---------------------", "Nothing"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"grid_spacing": ("INT", {"min": 0, "max": 500, "step": 5, "default": 0, }),
|
||||
"output_individuals": (["False", "True"], {"default": "False"}),
|
||||
"flip_xy": (["False", "True"], {"default": "False"}),
|
||||
"x_axis": (pipeXYPlot.plot_values, {"default": 'None'}),
|
||||
"x_values": (
|
||||
"STRING", {"default": '', "multiline": True, "placeholder": 'insert values seperated by "; "'}),
|
||||
"y_axis": (pipeXYPlot.plot_values, {"default": 'None'}),
|
||||
"y_values": (
|
||||
"STRING", {"default": '', "multiline": True, "placeholder": 'insert values seperated by "; "'}),
|
||||
},
|
||||
"optional": {
|
||||
"pipe": ("PIPE_LINE",)
|
||||
},
|
||||
"hidden": {
|
||||
"plot_dict": (pipeXYPlot.plot_dict,),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
FUNCTION = "plot"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def plot(self, grid_spacing, output_individuals, flip_xy, x_axis, x_values, y_axis, y_values, pipe=None, font_path=None):
|
||||
def clean_values(values):
|
||||
original_values = values.split("; ")
|
||||
cleaned_values = []
|
||||
|
||||
for value in original_values:
|
||||
# Strip the semi-colon
|
||||
cleaned_value = value.strip(';').strip()
|
||||
|
||||
if cleaned_value == "":
|
||||
continue
|
||||
|
||||
# Try to convert the cleaned_value back to int or float if possible
|
||||
try:
|
||||
cleaned_value = int(cleaned_value)
|
||||
except ValueError:
|
||||
try:
|
||||
cleaned_value = float(cleaned_value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Append the cleaned_value to the list
|
||||
cleaned_values.append(cleaned_value)
|
||||
|
||||
return cleaned_values
|
||||
|
||||
if x_axis in self.rejected:
|
||||
x_axis = "None"
|
||||
x_values = []
|
||||
else:
|
||||
x_values = clean_values(x_values)
|
||||
|
||||
if y_axis in self.rejected:
|
||||
y_axis = "None"
|
||||
y_values = []
|
||||
else:
|
||||
y_values = clean_values(y_values)
|
||||
|
||||
if flip_xy == "True":
|
||||
x_axis, y_axis = y_axis, x_axis
|
||||
x_values, y_values = y_values, x_values
|
||||
|
||||
|
||||
xy_plot = {"x_axis": x_axis,
|
||||
"x_vals": x_values,
|
||||
"y_axis": y_axis,
|
||||
"y_vals": y_values,
|
||||
"custom_font": font_path,
|
||||
"grid_spacing": grid_spacing,
|
||||
"output_individuals": output_individuals}
|
||||
|
||||
if pipe is not None:
|
||||
new_pipe = pipe.copy()
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"xyplot": xy_plot
|
||||
}
|
||||
del pipe
|
||||
return (new_pipe, xy_plot,)
|
||||
|
||||
# pipeXYPlotAdvanced
|
||||
import platform
|
||||
class pipeXYPlotAdvanced:
|
||||
if platform.system() == "Windows":
|
||||
system_root = os.environ.get("SystemRoot")
|
||||
user_root = os.environ.get("USERPROFILE")
|
||||
font_dir = os.path.join(system_root, "Fonts") if system_root else None
|
||||
user_font_dir = os.path.join(user_root, "AppData","Local","Microsoft","Windows", "Fonts") if user_root else None
|
||||
|
||||
# Default debian-based Linux & MacOS font dirs
|
||||
elif platform.system() == "Linux":
|
||||
font_dir = "/usr/share/fonts/truetype"
|
||||
user_font_dir = None
|
||||
elif platform.system() == "Darwin":
|
||||
font_dir = "/System/Library/Fonts"
|
||||
user_font_dir = None
|
||||
else:
|
||||
font_dir = None
|
||||
user_font_dir = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
files_list = []
|
||||
if s.font_dir and os.path.exists(s.font_dir):
|
||||
font_dir = s.font_dir
|
||||
files_list = files_list + [f for f in os.listdir(font_dir) if os.path.isfile(os.path.join(font_dir, f)) and f.lower().endswith(".ttf")]
|
||||
|
||||
if s.user_font_dir and os.path.exists(s.user_font_dir):
|
||||
files_list = files_list + [f for f in os.listdir(s.user_font_dir) if os.path.isfile(os.path.join(s.user_font_dir, f)) and f.lower().endswith(".ttf")]
|
||||
|
||||
return {
|
||||
"required": {
|
||||
"pipe": ("PIPE_LINE",),
|
||||
"grid_spacing": ("INT", {"min": 0, "max": 500, "step": 5, "default": 0, }),
|
||||
"output_individuals": (["False", "True"], {"default": "False"}),
|
||||
"flip_xy": (["False", "True"], {"default": "False"}),
|
||||
},
|
||||
"optional": {
|
||||
"X": ("X_Y",),
|
||||
"Y": ("X_Y",),
|
||||
"font": (["None"] + files_list,)
|
||||
},
|
||||
"hidden": {"my_unique_id": "UNIQUE_ID"}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PIPE_LINE",)
|
||||
RETURN_NAMES = ("pipe",)
|
||||
FUNCTION = "plot"
|
||||
|
||||
CATEGORY = "EasyUse/Pipe"
|
||||
|
||||
def plot(self, pipe, grid_spacing, output_individuals, flip_xy, X=None, Y=None, font=None, my_unique_id=None):
|
||||
font_path = os.path.join(self.font_dir, font) if font != "None" else None
|
||||
if font_path and not os.path.exists(font_path):
|
||||
font_path = os.path.join(self.user_font_dir, font)
|
||||
|
||||
if X != None:
|
||||
x_axis = X.get('axis')
|
||||
x_values = X.get('values')
|
||||
else:
|
||||
x_axis = "Nothing"
|
||||
x_values = [""]
|
||||
if Y != None:
|
||||
y_axis = Y.get('axis')
|
||||
y_values = Y.get('values')
|
||||
else:
|
||||
y_axis = "Nothing"
|
||||
y_values = [""]
|
||||
|
||||
if pipe is not None:
|
||||
new_pipe = pipe.copy()
|
||||
positive = pipe["loader_settings"]["positive"] if "positive" in pipe["loader_settings"] else ""
|
||||
negative = pipe["loader_settings"]["negative"] if "negative" in pipe["loader_settings"] else ""
|
||||
|
||||
if x_axis == 'advanced: ModelMergeBlocks':
|
||||
models = X.get('models')
|
||||
vae_use = X.get('vae_use')
|
||||
if models is None:
|
||||
raise Exception("models is not found")
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"models": models,
|
||||
"vae_use": vae_use
|
||||
}
|
||||
if y_axis == 'advanced: ModelMergeBlocks':
|
||||
models = Y.get('models')
|
||||
vae_use = Y.get('vae_use')
|
||||
if models is None:
|
||||
raise Exception("models is not found")
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"models": models,
|
||||
"vae_use": vae_use
|
||||
}
|
||||
|
||||
if x_axis in ['advanced: Lora', 'advanced: Checkpoint']:
|
||||
lora_stack = X.get('lora_stack')
|
||||
_lora_stack = []
|
||||
if lora_stack is not None:
|
||||
for lora in lora_stack:
|
||||
_lora_stack.append(
|
||||
{"lora_name": lora[0], "model": pipe['model'], "clip": pipe['clip'], "model_strength": lora[1],
|
||||
"clip_strength": lora[2]})
|
||||
del lora_stack
|
||||
x_values = "; ".join(x_values)
|
||||
lora_stack = pipe['lora_stack'] + _lora_stack if 'lora_stack' in pipe else _lora_stack
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"lora_stack": lora_stack,
|
||||
}
|
||||
|
||||
if y_axis in ['advanced: Lora', 'advanced: Checkpoint']:
|
||||
lora_stack = Y.get('lora_stack')
|
||||
_lora_stack = []
|
||||
if lora_stack is not None:
|
||||
for lora in lora_stack:
|
||||
_lora_stack.append(
|
||||
{"lora_name": lora[0], "model": pipe['model'], "clip": pipe['clip'], "model_strength": lora[1],
|
||||
"clip_strength": lora[2]})
|
||||
del lora_stack
|
||||
y_values = "; ".join(y_values)
|
||||
lora_stack = pipe['lora_stack'] + _lora_stack if 'lora_stack' in pipe else _lora_stack
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"lora_stack": lora_stack,
|
||||
}
|
||||
|
||||
if x_axis == 'advanced: Seeds++ Batch':
|
||||
if new_pipe['seed']:
|
||||
value = x_values
|
||||
x_values = []
|
||||
for index in range(value):
|
||||
x_values.append(str(new_pipe['seed'] + index))
|
||||
x_values = "; ".join(x_values)
|
||||
if y_axis == 'advanced: Seeds++ Batch':
|
||||
if new_pipe['seed']:
|
||||
value = y_values
|
||||
y_values = []
|
||||
for index in range(value):
|
||||
y_values.append(str(new_pipe['seed'] + index))
|
||||
y_values = "; ".join(y_values)
|
||||
|
||||
if x_axis == 'advanced: Positive Prompt S/R':
|
||||
if positive:
|
||||
x_value = x_values
|
||||
x_values = []
|
||||
for index, value in enumerate(x_value):
|
||||
search_txt, replace_txt, replace_all = value
|
||||
if replace_all:
|
||||
txt = replace_txt if replace_txt is not None else positive
|
||||
x_values.append(txt)
|
||||
else:
|
||||
txt = positive.replace(search_txt, replace_txt, 1) if replace_txt is not None else positive
|
||||
x_values.append(txt)
|
||||
x_values = "; ".join(x_values)
|
||||
if y_axis == 'advanced: Positive Prompt S/R':
|
||||
if positive:
|
||||
y_value = y_values
|
||||
y_values = []
|
||||
for index, value in enumerate(y_value):
|
||||
search_txt, replace_txt, replace_all = value
|
||||
if replace_all:
|
||||
txt = replace_txt if replace_txt is not None else positive
|
||||
y_values.append(txt)
|
||||
else:
|
||||
txt = positive.replace(search_txt, replace_txt, 1) if replace_txt is not None else positive
|
||||
y_values.append(txt)
|
||||
y_values = "; ".join(y_values)
|
||||
|
||||
if x_axis == 'advanced: Negative Prompt S/R':
|
||||
if negative:
|
||||
x_value = x_values
|
||||
x_values = []
|
||||
for index, value in enumerate(x_value):
|
||||
search_txt, replace_txt, replace_all = value
|
||||
if replace_all:
|
||||
txt = replace_txt if replace_txt is not None else negative
|
||||
x_values.append(txt)
|
||||
else:
|
||||
txt = negative.replace(search_txt, replace_txt, 1) if replace_txt is not None else negative
|
||||
x_values.append(txt)
|
||||
x_values = "; ".join(x_values)
|
||||
if y_axis == 'advanced: Negative Prompt S/R':
|
||||
if negative:
|
||||
y_value = y_values
|
||||
y_values = []
|
||||
for index, value in enumerate(y_value):
|
||||
search_txt, replace_txt, replace_all = value
|
||||
if replace_all:
|
||||
txt = replace_txt if replace_txt is not None else negative
|
||||
y_values.append(txt)
|
||||
else:
|
||||
txt = negative.replace(search_txt, replace_txt, 1) if replace_txt is not None else negative
|
||||
y_values.append(txt)
|
||||
y_values = "; ".join(y_values)
|
||||
|
||||
if "advanced: ControlNet" in x_axis:
|
||||
x_value = x_values
|
||||
x_values = []
|
||||
cnet = []
|
||||
for index, value in enumerate(x_value):
|
||||
cnet.append(value)
|
||||
x_values.append(str(index))
|
||||
x_values = "; ".join(x_values)
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"cnet_stack": cnet,
|
||||
}
|
||||
|
||||
if "advanced: ControlNet" in y_axis:
|
||||
y_value = y_values
|
||||
y_values = []
|
||||
cnet = []
|
||||
for index, value in enumerate(y_value):
|
||||
cnet.append(value)
|
||||
y_values.append(str(index))
|
||||
y_values = "; ".join(y_values)
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"cnet_stack": cnet,
|
||||
}
|
||||
|
||||
if "advanced: Pos Condition" in x_axis:
|
||||
x_values = "; ".join(x_values)
|
||||
cond = X.get('cond')
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"positive_cond_stack": cond,
|
||||
}
|
||||
if "advanced: Pos Condition" in y_axis:
|
||||
y_values = "; ".join(y_values)
|
||||
cond = Y.get('cond')
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"positive_cond_stack": cond,
|
||||
}
|
||||
|
||||
if "advanced: Neg Condition" in x_axis:
|
||||
x_values = "; ".join(x_values)
|
||||
cond = X.get('cond')
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"negative_cond_stack": cond,
|
||||
}
|
||||
if "advanced: Neg Condition" in y_axis:
|
||||
y_values = "; ".join(y_values)
|
||||
cond = Y.get('cond')
|
||||
new_pipe['loader_settings'] = {
|
||||
**pipe['loader_settings'],
|
||||
"negative_cond_stack": cond,
|
||||
}
|
||||
|
||||
del pipe
|
||||
|
||||
return pipeXYPlot().plot(grid_spacing, output_individuals, flip_xy, x_axis, x_values, y_axis, y_values, new_pipe, font_path)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"easy pipeIn": pipeIn,
|
||||
"easy pipeOut": pipeOut,
|
||||
"easy pipeEdit": pipeEdit,
|
||||
"easy pipeEditPrompt": pipeEditPrompt,
|
||||
"easy pipeToBasicPipe": pipeToBasicPipe,
|
||||
"easy pipeBatchIndex": pipeBatchIndex,
|
||||
"easy XYPlot": pipeXYPlot,
|
||||
"easy XYPlotAdvanced": pipeXYPlotAdvanced
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"easy pipeIn": "Pipe In",
|
||||
"easy pipeOut": "Pipe Out",
|
||||
"easy pipeEdit": "Pipe Edit",
|
||||
"easy pipeEditPrompt": "Pipe Edit Prompt",
|
||||
"easy pipeBatchIndex": "Pipe Batch Index",
|
||||
"easy pipeToBasicPipe": "Pipe -> BasicPipe",
|
||||
"easy XYPlot": "XY Plot",
|
||||
"easy XYPlotAdvanced": "XY Plot Advanced"
|
||||
}
|
||||
1002
custom_nodes/ComfyUI-Easy-Use/py/nodes/preSampling.py
Normal file
1002
custom_nodes/ComfyUI-Easy-Use/py/nodes/preSampling.py
Normal file
File diff suppressed because it is too large
Load Diff
790
custom_nodes/ComfyUI-Easy-Use/py/nodes/prompt.py
Normal file
790
custom_nodes/ComfyUI-Easy-Use/py/nodes/prompt.py
Normal file
@@ -0,0 +1,790 @@
|
||||
import json
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
import folder_paths
|
||||
|
||||
from .. import easyCache
|
||||
from ..config import FOOOCUS_STYLES_DIR, MAX_SEED_NUM, PROMPT_TEMPLATE, RESOURCES_DIR
|
||||
from ..libs.log import log_node_info
|
||||
from ..libs.wildcards import WildcardProcessor, get_wildcard_list, process
|
||||
|
||||
from comfy_api.latest import io
|
||||
|
||||
|
||||
# 正面提示词
|
||||
class positivePrompt(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy positive",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("positive", default="", multiline=True, placeholder="Positive"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(id="output_positive", display_name="positive"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive):
|
||||
return io.NodeOutput(positive)
|
||||
|
||||
# 通配符提示词
|
||||
class wildcardsPrompt(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
wildcard_list = get_wildcard_list()
|
||||
return io.Schema(
|
||||
node_id="easy wildcards",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("text", default="", multiline=True, dynamic_prompts=False, placeholder="(Support wildcard)"),
|
||||
io.Combo.Input("Select to add LoRA", options=["Select the LoRA to add to the text"] + folder_paths.get_filename_list("loras")),
|
||||
io.Combo.Input("Select to add Wildcard", options=["Select the Wildcard to add to the text"] + wildcard_list),
|
||||
io.Int.Input("seed", default=0, min=0, max=MAX_SEED_NUM),
|
||||
io.Boolean.Input("multiline_mode", default=False),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(id="output_text", display_name="text", is_output_list=True),
|
||||
io.String.Output(id="populated_text", display_name="populated_text", is_output_list=True),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.prompt,
|
||||
io.Hidden.extra_pnginfo,
|
||||
io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, text, seed, multiline_mode, **kwargs):
|
||||
prompt = cls.hidden.prompt
|
||||
|
||||
# Clean loaded_objects
|
||||
if prompt:
|
||||
easyCache.update_loaded_objects(prompt)
|
||||
|
||||
if multiline_mode:
|
||||
populated_text = []
|
||||
_text = []
|
||||
text_lines = text.split("\n")
|
||||
for t in text_lines:
|
||||
_text.append(t)
|
||||
populated_text.append(process(t, seed))
|
||||
text = _text
|
||||
else:
|
||||
populated_text = [process(text, seed)]
|
||||
text = [text]
|
||||
return io.NodeOutput(text, populated_text, ui={"value": [seed]})
|
||||
|
||||
# 通配符提示词矩阵,会按顺序返回包含通配符的提示词所生成的所有可能
|
||||
class wildcardsPromptMatrix(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
wildcard_list = get_wildcard_list()
|
||||
return io.Schema(
|
||||
node_id="easy wildcardsMatrix",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("text", default="", multiline=True, dynamic_prompts=False, placeholder="(Support Lora Block Weight and wildcard)"),
|
||||
io.Combo.Input("Select to add LoRA", options=["Select the LoRA to add to the text"] + folder_paths.get_filename_list("loras")),
|
||||
io.Combo.Input("Select to add Wildcard", options=["Select the Wildcard to add to the text"] + wildcard_list),
|
||||
io.Int.Input("offset", default=0, min=0, max=MAX_SEED_NUM, step=1, control_after_generate=True),
|
||||
io.Int.Input("output_limit", default=1, min=-1, step=1, tooltip="Output All Probilities", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output("populated_text", is_output_list=True),
|
||||
io.Int.Output("total"),
|
||||
io.Int.Output("factors", is_output_list=True),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.prompt,
|
||||
io.Hidden.extra_pnginfo,
|
||||
io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, text, offset, output_limit=1, **kwargs):
|
||||
prompt = cls.hidden.prompt
|
||||
# Clean loaded_objects
|
||||
if prompt:
|
||||
easyCache.update_loaded_objects(prompt)
|
||||
|
||||
p = WildcardProcessor(text)
|
||||
total = p.total()
|
||||
limit = total if output_limit > total or output_limit == -1 else output_limit
|
||||
offset = 0 if output_limit == -1 else offset
|
||||
populated_text = p.getmany(limit, offset) if output_limit != 1 else [p.getn(offset)]
|
||||
return io.NodeOutput(populated_text, p.total(), list(p.placeholder_choices.values()), ui={"value": [offset]})
|
||||
|
||||
# 负面提示词
|
||||
class negativePrompt(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy negative",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("negative", default="", multiline=True, placeholder="Negative"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(id="output_negative", display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, negative):
|
||||
return io.NodeOutput(negative)
|
||||
|
||||
# 风格提示词选择器
|
||||
class stylesPromptSelector(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
styles = ["fooocus_styles"]
|
||||
styles_dir = FOOOCUS_STYLES_DIR
|
||||
for file_name in os.listdir(styles_dir):
|
||||
file = os.path.join(styles_dir, file_name)
|
||||
if os.path.isfile(file) and file_name.endswith(".json"):
|
||||
if file_name != "fooocus_styles.json":
|
||||
styles.append(file_name.split(".")[0])
|
||||
|
||||
return io.Schema(
|
||||
node_id="easy stylesSelector",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.Combo.Input("styles", options=styles, default="fooocus_styles"),
|
||||
io.String.Input("positive", default="", force_input=True, optional=True),
|
||||
io.String.Input("negative", default="", force_input=True, optional=True),
|
||||
io.Custom(io_type="EASY_PROMPT_STYLES").Input("select_styles", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(id="output_positive", display_name="positive"),
|
||||
io.String.Output(id="output_negative", display_name="negative"),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.prompt,
|
||||
io.Hidden.extra_pnginfo,
|
||||
io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, styles, positive='', negative='', select_styles=None, **kwargs):
|
||||
values = []
|
||||
all_styles = {}
|
||||
positive_prompt, negative_prompt = '', negative
|
||||
fooocus_custom_dir = os.path.join(FOOOCUS_STYLES_DIR, 'fooocus_styles.json')
|
||||
if styles == "fooocus_styles" and not os.path.exists(fooocus_custom_dir):
|
||||
file = os.path.join(RESOURCES_DIR, styles + '.json')
|
||||
else:
|
||||
file = os.path.join(FOOOCUS_STYLES_DIR, styles + '.json')
|
||||
f = open(file, 'r', encoding='utf-8')
|
||||
data = json.load(f)
|
||||
f.close()
|
||||
for d in data:
|
||||
all_styles[d['name']] = d
|
||||
# if my_unique_id in prompt:
|
||||
# if prompt[my_unique_id]["inputs"]['select_styles']:
|
||||
# values = prompt[my_unique_id]["inputs"]['select_styles'].split(',')
|
||||
|
||||
if isinstance(select_styles, str):
|
||||
values = select_styles.split(',')
|
||||
else:
|
||||
values = select_styles if select_styles else []
|
||||
|
||||
has_prompt = False
|
||||
if len(values) == 0:
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
for index, val in enumerate(values):
|
||||
if val not in all_styles:
|
||||
continue
|
||||
if 'prompt' in all_styles[val]:
|
||||
if "{prompt}" in all_styles[val]['prompt'] and has_prompt == False:
|
||||
positive_prompt = all_styles[val]['prompt'].replace('{prompt}', positive)
|
||||
has_prompt = True
|
||||
elif "{prompt}" in all_styles[val]['prompt']:
|
||||
positive_prompt += ', ' + all_styles[val]['prompt'].replace(', {prompt}', '').replace('{prompt}', '')
|
||||
else:
|
||||
positive_prompt = all_styles[val]['prompt'] if positive_prompt == '' else positive_prompt + ', ' + all_styles[val]['prompt']
|
||||
if 'negative_prompt' in all_styles[val]:
|
||||
negative_prompt += ', ' + all_styles[val]['negative_prompt'] if negative_prompt else all_styles[val]['negative_prompt']
|
||||
|
||||
if has_prompt == False and positive:
|
||||
positive_prompt = positive + positive_prompt + ', '
|
||||
|
||||
return io.NodeOutput(positive_prompt, negative_prompt)
|
||||
|
||||
#prompt
|
||||
class prompt(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy prompt",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("text", default="", multiline=True, placeholder="Prompt"),
|
||||
io.Combo.Input("prefix", options=["Select the prefix add to the text"] + PROMPT_TEMPLATE["prefix"], default="Select the prefix add to the text"),
|
||||
io.Combo.Input("subject", options=["👤Select the subject add to the text"] + PROMPT_TEMPLATE["subject"], default="👤Select the subject add to the text"),
|
||||
io.Combo.Input("action", options=["🎬Select the action add to the text"] + PROMPT_TEMPLATE["action"], default="🎬Select the action add to the text"),
|
||||
io.Combo.Input("clothes", options=["👚Select the clothes add to the text"] + PROMPT_TEMPLATE["clothes"], default="👚Select the clothes add to the text"),
|
||||
io.Combo.Input("environment", options=["☀️Select the illumination environment add to the text"] + PROMPT_TEMPLATE["environment"], default="☀️Select the illumination environment add to the text"),
|
||||
io.Combo.Input("background", options=["🎞️Select the background add to the text"] + PROMPT_TEMPLATE["background"], default="🎞️Select the background add to the text"),
|
||||
io.Combo.Input("nsfw", options=["🔞Select the nsfw add to the text"] + PROMPT_TEMPLATE["nsfw"], default="🔞️Select the nsfw add to the text"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output("prompt"),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.prompt,
|
||||
io.Hidden.extra_pnginfo,
|
||||
io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, text, **kwargs):
|
||||
return io.NodeOutput(text)
|
||||
|
||||
#promptList
|
||||
class promptList(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy promptList",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("prompt_1", multiline=True, default=""),
|
||||
io.String.Input("prompt_2", multiline=True, default=""),
|
||||
io.String.Input("prompt_3", multiline=True, default=""),
|
||||
io.String.Input("prompt_4", multiline=True, default=""),
|
||||
io.String.Input("prompt_5", multiline=True, default=""),
|
||||
io.Custom(io_type="LIST").Input("optional_prompt_list", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Custom(io_type="LIST").Output("prompt_list"),
|
||||
io.String.Output("prompt_strings", is_output_list=True),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, prompt_1="", prompt_2="", prompt_3="", prompt_4="", prompt_5="", optional_prompt_list=None, **kwargs):
|
||||
prompts = []
|
||||
|
||||
if optional_prompt_list:
|
||||
for l in optional_prompt_list:
|
||||
prompts.append(l)
|
||||
|
||||
# Add individual prompts
|
||||
for p in [prompt_1, prompt_2, prompt_3, prompt_4, prompt_5]:
|
||||
if isinstance(p, str) and p != '':
|
||||
prompts.append(p)
|
||||
|
||||
return io.NodeOutput(prompts, prompts)
|
||||
|
||||
#promptLine
|
||||
class promptLine(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy promptLine",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("prompt", multiline=True, default="text"),
|
||||
io.Int.Input("start_index", default=0, min=0, max=9999),
|
||||
io.Int.Input("max_rows", default=1000, min=1, max=9999),
|
||||
io.Boolean.Input("remove_empty_lines", default=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output("STRING", is_output_list=True),
|
||||
io.Combo.Output("COMBO", is_output_list=True),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.prompt,
|
||||
io.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, prompt, start_index, max_rows, remove_empty_lines=True, **kwargs):
|
||||
lines = prompt.split('\n')
|
||||
|
||||
if remove_empty_lines:
|
||||
lines = [line for line in lines if line.strip()]
|
||||
|
||||
start_index = max(0, min(start_index, len(lines) - 1))
|
||||
|
||||
end_index = min(start_index + max_rows, len(lines))
|
||||
|
||||
rows = lines[start_index:end_index]
|
||||
|
||||
return io.NodeOutput(rows, rows)
|
||||
|
||||
import comfy.utils
|
||||
from server import PromptServer
|
||||
from ..libs.messages import MessageCancelled, Message
|
||||
class promptAwait(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy promptAwait",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.AnyType.Input("now"),
|
||||
io.String.Input("prompt", multiline=True, default="", placeholder="Enter a prompt or use voice to enter to text"),
|
||||
io.Custom(io_type="EASY_PROMPT_AWAIT_BAR").Input("toolbar"),
|
||||
io.AnyType.Input("prev", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.AnyType.Output(id="output", display_name="output"),
|
||||
io.String.Output(id="output_prompt", display_name="prompt"),
|
||||
io.Boolean.Output("continue"),
|
||||
io.Int.Output("seed"),
|
||||
],
|
||||
hidden=[
|
||||
io.Hidden.prompt,
|
||||
io.Hidden.unique_id,
|
||||
io.Hidden.extra_pnginfo,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, now, prompt, toolbar, prev=None, **kwargs):
|
||||
id = cls.hidden.unique_id
|
||||
id = id.split('.')[len(id.split('.')) - 1] if "." in id else id
|
||||
if ":" in id:
|
||||
id = id.split(":")[0]
|
||||
pbar = comfy.utils.ProgressBar(100)
|
||||
pbar.update_absolute(30)
|
||||
PromptServer.instance.send_sync('easyuse_prompt_await', {"id": id})
|
||||
try:
|
||||
res = Message.waitForMessage(id, asList=False)
|
||||
if res is None or res == "-1":
|
||||
result = (now, prompt, False, 0)
|
||||
else:
|
||||
input = now if res['select'] == 'now' or prev is None else prev
|
||||
result = (input, res['prompt'], False if res['result'] == -1 else True, res['seed'] if res['unlock'] else res['last_seed'])
|
||||
pbar.update_absolute(100)
|
||||
return io.NodeOutput(*result)
|
||||
except MessageCancelled:
|
||||
pbar.update_absolute(100)
|
||||
raise comfy.model_management.InterruptProcessingException()
|
||||
|
||||
class promptConcat(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy promptConcat",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("prompt1", multiline=False, default="", force_input=True, optional=True),
|
||||
io.String.Input("prompt2", multiline=False, default="", force_input=True, optional=True),
|
||||
io.String.Input("separator", multiline=False, default="", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output("prompt"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, prompt1="", prompt2="", separator=""):
|
||||
return io.NodeOutput(prompt1 + separator + prompt2)
|
||||
|
||||
class promptReplace(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy promptReplace",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.String.Input("prompt", multiline=True, default="", force_input=True),
|
||||
io.String.Input("find1", multiline=False, default="", optional=True),
|
||||
io.String.Input("replace1", multiline=False, default="", optional=True),
|
||||
io.String.Input("find2", multiline=False, default="", optional=True),
|
||||
io.String.Input("replace2", multiline=False, default="", optional=True),
|
||||
io.String.Input("find3", multiline=False, default="", optional=True),
|
||||
io.String.Input("replace3", multiline=False, default="", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(id="output_prompt",display_name="prompt"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, prompt, find1="", replace1="", find2="", replace2="", find3="", replace3=""):
|
||||
prompt = prompt.replace(find1, replace1)
|
||||
prompt = prompt.replace(find2, replace2)
|
||||
prompt = prompt.replace(find3, replace3)
|
||||
|
||||
return io.NodeOutput(prompt)
|
||||
|
||||
|
||||
# 肖像大师
|
||||
# Created by AI Wiz Art (Stefano Flore)
|
||||
# Version: 2.2
|
||||
# https://stefanoflore.it
|
||||
# https://ai-wiz.art
|
||||
class portraitMaster(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
max_float_value = 1.95
|
||||
prompt_path = os.path.join(RESOURCES_DIR, 'portrait_prompt.json')
|
||||
if not os.path.exists(prompt_path):
|
||||
response = urlopen('https://raw.githubusercontent.com/yolain/ComfyUI-Easy-Use/main/resources/portrait_prompt.json')
|
||||
temp_prompt = json.loads(response.read())
|
||||
prompt_serialized = json.dumps(temp_prompt, indent=4)
|
||||
with open(prompt_path, "w") as f:
|
||||
f.write(prompt_serialized)
|
||||
del response, temp_prompt
|
||||
# Load local
|
||||
with open(prompt_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
inputs = []
|
||||
# Shot
|
||||
inputs.append(io.Combo.Input("shot", options=['-'] + data['shot_list']))
|
||||
inputs.append(io.Float.Input("shot_weight", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
# Gender and age
|
||||
inputs.append(io.Combo.Input("gender", options=['-'] + data['gender_list'], default="Woman"))
|
||||
inputs.append(io.Int.Input("age", default=30, min=18, max=90, step=1, display_mode=io.NumberDisplay.slider))
|
||||
# Nationality
|
||||
inputs.append(io.Combo.Input("nationality_1", options=['-'] + data['nationality_list'], default="Chinese"))
|
||||
inputs.append(io.Combo.Input("nationality_2", options=['-'] + data['nationality_list']))
|
||||
inputs.append(io.Float.Input("nationality_mix", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
# Body
|
||||
inputs.append(io.Combo.Input("body_type", options=['-'] + data['body_type_list']))
|
||||
inputs.append(io.Float.Input("body_type_weight", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Combo.Input("model_pose", options=['-'] + data['model_pose_list']))
|
||||
inputs.append(io.Combo.Input("eyes_color", options=['-'] + data['eyes_color_list']))
|
||||
# Face
|
||||
inputs.append(io.Combo.Input("facial_expression", options=['-'] + data['face_expression_list']))
|
||||
inputs.append(io.Float.Input("facial_expression_weight", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Combo.Input("face_shape", options=['-'] + data['face_shape_list']))
|
||||
inputs.append(io.Float.Input("face_shape_weight", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("facial_asymmetry", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
# Hair
|
||||
inputs.append(io.Combo.Input("hair_style", options=['-'] + data['hair_style_list']))
|
||||
inputs.append(io.Combo.Input("hair_color", options=['-'] + data['hair_color_list']))
|
||||
inputs.append(io.Float.Input("disheveled", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Combo.Input("beard", options=['-'] + data['beard_list']))
|
||||
# Skin details
|
||||
inputs.append(io.Float.Input("skin_details", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("skin_pores", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("dimples", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("freckles", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("moles", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("skin_imperfections", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("skin_acne", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("tanned_skin", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
# Eyes
|
||||
inputs.append(io.Float.Input("eyes_details", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("iris_details", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("circular_iris", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
inputs.append(io.Float.Input("circular_pupil", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
# Light
|
||||
inputs.append(io.Combo.Input("light_type", options=['-'] + data['light_type_list']))
|
||||
inputs.append(io.Combo.Input("light_direction", options=['-'] + data['light_direction_list']))
|
||||
inputs.append(io.Float.Input("light_weight", default=0, step=0.05, min=0, max=max_float_value, display_mode=io.NumberDisplay.slider))
|
||||
# Additional
|
||||
inputs.append(io.Combo.Input("photorealism_improvement", options=["enable", "disable"]))
|
||||
inputs.append(io.String.Input("prompt_start", multiline=True, default="raw photo, (realistic:1.5)"))
|
||||
inputs.append(io.String.Input("prompt_additional", multiline=True, default=""))
|
||||
inputs.append(io.String.Input("prompt_end", multiline=True, default=""))
|
||||
inputs.append(io.String.Input("negative_prompt", multiline=True, default=""))
|
||||
|
||||
return io.Schema(
|
||||
node_id="easy portraitMaster",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=inputs,
|
||||
outputs=[
|
||||
io.String.Output("positive"),
|
||||
io.String.Output("negative"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shot="-", shot_weight=1, gender="-", body_type="-", body_type_weight=0, eyes_color="-",
|
||||
facial_expression="-", facial_expression_weight=0, face_shape="-", face_shape_weight=0,
|
||||
nationality_1="-", nationality_2="-", nationality_mix=0.5, age=30, hair_style="-", hair_color="-",
|
||||
disheveled=0, dimples=0, freckles=0, skin_pores=0, skin_details=0, moles=0, skin_imperfections=0,
|
||||
wrinkles=0, tanned_skin=0, eyes_details=1, iris_details=1, circular_iris=1, circular_pupil=1,
|
||||
facial_asymmetry=0, prompt_additional="", prompt_start="", prompt_end="", light_type="-",
|
||||
light_direction="-", light_weight=0, negative_prompt="", photorealism_improvement="disable", beard="-",
|
||||
model_pose="-", skin_acne=0):
|
||||
|
||||
prompt = []
|
||||
|
||||
if gender == "-":
|
||||
gender = ""
|
||||
else:
|
||||
if age <= 25 and gender == 'Woman':
|
||||
gender = 'girl'
|
||||
if age <= 25 and gender == 'Man':
|
||||
gender = 'boy'
|
||||
gender = " " + gender + " "
|
||||
|
||||
if nationality_1 != '-' and nationality_2 != '-':
|
||||
nationality = f"[{nationality_1}:{nationality_2}:{round(nationality_mix, 2)}]"
|
||||
elif nationality_1 != '-':
|
||||
nationality = nationality_1 + " "
|
||||
elif nationality_2 != '-':
|
||||
nationality = nationality_2 + " "
|
||||
else:
|
||||
nationality = ""
|
||||
|
||||
if prompt_start != "":
|
||||
prompt.append(f"{prompt_start}")
|
||||
|
||||
if shot != "-" and shot_weight > 0:
|
||||
prompt.append(f"({shot}:{round(shot_weight, 2)})")
|
||||
|
||||
prompt.append(f"({nationality}{gender}{round(age)}-years-old:1.5)")
|
||||
|
||||
if body_type != "-" and body_type_weight > 0:
|
||||
prompt.append(f"({body_type}, {body_type} body:{round(body_type_weight, 2)})")
|
||||
|
||||
if model_pose != "-":
|
||||
prompt.append(f"({model_pose}:1.5)")
|
||||
|
||||
if eyes_color != "-":
|
||||
prompt.append(f"({eyes_color} eyes:1.25)")
|
||||
|
||||
if facial_expression != "-" and facial_expression_weight > 0:
|
||||
prompt.append(
|
||||
f"({facial_expression}, {facial_expression} expression:{round(facial_expression_weight, 2)})")
|
||||
|
||||
if face_shape != "-" and face_shape_weight > 0:
|
||||
prompt.append(f"({face_shape} shape face:{round(face_shape_weight, 2)})")
|
||||
|
||||
if hair_style != "-":
|
||||
prompt.append(f"({hair_style} hairstyle:1.25)")
|
||||
|
||||
if hair_color != "-":
|
||||
prompt.append(f"({hair_color} hair:1.25)")
|
||||
|
||||
if beard != "-":
|
||||
prompt.append(f"({beard}:1.15)")
|
||||
|
||||
if disheveled != "-" and disheveled > 0:
|
||||
prompt.append(f"(disheveled:{round(disheveled, 2)})")
|
||||
|
||||
if prompt_additional != "":
|
||||
prompt.append(f"{prompt_additional}")
|
||||
|
||||
if skin_details > 0:
|
||||
prompt.append(f"(skin details, skin texture:{round(skin_details, 2)})")
|
||||
|
||||
if skin_pores > 0:
|
||||
prompt.append(f"(skin pores:{round(skin_pores, 2)})")
|
||||
|
||||
if skin_imperfections > 0:
|
||||
prompt.append(f"(skin imperfections:{round(skin_imperfections, 2)})")
|
||||
|
||||
if skin_acne > 0:
|
||||
prompt.append(f"(acne, skin with acne:{round(skin_acne, 2)})")
|
||||
|
||||
if wrinkles > 0:
|
||||
prompt.append(f"(skin imperfections:{round(wrinkles, 2)})")
|
||||
|
||||
if tanned_skin > 0:
|
||||
prompt.append(f"(tanned skin:{round(tanned_skin, 2)})")
|
||||
|
||||
if dimples > 0:
|
||||
prompt.append(f"(dimples:{round(dimples, 2)})")
|
||||
|
||||
if freckles > 0:
|
||||
prompt.append(f"(freckles:{round(freckles, 2)})")
|
||||
|
||||
if moles > 0:
|
||||
prompt.append(f"(skin pores:{round(moles, 2)})")
|
||||
|
||||
if eyes_details > 0:
|
||||
prompt.append(f"(eyes details:{round(eyes_details, 2)})")
|
||||
|
||||
if iris_details > 0:
|
||||
prompt.append(f"(iris details:{round(iris_details, 2)})")
|
||||
|
||||
if circular_iris > 0:
|
||||
prompt.append(f"(circular iris:{round(circular_iris, 2)})")
|
||||
|
||||
if circular_pupil > 0:
|
||||
prompt.append(f"(circular pupil:{round(circular_pupil, 2)})")
|
||||
|
||||
if facial_asymmetry > 0:
|
||||
prompt.append(f"(facial asymmetry, face asymmetry:{round(facial_asymmetry, 2)})")
|
||||
|
||||
if light_type != '-' and light_weight > 0:
|
||||
if light_direction != '-':
|
||||
prompt.append(f"({light_type} {light_direction}:{round(light_weight, 2)})")
|
||||
else:
|
||||
prompt.append(f"({light_type}:{round(light_weight, 2)})")
|
||||
|
||||
if prompt_end != "":
|
||||
prompt.append(f"{prompt_end}")
|
||||
|
||||
prompt = ", ".join(prompt)
|
||||
prompt = prompt.lower()
|
||||
|
||||
if photorealism_improvement == "enable":
|
||||
prompt = prompt + ", (professional photo, balanced photo, balanced exposure:1.2), (film grain:1.15)"
|
||||
|
||||
if photorealism_improvement == "enable":
|
||||
negative_prompt = negative_prompt + ", (shinny skin, reflections on the skin, skin reflections:1.25)"
|
||||
|
||||
log_node_info("Portrait Master as generate the prompt:", prompt)
|
||||
|
||||
return io.NodeOutput(prompt, negative_prompt)
|
||||
|
||||
# 多角度
|
||||
class multiAngle(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="easy multiAngle",
|
||||
category="EasyUse/Prompt",
|
||||
inputs=[
|
||||
io.Custom(io_type="EASY_MULTI_ANGLE").Input("multi_angle", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output("prompt", is_output_list=True),
|
||||
io.Custom(io_type="EASY_MULTI_ANGLE").Output("params"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, multi_angle=None, **kwargs):
|
||||
if multi_angle is None:
|
||||
return io.NodeOutput([""])
|
||||
|
||||
if isinstance(multi_angle, str):
|
||||
try:
|
||||
multi_angle = json.loads(multi_angle)
|
||||
except:
|
||||
raise Exception(f"Invalid multi angle: {multi_angle}")
|
||||
|
||||
prompts = []
|
||||
for angle_data in multi_angle:
|
||||
rotate = angle_data.get("rotate", 0)
|
||||
vertical = angle_data.get("vertical", 0)
|
||||
zoom = angle_data.get("zoom", 5)
|
||||
add_angle_prompt = angle_data.get("add_angle_prompt", True)
|
||||
|
||||
# Validate input ranges
|
||||
rotate = max(0, min(360, int(rotate)))
|
||||
vertical = max(-90, min(90, int(vertical)))
|
||||
zoom = max(0.0, min(10.0, float(zoom)))
|
||||
|
||||
h_angle = rotate % 360
|
||||
|
||||
# Horizontal direction mapping
|
||||
h_suffix = "" if add_angle_prompt else " quarter"
|
||||
if h_angle < 22.5 or h_angle >= 337.5: h_direction = "front view"
|
||||
elif h_angle < 67.5: h_direction = f"front-right{h_suffix} view"
|
||||
elif h_angle < 112.5: h_direction = "right side view"
|
||||
elif h_angle < 157.5: h_direction = f"back-right{h_suffix} view"
|
||||
elif h_angle < 202.5: h_direction = "back view"
|
||||
elif h_angle < 247.5: h_direction = f"back-left{h_suffix} view"
|
||||
elif h_angle < 292.5: h_direction = "left side view"
|
||||
else: h_direction = f"front-left{h_suffix} view"
|
||||
|
||||
# Vertical direction mapping
|
||||
if add_angle_prompt:
|
||||
if vertical == -90:
|
||||
v_direction = "bottom-looking-up perspective, extreme worm's eye view, focus subject bottom"
|
||||
elif vertical < -75:
|
||||
v_direction = "bottom-looking-up perspective, extreme worm's eye view"
|
||||
elif vertical < -45:
|
||||
v_direction = "ultra-low angle"
|
||||
elif vertical < -15:
|
||||
v_direction = "low angle"
|
||||
elif vertical < 15:
|
||||
v_direction = "eye level"
|
||||
elif vertical < 45:
|
||||
v_direction = "high angle"
|
||||
elif vertical < 75:
|
||||
v_direction = "bird's eye view"
|
||||
elif vertical < 90:
|
||||
v_direction = "top-down perspective, looking straight down at the top of the subject"
|
||||
else:
|
||||
v_direction = "top-down perspective, looking straight down at the top of the subject, face not visible, focus on subject head"
|
||||
else:
|
||||
if vertical < -15:
|
||||
v_direction = "low-angle shot"
|
||||
elif vertical < 15:
|
||||
v_direction = "eye-level shot"
|
||||
elif vertical < 45:
|
||||
v_direction = "elevated shot"
|
||||
elif vertical < 75:
|
||||
v_direction = "high-angle shot"
|
||||
elif vertical < 90:
|
||||
v_direction = "top-down perspective, looking straight down at the top of the subject"
|
||||
else:
|
||||
v_direction = "top-down perspective, looking straight down at the top of the subject, face not visible, focus on subject head"
|
||||
|
||||
# Distance/zoom mapping
|
||||
if add_angle_prompt:
|
||||
if zoom < 2: distance = "extreme wide shot"
|
||||
elif zoom < 4: distance = "wide shot"
|
||||
elif zoom < 6: distance = "medium shot"
|
||||
elif zoom < 8: distance = "close-up"
|
||||
else: distance = "extreme close-up"
|
||||
else:
|
||||
if zoom < 2: distance = "extreme wide shot"
|
||||
elif zoom < 4: distance = "wide shot"
|
||||
elif zoom < 6: distance = "medium shot"
|
||||
elif zoom < 8: distance = "close-up"
|
||||
else: distance = "extreme close-up"
|
||||
|
||||
# Build prompt
|
||||
if add_angle_prompt:
|
||||
prompt = f"{h_direction}, {v_direction}, {distance} (horizontal: {rotate}, vertical: {vertical}, zoom: {zoom:.1f})"
|
||||
else:
|
||||
prompt = f"{h_direction} {v_direction} {distance}"
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
return io.NodeOutput(prompts, multi_angle)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"easy positive": positivePrompt,
|
||||
"easy negative": negativePrompt,
|
||||
"easy wildcards": wildcardsPrompt,
|
||||
"easy wildcardsMatrix": wildcardsPromptMatrix,
|
||||
"easy prompt": prompt,
|
||||
"easy promptList": promptList,
|
||||
"easy promptLine": promptLine,
|
||||
"easy promptAwait": promptAwait,
|
||||
"easy promptConcat": promptConcat,
|
||||
"easy promptReplace": promptReplace,
|
||||
"easy stylesSelector": stylesPromptSelector,
|
||||
"easy portraitMaster": portraitMaster,
|
||||
"easy multiAngle": multiAngle,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"easy positive": "Positive",
|
||||
"easy negative": "Negative",
|
||||
"easy wildcards": "Wildcards",
|
||||
"easy wildcardsMatrix": "Wildcards Matrix",
|
||||
"easy prompt": "Prompt",
|
||||
"easy promptList": "PromptList",
|
||||
"easy promptLine": "PromptLine",
|
||||
"easy promptAwait": "PromptAwait",
|
||||
"easy promptConcat": "PromptConcat",
|
||||
"easy promptReplace": "PromptReplace",
|
||||
"easy stylesSelector": "Styles Selector",
|
||||
"easy portraitMaster": "Portrait Master",
|
||||
"easy multiAngle": "Multi Angle",
|
||||
}
|
||||
1360
custom_nodes/ComfyUI-Easy-Use/py/nodes/samplers.py
Normal file
1360
custom_nodes/ComfyUI-Easy-Use/py/nodes/samplers.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user