Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
11
custom_nodes/x-flux-comfyui/xflux/src/flux/__init__.py
Normal file
11
custom_nodes/x-flux-comfyui/xflux/src/flux/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
try:
|
||||
from ._version import version as __version__ # type: ignore
|
||||
from ._version import version_tuple
|
||||
except ImportError:
|
||||
__version__ = "unknown (no version information available)"
|
||||
version_tuple = (0, 0, "unknown", "noinfo")
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
PACKAGE = __package__.replace("_", "-")
|
||||
PACKAGE_ROOT = Path(__file__).parent
|
||||
38
custom_nodes/x-flux-comfyui/xflux/src/flux/annotator/util.py
Normal file
38
custom_nodes/x-flux-comfyui/xflux/src/flux/annotator/util.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
x = x[:, :, None]
|
||||
assert x.ndim == 3
|
||||
H, W, C = x.shape
|
||||
assert C == 1 or C == 3 or C == 4
|
||||
if C == 3:
|
||||
return x
|
||||
if C == 1:
|
||||
return np.concatenate([x, x, x], axis=2)
|
||||
if C == 4:
|
||||
color = x[:, :, 0:3].astype(np.float32)
|
||||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
||||
y = color * alpha + 255.0 * (1.0 - alpha)
|
||||
y = y.clip(0, 255).astype(np.uint8)
|
||||
return y
|
||||
|
||||
|
||||
def resize_image(input_image, resolution):
|
||||
H, W, C = input_image.shape
|
||||
H = float(H)
|
||||
W = float(W)
|
||||
k = float(resolution) / min(H, W)
|
||||
H *= k
|
||||
W *= k
|
||||
H = int(np.round(H / 64.0)) * 64
|
||||
W = int(np.round(W / 64.0)) * 64
|
||||
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
||||
return img
|
||||
222
custom_nodes/x-flux-comfyui/xflux/src/flux/controlnet.py
Normal file
222
custom_nodes/x-flux-comfyui/xflux/src/flux/controlnet.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange
|
||||
|
||||
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
class ControlNetFlux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, params: FluxParams, controlnet_depth=2):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(controlnet_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.gradient_checkpointing = False
|
||||
self.input_hint_block = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
@property
|
||||
def attn_processors(self):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.double_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
return controlnet_block_res_samples
|
||||
30
custom_nodes/x-flux-comfyui/xflux/src/flux/math.py
Normal file
30
custom_nodes/x-flux-comfyui/xflux/src/flux/math.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
217
custom_nodes/x-flux-comfyui/xflux/src/flux/model.py
Normal file
217
custom_nodes/x-flux-comfyui/xflux/src/flux/model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange
|
||||
|
||||
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list[int]
|
||||
theta: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, params: FluxParams):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
def attn_processors(self):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
|
||||
if hasattr(module, "set_processor"):
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
block_controlnet_hidden_states=None,
|
||||
guidance: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
if block_controlnet_hidden_states is not None:
|
||||
controlnet_depth = len(block_controlnet_hidden_states)
|
||||
for index_block, block in enumerate(self.double_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
# controlnet residual
|
||||
if block_controlnet_hidden_states is not None:
|
||||
img = img + block_controlnet_hidden_states[index_block % 2]
|
||||
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if torch.is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
img,
|
||||
vec,
|
||||
pe,
|
||||
)
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@@ -0,0 +1,312 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int
|
||||
in_channels: int
|
||||
ch: int
|
||||
out_ch: int
|
||||
ch_mult: list[int]
|
||||
num_res_blocks: int
|
||||
z_channels: int
|
||||
scale_factor: float
|
||||
shift_factor: float
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class DiagonalGaussian(nn.Module):
|
||||
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
||||
super().__init__()
|
||||
self.sample = sample
|
||||
self.chunk_dim = chunk_dim
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
else:
|
||||
return mean
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.reg = DiagonalGaussian()
|
||||
|
||||
self.scale_factor = params.scale_factor
|
||||
self.shift_factor = params.shift_factor
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
z = self.reg(self.encoder(x))
|
||||
z = self.scale_factor * (z - self.shift_factor)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = z / self.scale_factor + self.shift_factor
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
@@ -0,0 +1,38 @@
|
||||
from torch import Tensor, nn
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
||||
T5Tokenizer)
|
||||
|
||||
|
||||
class HFEmbedder(nn.Module):
|
||||
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
||||
super().__init__()
|
||||
self.is_clip = version.startswith("openai")
|
||||
self.max_length = max_length
|
||||
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
||||
|
||||
if self.is_clip:
|
||||
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
||||
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
||||
else:
|
||||
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
||||
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
||||
|
||||
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
||||
|
||||
def forward(self, text: list[str]) -> Tensor:
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
outputs = self.hf_module(
|
||||
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
||||
attention_mask=None,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return outputs[self.output_key]
|
||||
358
custom_nodes/x-flux-comfyui/xflux/src/flux/modules/layers.py
Normal file
358
custom_nodes/x-flux-comfyui/xflux/src/flux/modules/layers.py
Normal file
@@ -0,0 +1,358 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..math import attention, rope
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
class FLuxSelfAttnProcessor:
|
||||
def __call__(self, attn, x, pe, **attention_kwargs):
|
||||
print('2' * 30)
|
||||
|
||||
qkv = attn.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = attn.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = attn.proj(x)
|
||||
return x
|
||||
|
||||
class LoraFluxAttnProcessor(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
||||
super().__init__()
|
||||
self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
||||
self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
|
||||
self.lora_weight = lora_weight
|
||||
|
||||
|
||||
def __call__(self, attn, x, pe, **attention_kwargs):
|
||||
qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = attn.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
|
||||
print('1' * 30)
|
||||
print(x.norm(), (self.proj_lora(x) * self.lora_weight).norm(), 'norm')
|
||||
return x
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
def forward():
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModulationOut:
|
||||
shift: Tensor
|
||||
scale: Tensor
|
||||
gate: Tensor
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||
|
||||
return (
|
||||
ModulationOut(*out[:3]),
|
||||
ModulationOut(*out[3:]) if self.is_double else None,
|
||||
)
|
||||
|
||||
class DoubleStreamBlockLoraProcessor(nn.Module):
|
||||
def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
|
||||
super().__init__()
|
||||
self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
||||
self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
||||
self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
|
||||
self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
|
||||
self.lora_weight = lora_weight
|
||||
|
||||
def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
|
||||
img_mod1, img_mod2 = attn.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = attn.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = attn.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn1 = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * attn.img_attn.proj(img_attn) + img_mod1.gate * self.proj_lora1(img_attn) * self.lora_weight
|
||||
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn) + txt_mod1.gate * self.proj_lora2(txt_attn) * self.lora_weight
|
||||
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
class DoubleStreamBlockProcessor(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
|
||||
|
||||
img_mod1, img_mod2 = attn.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = attn.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = attn.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = attn.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
||||
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn1 = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
||||
return img, txt
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
processor = DoubleStreamBlockProcessor()
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor) -> None:
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self):
|
||||
return self.processor
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||
return self.processor(self, img, txt, vec, pe)
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
# proj and mlp_out
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
248
custom_nodes/x-flux-comfyui/xflux/src/flux/sampling.py
Normal file
248
custom_nodes/x-flux-comfyui/xflux/src/flux/sampling.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
|
||||
from .model import Flux
|
||||
from .modules.conditioner import HFEmbedder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
|
||||
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(
|
||||
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
neg_txt: Tensor,
|
||||
neg_txt_ids: Tensor,
|
||||
neg_vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
true_gs = 1,
|
||||
timestep_to_start_cfg=0,
|
||||
image2image_strength=None,
|
||||
orig_image = None,
|
||||
):
|
||||
i = 0
|
||||
|
||||
#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if image2image_strength is not None and orig_image is not None:
|
||||
t_idx = int((1 - image2image_strength) * len(timesteps))
|
||||
t = timesteps[t_idx]
|
||||
timesteps = timesteps[t_idx:]
|
||||
img = t * img + (1.0 - t) * orig_image.to(img.dtype)
|
||||
# this is ignored for schnell
|
||||
if hasattr(model, "guidance_in"):
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
guidance_vec = None
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
if i >= timestep_to_start_cfg:
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
pred = neg_pred + true_gs * (pred - neg_pred)
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
i += 1
|
||||
return img
|
||||
|
||||
def denoise_controlnet(
|
||||
model: Flux,
|
||||
controlnet:None,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
neg_txt: Tensor,
|
||||
neg_txt_ids: Tensor,
|
||||
neg_vec: Tensor,
|
||||
controlnet_cond,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
true_gs = 1,
|
||||
controlnet_gs=0.7,
|
||||
timestep_to_start_cfg=0,
|
||||
image2image_strength=None,
|
||||
orig_image = None,
|
||||
):
|
||||
i = 0
|
||||
|
||||
#init_latents = rearrange(init_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if image2image_strength is not None and orig_image is not None:
|
||||
t_idx = int((1 - image2image_strength) * len(timesteps))
|
||||
t = timesteps[t_idx]
|
||||
timesteps = timesteps[t_idx:]
|
||||
img = t * img + (1.0 - t) * orig_image.to(img.dtype)
|
||||
# this is ignored for schnell
|
||||
if hasattr(model, "guidance_in"):
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
else:
|
||||
guidance_vec = None
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
block_res_samples = controlnet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=controlnet_cond,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
block_controlnet_hidden_states=[i * controlnet_gs for i in block_res_samples]
|
||||
)
|
||||
if i >= timestep_to_start_cfg:
|
||||
neg_block_res_samples = controlnet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=controlnet_cond,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
neg_pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=neg_txt,
|
||||
txt_ids=neg_txt_ids,
|
||||
y=neg_vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
block_controlnet_hidden_states=[i * controlnet_gs for i in neg_block_res_samples]
|
||||
)
|
||||
pred = neg_pred + true_gs * (pred - neg_pred)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
i += 1
|
||||
return img
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
350
custom_nodes/x-flux-comfyui/xflux/src/flux/util.py
Normal file
350
custom_nodes/x-flux-comfyui/xflux/src/flux/util.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file as load_sft
|
||||
|
||||
from .model import Flux, FluxParams
|
||||
from .controlnet import ControlNetFlux
|
||||
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .modules.conditioner import HFEmbedder
|
||||
|
||||
|
||||
|
||||
def load_safetensors(path):
|
||||
tensors = {}
|
||||
with safe_open(path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key)
|
||||
return tensors
|
||||
|
||||
def get_lora_rank(checkpoint):
|
||||
for k in checkpoint.keys():
|
||||
if k.endswith(".down.weight"):
|
||||
return checkpoint[k].shape[0]
|
||||
|
||||
def load_checkpoint(local_path, repo_id, name):
|
||||
if local_path is not None:
|
||||
if '.safetensors' in local_path:
|
||||
print("Loading .safetensors checkpoint...")
|
||||
checkpoint = load_safetensors(local_path)
|
||||
else:
|
||||
print("Loading checkpoint...")
|
||||
checkpoint = torch.load(local_path, map_location='cpu')
|
||||
elif repo_id is not None and name is not None:
|
||||
print("Loading checkpoint from repo id...")
|
||||
checkpoint = load_from_repo_id(repo_id, name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"LOADING ERROR: you must specify local_path or repo_id with name in HF to download"
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def c_crop(image):
|
||||
width, height = image.size
|
||||
new_size = min(width, height)
|
||||
left = (width - new_size) / 2
|
||||
top = (height - new_size) / 2
|
||||
right = (width + new_size) / 2
|
||||
bottom = (height + new_size) / 2
|
||||
return image.crop((left, top, right, bottom))
|
||||
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, name: str, device: str):
|
||||
if name == "canny":
|
||||
processor = CannyDetector()
|
||||
elif name == "openpose":
|
||||
processor = DWposeDetector(device)
|
||||
elif name == "depth":
|
||||
processor = MidasDetector()
|
||||
elif name == "hed":
|
||||
processor = HEDdetector()
|
||||
elif name == "hough":
|
||||
processor = MLSDdetector()
|
||||
elif name == "tile":
|
||||
processor = TileDetector()
|
||||
self.name = name
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, image: Image, width: int, height: int):
|
||||
image = c_crop(image)
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image)
|
||||
if self.name == "canny":
|
||||
result = self.processor(image, low_threshold=100, high_threshold=200)
|
||||
elif self.name == "hough":
|
||||
result = self.processor(image, thr_v=0.05, thr_d=5)
|
||||
elif self.name == "depth":
|
||||
result = self.processor(image)
|
||||
result, _ = result
|
||||
else:
|
||||
result = self.processor(image)
|
||||
|
||||
if result.ndim != 3:
|
||||
result = result[:, :, None]
|
||||
result = np.concatenate([result, result, result], axis=2)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSpec:
|
||||
params: FluxParams
|
||||
ae_params: AutoEncoderParams
|
||||
ckpt_path: str | None
|
||||
ae_path: str | None
|
||||
repo_id: str | None
|
||||
repo_flow: str | None
|
||||
repo_ae: str | None
|
||||
repo_id_ae: str | None
|
||||
|
||||
|
||||
configs = {
|
||||
"flux-dev": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-dev",
|
||||
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-dev.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-fp8": ModelSpec(
|
||||
repo_id="XLabs-AI/flux-dev-fp8",
|
||||
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux-dev-fp8.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_DEV_FP8"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-schnell": ModelSpec(
|
||||
repo_id="black-forest-labs/FLUX.1-schnell",
|
||||
repo_id_ae="black-forest-labs/FLUX.1-dev",
|
||||
repo_flow="flux1-schnell.safetensors",
|
||||
repo_ae="ae.safetensors",
|
||||
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=False,
|
||||
),
|
||||
ae_path=os.getenv("AE"),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
||||
if len(missing) > 0 and len(unexpected) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
print("\n" + "-" * 79 + "\n")
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
elif len(missing) > 0:
|
||||
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
||||
elif len(unexpected) > 0:
|
||||
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
||||
|
||||
def load_from_repo_id(repo_id, checkpoint_name):
|
||||
ckpt_path = hf_hub_download(repo_id, checkpoint_name)
|
||||
sd = load_sft(ckpt_path, device='cpu')
|
||||
return sd
|
||||
|
||||
def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
|
||||
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
model = Flux(configs[name].params).to(torch.bfloat16)
|
||||
|
||||
if ckpt_path is not None:
|
||||
print("Loading checkpoint")
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return model
|
||||
|
||||
def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
ckpt_path = configs[name].ckpt_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_flow is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
||||
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
model = Flux(configs[name].params)
|
||||
|
||||
if ckpt_path is not None:
|
||||
print("Loading checkpoint")
|
||||
# load_sft doesn't support torch.device
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return model
|
||||
|
||||
def load_controlnet(name, device, transformer=None):
|
||||
with torch.device(device):
|
||||
controlnet = ControlNetFlux(configs[name].params)
|
||||
if transformer is not None:
|
||||
controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||
return controlnet
|
||||
|
||||
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
||||
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
||||
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
||||
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
||||
|
||||
|
||||
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
||||
ckpt_path = configs[name].ae_path
|
||||
if (
|
||||
ckpt_path is None
|
||||
and configs[name].repo_id is not None
|
||||
and configs[name].repo_ae is not None
|
||||
and hf_download
|
||||
):
|
||||
ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae)
|
||||
|
||||
# Loading the autoencoder
|
||||
print("Init AE")
|
||||
with torch.device("meta" if ckpt_path is not None else device):
|
||||
ae = AutoEncoder(configs[name].ae_params)
|
||||
|
||||
if ckpt_path is not None:
|
||||
sd = load_sft(ckpt_path, device=str(device))
|
||||
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
print_load_warning(missing, unexpected)
|
||||
return ae
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
image = 0.5 * image + 0.5
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
||||
image.device
|
||||
)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
image = 2 * image - 1
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
152
custom_nodes/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py
Normal file
152
custom_nodes/x-flux-comfyui/xflux/src/flux/xflux_pipeline.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from src.flux.modules.layers import DoubleStreamBlockLoraProcessor
|
||||
from src.flux.sampling import denoise, denoise_controlnet, get_noise, get_schedule, prepare, unpack
|
||||
from src.flux.util import (load_ae, load_clip, load_flow_model, load_t5, load_controlnet,
|
||||
load_flow_model_quintized, Annotator, get_lora_rank, load_checkpoint)
|
||||
|
||||
|
||||
class XFluxPipeline:
|
||||
def __init__(self, model_type, device, offload: bool = False, seed: int = None):
|
||||
self.device = torch.device(device)
|
||||
self.offload = offload
|
||||
self.seed = seed
|
||||
self.model_type = model_type
|
||||
|
||||
self.clip = load_clip(self.device)
|
||||
self.t5 = load_t5(self.device, max_length=512)
|
||||
self.ae = load_ae(model_type, device="cpu" if offload else self.device)
|
||||
if "fp8" in model_type:
|
||||
self.model = load_flow_model_quintized(model_type, device="cpu" if offload else self.device)
|
||||
else:
|
||||
self.model = load_flow_model(model_type, device="cpu" if offload else self.device)
|
||||
|
||||
self.hf_lora_collection = "XLabs-AI/flux-lora-collection"
|
||||
self.lora_types_to_names = {
|
||||
"realism": "lora.safetensors",
|
||||
}
|
||||
self.controlnet_loaded = False
|
||||
|
||||
def set_lora(self, local_path: str = None, repo_id: str = None,
|
||||
name: str = None, lora_weight: int = 0.7):
|
||||
checkpoint = load_checkpoint(local_path, repo_id, name)
|
||||
self.update_model_with_lora(checkpoint, lora_weight)
|
||||
|
||||
def set_lora_from_collection(self, lora_type: str = "realism", lora_weight: int = 0.7):
|
||||
checkpoint = load_checkpoint(
|
||||
None, self.hf_lora_collection, self.lora_types_to_names[lora_type]
|
||||
)
|
||||
self.update_model_with_lora(checkpoint, lora_weight)
|
||||
|
||||
def update_model_with_lora(self, checkpoint, lora_weight):
|
||||
rank = get_lora_rank(checkpoint)
|
||||
lora_attn_procs = {}
|
||||
|
||||
for name, _ in self.model.attn_processors.items():
|
||||
lora_attn_procs[name] = DoubleStreamBlockLoraProcessor(dim=3072, rank=rank)
|
||||
lora_state_dict = {}
|
||||
for k in checkpoint.keys():
|
||||
if name in k:
|
||||
lora_state_dict[k[len(name) + 1:]] = checkpoint[k] * lora_weight
|
||||
lora_attn_procs[name].load_state_dict(lora_state_dict)
|
||||
lora_attn_procs[name].to(self.device)
|
||||
|
||||
self.model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
def set_controlnet(self, control_type: str, local_path: str = None, repo_id: str = None, name: str = None):
|
||||
self.model.to(self.device)
|
||||
self.controlnet = load_controlnet(self.model_type, self.device).to(torch.bfloat16)
|
||||
|
||||
checkpoint = load_checkpoint(local_path, repo_id, name)
|
||||
self.controlnet.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if control_type == "depth":
|
||||
self.controlnet_gs = 0.9
|
||||
else:
|
||||
self.controlnet_gs = 0.7
|
||||
self.annotator = Annotator(control_type, self.device)
|
||||
self.controlnet_loaded = True
|
||||
|
||||
def __call__(self,
|
||||
prompt: str,
|
||||
controlnet_image: Image = None,
|
||||
width: int = 512,
|
||||
height: int = 512,
|
||||
guidance: float = 4,
|
||||
num_steps: int = 50,
|
||||
true_gs = 3,
|
||||
neg_prompt: str = '',
|
||||
timestep_to_start_cfg: int = 0,
|
||||
):
|
||||
width = 16 * width // 16
|
||||
height = 16 * height // 16
|
||||
if self.controlnet_loaded:
|
||||
controlnet_image = self.annotator(controlnet_image, width, height)
|
||||
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)
|
||||
|
||||
return self.forward(prompt, width, height, guidance, num_steps, controlnet_image,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg, true_gs=true_gs, neg_prompt=neg_prompt)
|
||||
|
||||
def forward(self, prompt, width, height, guidance, num_steps, controlnet_image=None, timestep_to_start_cfg=0, true_gs=3, neg_prompt=""):
|
||||
x = get_noise(
|
||||
1, height, width, device=self.device,
|
||||
dtype=torch.bfloat16, seed=self.seed
|
||||
)
|
||||
timesteps = get_schedule(
|
||||
num_steps,
|
||||
(width // 8) * (height // 8) // (16 * 16),
|
||||
shift=True,
|
||||
)
|
||||
torch.manual_seed(self.seed)
|
||||
with torch.no_grad():
|
||||
if self.offload:
|
||||
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
||||
inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=prompt)
|
||||
neg_inp_cond = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt)
|
||||
|
||||
if self.offload:
|
||||
self.offload_model_to_cpu(self.t5, self.clip)
|
||||
self.model = self.model.to(self.device)
|
||||
if self.controlnet_loaded:
|
||||
x = denoise_controlnet(
|
||||
self.model, **inp_cond, controlnet=self.controlnet,
|
||||
timesteps=timesteps, guidance=guidance,
|
||||
controlnet_cond=controlnet_image,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg,
|
||||
neg_txt=neg_inp_cond['txt'],
|
||||
neg_txt_ids=neg_inp_cond['txt_ids'],
|
||||
neg_vec=neg_inp_cond['vec'],
|
||||
true_gs=true_gs,
|
||||
controlnet_gs=self.controlnet_gs,
|
||||
)
|
||||
else:
|
||||
x = denoise(self.model, **inp_cond, timesteps=timesteps, guidance=guidance,
|
||||
timestep_to_start_cfg=timestep_to_start_cfg,
|
||||
neg_txt=neg_inp_cond['txt'],
|
||||
neg_txt_ids=neg_inp_cond['txt_ids'],
|
||||
neg_vec=neg_inp_cond['vec'],
|
||||
true_gs=true_gs
|
||||
)
|
||||
|
||||
if self.offload:
|
||||
self.offload_model_to_cpu(self.model)
|
||||
self.ae.decoder.to(x.device)
|
||||
x = unpack(x.float(), height, width)
|
||||
x = self.ae.decode(x)
|
||||
self.offload_model_to_cpu(self.ae.decoder)
|
||||
|
||||
x1 = x.clamp(-1, 1)
|
||||
x1 = rearrange(x1[-1], "c h w -> h w c")
|
||||
output_img = Image.fromarray((127.5 * (x1 + 1.0)).cpu().byte().numpy())
|
||||
return output_img
|
||||
|
||||
def offload_model_to_cpu(self, *models):
|
||||
if not self.offload: return
|
||||
for model in models:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
Reference in New Issue
Block a user