Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
545 lines
19 KiB
Python
545 lines
19 KiB
Python
# SPDX-License-Identifier: AGPL-3.0-only
|
||
# SPDX-FileCopyrightText: 2025 ArtificialSweetener
|
||
|
||
import os
|
||
import random
|
||
import time
|
||
from collections import OrderedDict
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
import comfy.utils as comfy_utils
|
||
import folder_paths
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from PIL import Image
|
||
from torchlanc import lanczos_resize
|
||
|
||
|
||
def _chunk_spans(n: int, cap: int) -> List[Tuple[int, int]]:
|
||
if cap <= 0 or cap >= n:
|
||
return [(0, n)]
|
||
out = []
|
||
i = 0
|
||
while i < n:
|
||
j = min(n, i + cap)
|
||
out.append((i, j))
|
||
i = j
|
||
return out
|
||
|
||
|
||
def _bhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
|
||
return x.movedim(-1, -3)
|
||
|
||
|
||
def _nchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
|
||
return x.movedim(-3, -1)
|
||
|
||
|
||
def _ensure_rgba_nchw(wm: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
wm: (1,H,W,C) in [0,1] → return (4,H,W) float
|
||
C may be 1,3,4; synthesize alpha=1 if missing.
|
||
"""
|
||
if wm.dim() != 4 or wm.shape[0] != 1:
|
||
raise ValueError(
|
||
"watermark must be a single IMAGE tensor of shape (1,H,W,C) in [0,1]."
|
||
)
|
||
_, h, w, c = wm.shape
|
||
x = _bhwc_to_nchw(wm[0]).float().clamp_(0, 1) # (C,H,W)
|
||
if c == 4:
|
||
return x
|
||
if c == 3:
|
||
a = torch.ones(1, h, w, device=x.device, dtype=x.dtype)
|
||
return torch.cat([x, a], dim=0)
|
||
if c == 1:
|
||
rgb = x.repeat(3, 1, 1)
|
||
a = torch.ones(1, h, w, device=x.device, dtype=x.dtype)
|
||
return torch.cat([rgb, a], dim=0)
|
||
raise ValueError(f"Unsupported watermark channel count C={c}. Expected 1, 3 or 4.")
|
||
|
||
|
||
def _load_rgba_from_path(path: str, device: torch.device) -> torch.Tensor:
|
||
"""
|
||
Load an image from disk as RGBA in [0,1] and return (4,H,W) on the target device.
|
||
No rotation or other processing happens here.
|
||
"""
|
||
try:
|
||
with Image.open(path) as im:
|
||
im = im.convert("RGBA")
|
||
arr = np.asarray(im, dtype=np.float32) / 255.0 # (H,W,4) in [0,1]
|
||
except Exception as e:
|
||
raise ValueError(f"Failed to load watermark image from '{path}': {e}")
|
||
t = torch.from_numpy(arr).to(device=device, dtype=torch.float32) # (H,W,4)
|
||
return t.permute(2, 0, 1).contiguous() # (4,H,W)
|
||
|
||
|
||
def _rotate_bicubic_expand(x: torch.Tensor, degrees: float) -> torch.Tensor:
|
||
"""
|
||
x: (N,C,H,W). Rotate around center with bicubic sampling and EXPAND canvas
|
||
(PIL-like `expand=True`). Parts outside input are zero/transparent.
|
||
"""
|
||
deg = float(degrees) % 360.0
|
||
if deg == 0.0:
|
||
return x
|
||
|
||
N, C, H, W = x.shape
|
||
rad = deg * 3.141592653589793 / 180.0
|
||
cosr = float(torch.cos(torch.tensor(rad)))
|
||
sinr = float(torch.sin(torch.tensor(rad)))
|
||
|
||
# Expanded output size (axis-aligned bounding box of the rotated rectangle)
|
||
new_w = int((abs(W * cosr) + abs(H * sinr)) + 0.9999)
|
||
new_h = int((abs(H * cosr) + abs(W * sinr)) + 0.9999)
|
||
new_w = max(1, new_w)
|
||
new_h = max(1, new_h)
|
||
|
||
# Centers in pixel coords
|
||
cx_in = (W - 1) * 0.5
|
||
cy_in = (H - 1) * 0.5
|
||
cx_out = (new_w - 1) * 0.5
|
||
cy_out = (new_h - 1) * 0.5
|
||
|
||
# Output grid in pixel coords
|
||
ys = torch.linspace(0, new_h - 1, new_h, device=x.device, dtype=x.dtype)
|
||
xs = torch.linspace(0, new_w - 1, new_w, device=x.device, dtype=x.dtype)
|
||
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
|
||
|
||
# Inverse rotation: output → input (rotate about centers)
|
||
rx = gx - cx_out
|
||
ry = gy - cy_out
|
||
x_in = cosr * rx + sinr * ry + cx_in
|
||
y_in = -sinr * rx + cosr * ry + cy_in
|
||
|
||
# Normalize to [-1,1] for align_corners=False
|
||
x_norm = (x_in + 0.5) / W * 2.0 - 1.0
|
||
y_norm = (y_in + 0.5) / H * 2.0 - 1.0
|
||
grid = torch.stack((x_norm, y_norm), dim=-1).unsqueeze(0).repeat(N, 1, 1, 1)
|
||
|
||
# Sample
|
||
try:
|
||
return F.grid_sample(
|
||
x, grid, mode="bicubic", padding_mode="zeros", align_corners=False
|
||
)
|
||
except Exception:
|
||
return F.grid_sample(
|
||
x, grid, mode="bilinear", padding_mode="zeros", align_corners=False
|
||
)
|
||
|
||
|
||
def _position_xy(
|
||
position: str,
|
||
base_w: int,
|
||
base_h: int,
|
||
wm_w: int,
|
||
wm_h: int,
|
||
pad_x: int,
|
||
pad_y: int,
|
||
) -> Tuple[int, int]:
|
||
pos = (position or "bottom-right").strip().lower()
|
||
if pos == "center":
|
||
return (base_w - wm_w) // 2, (base_h - wm_h) // 2
|
||
|
||
x = (
|
||
0
|
||
if "left" in pos
|
||
else (base_w - wm_w if "right" in pos else (base_w - wm_w) // 2)
|
||
)
|
||
y = (
|
||
0
|
||
if "top" in pos
|
||
else (base_h - wm_h if "bottom" in pos else (base_h - wm_h) // 2)
|
||
)
|
||
|
||
if "left" in pos:
|
||
x += int(pad_x)
|
||
if "right" in pos:
|
||
x -= int(pad_x)
|
||
if "top" in pos:
|
||
y += int(pad_y)
|
||
if "bottom" in pos:
|
||
y -= int(pad_y)
|
||
return x, y
|
||
|
||
|
||
class _SmallLRU:
|
||
def __init__(self, capacity: int = 6):
|
||
self.capacity = int(max(1, capacity))
|
||
self._m: "OrderedDict[Tuple, Tuple[torch.Tensor, torch.Tensor]]" = OrderedDict()
|
||
|
||
def get(self, key: Tuple):
|
||
v = self._m.get(key)
|
||
if v is not None:
|
||
self._m.move_to_end(key)
|
||
return v
|
||
|
||
def put(self, key: Tuple, value):
|
||
if key in self._m:
|
||
self._m.move_to_end(key)
|
||
self._m[key] = value
|
||
if len(self._m) > self.capacity:
|
||
self._m.popitem(last=False)
|
||
|
||
|
||
class BatchWatermarkSingle:
|
||
"""
|
||
Single-position watermark for image batches.
|
||
|
||
- Scale uses base image WIDTH × (scale/100)
|
||
- Rotation always applies, with clipping (no expand)
|
||
- Padding in pixels (ignored for center)
|
||
- TorchLanc for watermark resize
|
||
- Chunked batches + small LRU cache + optional torch.compile
|
||
"""
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
# Mirror LoadImage: list files from the input directory, allow upload
|
||
input_dir = folder_paths.get_input_directory()
|
||
files = [
|
||
f
|
||
for f in os.listdir(input_dir)
|
||
if os.path.isfile(os.path.join(input_dir, f))
|
||
]
|
||
files = folder_paths.filter_files_content_types(files, ["image"])
|
||
|
||
return {
|
||
"required": {
|
||
"image": (
|
||
"IMAGE",
|
||
{
|
||
"tooltip": "Images to watermark. Accepts (H,W,C) or (B,H,W,C) with values in [0–1]. Processed on GPU."
|
||
},
|
||
),
|
||
"watermark": (
|
||
sorted(files),
|
||
{
|
||
"image_upload": True,
|
||
"tooltip": "Select or upload the watermark image (PNG recommended). The file’s transparency is preserved.",
|
||
},
|
||
),
|
||
"position": (
|
||
["bottom-right", "bottom-left", "top-right", "top-left", "center"],
|
||
{
|
||
"default": "bottom-right",
|
||
"tooltip": "Where to place the watermark. Padding is ignored when 'center' is selected. Rotation clips; no canvas expand.",
|
||
},
|
||
),
|
||
"scale": (
|
||
"INT",
|
||
{
|
||
"default": 70,
|
||
"min": 1,
|
||
"max": 100,
|
||
"step": 1,
|
||
"tooltip": "Width-based scaling. Target watermark width = image width × (scale/100). Aspect ratio preserved.",
|
||
},
|
||
),
|
||
"transparency": (
|
||
"INT",
|
||
{
|
||
"default": 100,
|
||
"min": 0,
|
||
"max": 100,
|
||
"step": 1,
|
||
"tooltip": "Alpha multiplier for the watermark: 100 = unchanged, 0 = fully transparent.",
|
||
},
|
||
),
|
||
"rotation": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 359,
|
||
"step": 1,
|
||
"tooltip": "Rotate the watermark (degrees) with bicubic resampling. Canvas expands so nothing is clipped (PIL-style).",
|
||
},
|
||
),
|
||
"padding_x": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 16384,
|
||
"step": 1,
|
||
"tooltip": "Extra horizontal padding in pixels from the chosen edge (ignored when position='center').",
|
||
},
|
||
),
|
||
"padding_y": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 16384,
|
||
"step": 1,
|
||
"tooltip": "Extra vertical padding in pixels from the chosen edge (ignored when position='center').",
|
||
},
|
||
),
|
||
"optical_padding": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": False,
|
||
"tooltip": "Adjust placement by the watermark’s visual center so equal padding looks right (optical alignment). Affects corner positions; ignored when position='center'.",
|
||
},
|
||
),
|
||
"optical_strength": (
|
||
"INT",
|
||
{
|
||
"default": 40,
|
||
"min": 0,
|
||
"max": 100,
|
||
"step": 5,
|
||
"tooltip": "How strongly to nudge toward visual centering (0–100). 0 = off. Higher values shift more for wide/rotated marks.",
|
||
},
|
||
),
|
||
"max_batch_size": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 4096,
|
||
"step": 1,
|
||
"tooltip": "Process images in chunks to control VRAM. 0 = process the whole batch at once.",
|
||
},
|
||
),
|
||
"sinc_window": (
|
||
"INT",
|
||
{
|
||
"default": 3,
|
||
"min": 1,
|
||
"max": 8,
|
||
"step": 1,
|
||
"tooltip": "Lanczos window size (a) used when resizing the watermark. Higher = sharper (but more ringing).",
|
||
},
|
||
),
|
||
"precision": (
|
||
["fp32", "fp16", "bf16"],
|
||
{
|
||
"default": "fp32",
|
||
"tooltip": "Resampling compute dtype. fp32 = safest quality; fp16/bf16 can be faster on many GPUs.",
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE",)
|
||
FUNCTION = "apply"
|
||
CATEGORY = "image/post"
|
||
DESCRIPTION = "GPU accelerated watermark overlay. TorchLanc resize for quality and speed. Works for single images, but efficient for batches, too!"
|
||
|
||
def apply(
|
||
self,
|
||
image: torch.Tensor,
|
||
watermark: str,
|
||
position: str,
|
||
scale: int,
|
||
transparency: int,
|
||
rotation: int,
|
||
padding_x: int,
|
||
padding_y: int,
|
||
optical_padding: bool,
|
||
optical_strength: int,
|
||
max_batch_size: int,
|
||
sinc_window: int,
|
||
precision: str,
|
||
):
|
||
|
||
if image is None or not isinstance(image, torch.Tensor):
|
||
raise ValueError(
|
||
"image must be a torch.Tensor with shape (H,W,C) or (B,H,W,C) in [0,1]."
|
||
)
|
||
if not isinstance(watermark, str) or not watermark:
|
||
raise ValueError("Select a watermark image from the list (or upload one).")
|
||
|
||
if not folder_paths.exists_annotated_filepath(watermark):
|
||
raise ValueError(f"Invalid watermark file: {watermark}")
|
||
watermark_path = folder_paths.get_annotated_filepath(watermark)
|
||
|
||
# Refuse sequences (we must get a tensor just like Lanczos)
|
||
if isinstance(image, (list, tuple)):
|
||
raise TypeError(
|
||
"Expected IMAGE tensor (H,W,C) or (B,H,W,C); got a sequence. Use 'Image Batch' to re-batch."
|
||
)
|
||
|
||
# Accept both single images (H,W,C) and batches (B,H,W,C); normalize to batch
|
||
if image.dim() == 3:
|
||
image = image.unsqueeze(0) # -> (1,H,W,C)
|
||
elif image.dim() != 4:
|
||
raise ValueError(
|
||
f"Unexpected IMAGE tensor rank {image.dim()}; expected 3 or 4 dims."
|
||
)
|
||
|
||
B, H, W, C = image.shape
|
||
if C not in (1, 3, 4):
|
||
raise ValueError(f"Unsupported channel count C={C}. Expected 1, 3 or 4.")
|
||
|
||
# Common
|
||
device = torch.device("cuda")
|
||
scale = int(scale)
|
||
transparency = max(0, min(100, int(transparency)))
|
||
rotation = int(rotation) % 360
|
||
pad_x = int(padding_x)
|
||
pad_y = int(padding_y)
|
||
optical_padding = bool(optical_padding)
|
||
optical_strength = max(0, min(100, int(optical_strength)))
|
||
|
||
# Prepare watermark once (load RGBA from disk to preserve original transparency)
|
||
wm_rgba = _load_rgba_from_path(watermark_path, device) # (4,hw,ww)
|
||
wm_h0, wm_w0 = int(wm_rgba.shape[1]), int(wm_rgba.shape[2])
|
||
|
||
# Progress
|
||
pbar = comfy_utils.ProgressBar(B)
|
||
|
||
out_chunks: List[torch.Tensor] = []
|
||
|
||
# Compute final watermark once (all images in a Comfy batch share H×W)
|
||
target_w = max(1, int(round(W * (scale / 100.0))))
|
||
target_h = max(1, int(round(wm_h0 * target_w / max(1, wm_w0))))
|
||
|
||
# Premultiply BEFORE resampling to avoid dark fringes
|
||
pm0 = wm_rgba[:3, :, :] * wm_rgba[3:4, :, :]
|
||
a0 = wm_rgba[3:4, :, :]
|
||
wm_pm = torch.cat([pm0, a0], dim=0).unsqueeze(0) # (1,4,hw,ww)
|
||
|
||
wm_resized_pm = lanczos_resize(
|
||
wm_pm,
|
||
height=target_h,
|
||
width=target_w,
|
||
a=int(sinc_window),
|
||
precision=str(precision),
|
||
clamp=True,
|
||
chunk_size=0,
|
||
)[
|
||
0
|
||
] # (4,h,w)
|
||
|
||
# Apply transparency uniformly to premultiplied color AND alpha
|
||
if transparency != 100:
|
||
t = float(transparency) / 100.0
|
||
wm_resized_pm[:3, :, :].mul_(t)
|
||
wm_resized_pm[3:4, :, :].mul_(t)
|
||
|
||
# Rotate in premultiplied space (expand canvas)
|
||
wm_final = _rotate_bicubic_expand(wm_resized_pm.unsqueeze(0), rotation)[
|
||
0
|
||
] # (4,h,w)
|
||
pm_final, a_final = wm_final[:3, :, :], wm_final[3:4, :, :] # (3,h,w), (1,h,w)
|
||
|
||
# Position
|
||
wm_h, wm_w = int(pm_final.shape[1]), int(pm_final.shape[2])
|
||
x, y = _position_xy(position, W, H, wm_w, wm_h, pad_x, pad_y)
|
||
|
||
# Optional optical padding (corner positions only)
|
||
if optical_padding and position != "center":
|
||
a = a_final[0] # (h,w)
|
||
denom = a.sum()
|
||
if float(denom.item() if hasattr(denom, "item") else denom) > 1e-8:
|
||
ys = torch.linspace(0, wm_h - 1, wm_h, device=a.device, dtype=a.dtype)
|
||
xs = torch.linspace(0, wm_w - 1, wm_w, device=a.device, dtype=a.dtype)
|
||
cy = (a.sum(dim=1) * ys).sum() / denom
|
||
cx = (a.sum(dim=0) * xs).sum() / denom
|
||
gx = (wm_w - 1) * 0.5
|
||
gy = (wm_h - 1) * 0.5
|
||
s = float(optical_strength) / 100.0
|
||
dx = (gx - cx) * s # positive when centroid is left of center
|
||
dy = (gy - cy) * s # positive when centroid is above center
|
||
|
||
if "right" in position:
|
||
x += int(round(dx.item()))
|
||
if "left" in position:
|
||
x -= int(round(dx.item()))
|
||
if "bottom" in position:
|
||
y += int(round(dy.item()))
|
||
if "top" in position:
|
||
y -= int(round(dy.item()))
|
||
|
||
# Intersection with base image (clip)
|
||
x0 = max(0, x)
|
||
y0 = max(0, y)
|
||
x1 = min(W, x + wm_w)
|
||
y1 = min(H, y + wm_h)
|
||
|
||
if x1 <= x0 or y1 <= y0:
|
||
out = image.to("cpu", non_blocking=False).float().clamp_(0, 1).contiguous()
|
||
if not torch.is_tensor(out) or out.dim() != 4:
|
||
raise TypeError(
|
||
f"Pass-through produced non-tensor or wrong rank: {type(out)} / {getattr(out,'shape',None)}"
|
||
)
|
||
return (out,)
|
||
|
||
wx0 = x0 - x
|
||
wy0 = y0 - y
|
||
w_w = x1 - x0
|
||
w_h = y1 - y0
|
||
|
||
pm_crop = pm_final[:, wy0 : wy0 + w_h, wx0 : wx0 + w_w].contiguous()
|
||
a_crop = a_final[:, wy0 : wy0 + w_h, wx0 : wx0 + w_w].contiguous()
|
||
|
||
# Process in chunks
|
||
for s, e in _chunk_spans(B, int(max_batch_size)):
|
||
sub = (
|
||
_bhwc_to_nchw(image[s:e])
|
||
.to(device, non_blocking=True)
|
||
.float()
|
||
.clamp_(0, 1)
|
||
)
|
||
|
||
ov_pm = pm_crop.unsqueeze(0).expand(sub.shape[0], -1, -1, -1)
|
||
ov_a = a_crop.unsqueeze(0).expand(sub.shape[0], -1, -1, -1)
|
||
|
||
if C == 1:
|
||
rgb = sub.repeat(1, 3, 1, 1)
|
||
roi = rgb[:, :, y0:y1, x0:x1]
|
||
roi_out = roi * (1.0 - ov_a) + ov_pm
|
||
rgb[:, :, y0:y1, x0:x1] = roi_out
|
||
# Convert back to 1ch (luma)
|
||
y_luma = (
|
||
0.2126 * rgb[:, 0:1] + 0.7152 * rgb[:, 1:2] + 0.0722 * rgb[:, 2:3]
|
||
).clamp_(0, 1)
|
||
sub = y_luma
|
||
elif C == 3:
|
||
roi = sub[:, :3, y0:y1, x0:x1]
|
||
roi_out = roi * (1.0 - ov_a) + ov_pm
|
||
sub[:, :3, y0:y1, x0:x1] = roi_out
|
||
else: # C == 4
|
||
roi = sub[:, :3, y0:y1, x0:x1]
|
||
roi_out = roi * (1.0 - ov_a) + ov_pm
|
||
sub[:, :3, y0:y1, x0:x1] = roi_out
|
||
|
||
out_chunks.append(
|
||
_nchw_to_bhwc(sub).to("cpu", non_blocking=False).clamp_(0, 1)
|
||
)
|
||
pbar.update(e - s)
|
||
|
||
out = torch.cat(out_chunks, dim=0) # CPU BHWC chunks → CPU BHWC batch
|
||
|
||
if out.dim() > 4:
|
||
b_flat = 1
|
||
for s in out.shape[:-3]:
|
||
b_flat *= int(s)
|
||
out = out.reshape(b_flat, *out.shape[-3:])
|
||
if out.dim() == 3:
|
||
out = out.unsqueeze(0)
|
||
if (
|
||
out.dim() == 4
|
||
and out.shape[1] in (1, 3, 4)
|
||
and out.shape[-1] not in (1, 3, 4)
|
||
):
|
||
out = out.permute(0, 2, 3, 1).contiguous()
|
||
if out.dim() != 4:
|
||
raise ValueError(
|
||
f"Unexpected IMAGE tensor shape {tuple(out.shape)}; expected (B,H,W,C)."
|
||
)
|
||
|
||
out = (
|
||
out.to("cpu", non_blocking=False)
|
||
.to(dtype=torch.float32)
|
||
.clamp_(0, 1)
|
||
.contiguous()
|
||
)
|
||
|
||
if not torch.is_tensor(out):
|
||
raise TypeError(f"IMAGE output must be torch.Tensor, got: {type(out)}")
|
||
|
||
return (out,)
|