Files
ComfyUI/custom_nodes/whiterabbit/scaling.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

697 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: AGPL-3.0-only
# SPDX-FileCopyrightText: 2025 ArtificialSweetener <artificialsweetenerai@proton.me>
import math
from typing import List, Optional, Tuple
import comfy.utils as comfy_utils
import torch
import torch.nn.functional as F
from comfy import model_management
from torchlanc import lanczos_resize
class UpscaleWithModelAdvanced:
DESCRIPTION = """Based on Comfy's native "Upscale Image (using Model)", with controls exposed to tune for large batches, avoid slow
OOM fallbacks, and create opportunities to optimize for speed.
Defaults
- Behaves about the same as the original node.
Controls
- max_batch_size > 0: process images in chunks to keep VRAM steady and reduce fallback slowdowns.
- tile_size: choose a starting tile; original node defaults to 512. 0 = auto (falls back 512 → 256 → 128 on OOM).
- channels_last: try ON for a speedup on some systems.
- precision: lower (fp16/bf16) can be faster; may impact quality depending on the model.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"upscale_model": (
"UPSCALE_MODEL",
{"tooltip": "Pick your ESRGAN model (e.g. 2× / 4×)."},
),
"image": (
"IMAGE",
{
"tooltip": "Images to upscale. Accepts a batch: frames×H×W×C with values in [01]."
},
),
},
"optional": {
"max_batch_size": (
"INT",
{
"default": 0,
"min": 0,
"max": 4096,
"step": 1,
"tooltip": "How many images to process at once. 0 = all at once. Set >0 if you hit OOM.",
},
),
"tile_size": (
"INT",
{
"default": 0,
"min": 0,
"max": 2048,
"step": 32,
"tooltip": "How big each tile is. 0 = auto (starts at 512 and halves on OOM). Bigger is faster; smaller is safer.",
},
),
"channels_last": (
"BOOLEAN",
{
"default": False,
"tooltip": "Try this ON for a small speed boost on some GPUs. If you see no gain, leave it OFF.",
},
),
"precision": (
["fp32", "fp16", "bf16"],
{
"default": "fp32",
"tooltip": "Math mode. fp32 = safest. fp16/bf16 can be faster on many GPUs, may impact image quality.",
},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(
self,
upscale_model,
image,
max_batch_size=0,
tile_size=0,
channels_last=False,
precision="fp32",
):
def spans(n, cap):
if cap <= 0 or cap >= n:
return [(0, n)]
out = []
i = 0
while i < n:
j = min(n, i + cap)
out.append((i, j))
i = j
return out
device = model_management.get_torch_device()
upscale_model.to(device)
for p in upscale_model.model.parameters():
if p.device != device:
p.data = p.data.to(device)
if p._grad is not None:
p._grad.data = p._grad.data.to(device)
upscale_model.model.eval()
scale = float(getattr(upscale_model, "scale", 4.0))
memory_required = model_management.module_size(upscale_model.model)
memory_required += (
(512 * 512 * 3) * image.element_size() * max(scale, 1.0) * 384.0
)
memory_required += image.nelement() * image.element_size()
model_management.free_memory(memory_required, device)
B, H, W, C = image.shape
out_chunks = []
for s, e in spans(B, int(max_batch_size)):
sub = image[s:e].movedim(-1, -3).to(device, non_blocking=True)
if channels_last and device.type == "cuda":
sub = sub.to(memory_format=torch.channels_last)
tile = 512 if tile_size in (0, None) else int(tile_size)
overlap = 32
oom = True
while oom:
try:
steps = sub.shape[0] * comfy_utils.get_tiled_scale_steps(
sub.shape[3],
sub.shape[2],
tile_x=tile,
tile_y=tile,
overlap=overlap,
)
pbar = comfy_utils.ProgressBar(steps)
if device.type == "cuda" and precision in ("fp16", "bf16"):
amp_dtype = (
torch.float16 if precision == "fp16" else torch.bfloat16
)
with torch.autocast(
device_type="cuda", dtype=amp_dtype
), torch.inference_mode():
sr = comfy_utils.tiled_scale(
sub,
lambda a: upscale_model(a),
tile_x=tile,
tile_y=tile,
overlap=overlap,
upscale_amount=scale,
pbar=pbar,
)
else:
with torch.inference_mode():
sr = comfy_utils.tiled_scale(
sub,
lambda a: upscale_model(a),
tile_x=tile,
tile_y=tile,
overlap=overlap,
upscale_amount=scale,
pbar=pbar,
)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
out_chunks.append(
torch.clamp(sr.movedim(-3, -1), 0.0, 1.0).to("cpu", non_blocking=True)
)
upscale_model.to("cpu")
return (torch.cat(out_chunks, dim=0),)
def _chunk_spans(n: int, max_bs: int) -> List[Tuple[int, int]]:
if max_bs <= 0 or max_bs >= n:
return [(0, n)]
s, spans = 0, []
while s < n:
e = min(n, s + max_bs)
spans.append((s, e))
s = e
return spans
def _floor_mul(x: int, k: int) -> int:
if k <= 1:
return max(1, x)
return x - (x % k)
def _ceil_mul(x: int, k: int) -> int:
if k <= 1:
return max(1, x)
r = x % k
return x if r == 0 else x + (k - r)
def _fit_keep_aspect(sw: int, sh: int, tw: int, th: int) -> Tuple[int, int]:
if tw <= 0 and th <= 0:
return sw, sh
if tw <= 0:
r = th / sh
elif th <= 0:
r = tw / sw
else:
r = min(tw / sw, th / sh)
return max(1, int(round(sw * r))), max(1, int(round(sh * r)))
def _fit_keep_ar_divisible(
sw: int, sh: int, tw: int, th: int, d: int
) -> Tuple[int, int]:
if d <= 1:
return _fit_keep_aspect(sw, sh, tw, th)
fw, fh = _fit_keep_aspect(sw, sh, tw, th)
g = math.gcd(sw, sh)
base_w = d * (sw // g)
base_h = d * (sh // g)
k = min(fw // base_w, fh // base_h)
if k >= 1:
return base_w * k, base_h * k
return max(d, _floor_mul(fw, d)), max(d, _floor_mul(fh, d))
def _scale_then_crop_divisible(
sw: int, sh: int, req_w: int, req_h: int, d: int
) -> Tuple[int, int, int, int]:
"""
AR Scale + Divisible Crop:
1) Scale once (keep AR), locking the SOURCE long side to floor(requested_long/d)*d (>0).
2) Crop ONLY the short side to the largest multiple of d that is ≤ scaled short side and ≤ requested short side (>0).
"""
d = max(1, int(d))
req_w = max(1, int(req_w))
req_h = max(1, int(req_h))
req_w_div = _floor_mul(req_w, d)
req_h_div = _floor_mul(req_h, d)
src_long_is_h = sh >= sw
if src_long_is_h:
if req_h_div == 0:
raise ValueError(
f"AR Scale + Divisible Crop: requested height {req_h}px < divisible_by {d}."
)
scale = req_h_div / sh
rh = req_h_div
rw = max(1, int(round(sw * scale)))
if rw < d:
raise ValueError(
f"AR Scale + Divisible Crop: scaled width {rw}px < divisible_by {d}."
)
if req_w_div == 0:
raise ValueError(
f"AR Scale + Divisible Crop: requested width {req_w}px < divisible_by {d}."
)
out_w = min(req_w_div, _floor_mul(rw, d))
out_h = rh
else:
if req_w_div == 0:
raise ValueError(
f"AR Scale + Divisible Crop: requested width {req_w}px < divisible_by {d}."
)
scale = req_w_div / sw
rw = req_w_div
rh = max(1, int(round(sh * scale)))
if rh < d:
raise ValueError(
f"AR Scale + Divisible Crop: scaled height {rh}px < divisible_by {d}."
)
if req_h_div == 0:
raise ValueError(
f"AR Scale + Divisible Crop: requested height {req_h}px < divisible_by {d}."
)
out_h = min(req_h_div, _floor_mul(rh, d))
out_w = rw
if (rw % d == 0) and (rh % d == 0):
return rw, rh, rw, rh
return rw, rh, out_w, out_h
def _cover_keep_aspect(sw: int, sh: int, tw: int, th: int) -> Tuple[int, int]:
r = max(tw / sw, th / sh)
return max(1, int((sw * r) + 0.999999)), max(1, int((sh * r) + 0.999999))
def _pad_sides(pos: str, pad_w: int, pad_h: int) -> Tuple[int, int, int, int]:
lw = pad_w // 2
rw = pad_w - lw
th = pad_h // 2
bh = pad_h - th
if pos in ("top-left", "top", "top-right"):
th, bh = 0, pad_h
if pos in ("bottom-left", "bottom", "bottom-right"):
th, bh = pad_h, 0
if pos in ("top-left", "left", "bottom-left"):
lw, rw = 0, pad_w
if pos in ("top-right", "right", "bottom-right"):
lw, rw = pad_w, 0
return lw, rw, th, bh
def _crop_offsets(
pos: str, in_w: int, in_h: int, out_w: int, out_h: int
) -> Tuple[int, int]:
dx = max(0, in_w - out_w)
dy = max(0, in_h - out_h)
mapx = {
"top-left": "left",
"left": "left",
"bottom-left": "left",
"top": "center",
"center": "center",
"bottom": "center",
"top-right": "right",
"right": "right",
"bottom-right": "right",
}
mapy = {
"top-left": "top",
"top": "top",
"top-right": "top",
"left": "center",
"center": "center",
"right": "center",
"bottom-left": "bottom",
"bottom": "bottom",
"bottom-right": "bottom",
}
lx = {"left": 0, "center": dx // 2, "right": dx}.get(
mapx.get(pos, "center"), dx // 2
)
ly = {"top": 0, "center": dy // 2, "bottom": dy}.get(
mapy.get(pos, "center"), dy // 2
)
return lx, ly
def _parse_pad_color(
s: str, c: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
s = (s or "").strip()
if not s:
return torch.zeros(c, device=device, dtype=dtype)
try:
parts = [int(p.strip()) for p in s.split(",")]
except Exception:
parts = [0, 0, 0]
rgb = [int(max(0, min(255, v))) for v in (parts + [0, 0, 0])[:3]]
v = torch.tensor(
[rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0], device=device, dtype=dtype
)
if c == 1:
return v[:1]
if c == 4:
return torch.cat([v, torch.ones(1, device=device, dtype=dtype)])
return v
def _nearest_interp(x: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
try:
return F.interpolate(x, size=size, mode="nearest-exact")
except Exception:
return F.interpolate(x, size=size, mode="nearest")
def _divisible_box(w: int, h: int, d: int) -> Tuple[int, int]:
if d <= 1:
return int(w), int(h)
return _floor_mul(int(w), d), _floor_mul(int(h), d)
def _normalize_mode(mode: str) -> str:
key = (mode or "").strip().lower()
table = {
"keep ar": "keep_ar",
"stretch": "stretch",
"crop (cover + crop)": "crop",
"pad (fit + pad)": "pad",
"ar scale + divisible crop": "ar_scale_crop_divisible",
}
if key not in table:
raise ValueError(
"Unknown resize_mode. Use one of: "
"'Keep AR', 'Stretch', 'Crop (Cover + Crop)', 'Pad (Fit + Pad)', 'AR Scale + Divisible Crop'."
)
return table[key]
class BatchResizeWithLanczos:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": (
"IMAGE",
{
"tooltip": "Input batch (B,H,W,C) in [0,1] float.\nProcessed on GPU."
},
),
"width": (
"INT",
{
"default": 1024,
"min": 1,
"max": 16384,
"step": 1,
"tooltip": "Target width (pixels).\n\n"
"Notes:\n"
"• Keep AR / Pad: maximum width for the fit\n"
"• Crop: final output width\n"
"• AR Scale + Divisible Crop: requested width before divisibility",
},
),
"height": (
"INT",
{
"default": 576,
"min": 1,
"max": 16384,
"step": 1,
"tooltip": "Target height (pixels).\n\n"
"Notes:\n"
"• Keep AR / Pad: maximum height for the fit\n"
"• Crop: final output height\n"
"• AR Scale + Divisible Crop: requested height before divisibility",
},
),
"resize_mode": (
[
"Keep AR",
"Stretch",
"Crop (Cover + Crop)",
"Pad (Fit + Pad)",
"AR Scale + Divisible Crop",
],
{
"default": "Keep AR",
"tooltip": "Modes:\n"
"- Keep AR: Fit inside width×height (preserve aspect)\n"
"- Stretch: Force to width×height (may distort)\n"
"- Crop (Cover + Crop): Scale to cover, then crop to width×height\n"
"- Pad (Fit + Pad): Fit inside, then pad to width×height\n"
"- AR Scale + Divisible Crop: Scale by SOURCE long side to ≤ requested divisible; crop ONLY the short side to its divisible",
},
),
"divisible_by": (
"INT",
{
"default": 1,
"min": 1,
"max": 4096,
"step": 1,
"tooltip": "Force output dimensions to multiples of N.\n\n"
"Details:\n"
"• Keep AR: Fit → then step down to the largest size ≤ requested that keeps AR AND makes both sides divisible\n"
"• AR Scale + Divisible Crop: Lock the scaled LONG side to its divisible target; crop ONLY the short side to its divisible\n"
"Set to 1 (or 0 in UI) to disable",
},
),
"max_batch_size": (
"INT",
{
"default": 0,
"min": 0,
"max": 4096,
"step": 1,
"tooltip": "0 = process whole batch\n>0 = chunk the batch to this size",
},
),
"sinc_window": (
"INT",
{
"default": 3,
"min": 1,
"max": 8,
"step": 1,
"tooltip": "Lanczos window size (a). Higher = sharper (more ringing).",
},
),
"pad_color": (
"STRING",
{
"default": "0, 0, 0",
"tooltip": "Pad mode only. RGB as 'r, g, b' (0-255).",
},
),
"crop_position": (
[
"center",
"top-left",
"top",
"top-right",
"left",
"right",
"bottom-left",
"bottom",
"bottom-right",
],
{
"default": "center",
"tooltip": "Where to crop/pad from.\nChoose which edges are preserved for cropping, or where padding is added.",
},
),
"precision": (
["fp32", "fp16", "bf16"],
{"default": "fp32", "tooltip": "Resampling compute dtype."},
),
},
"optional": {
"mask": (
"MASK",
{
"tooltip": "Optional mask (B,H,W) in [0,1].\nResized with nearest.\nFollows the same crop/pad as the image."
},
),
},
}
RETURN_TYPES = ("IMAGE", "INT", "INT", "MASK")
RETURN_NAMES = ("IMAGE", "width", "height", "mask")
FUNCTION = "process"
CATEGORY = "image/resize"
DESCRIPTION = (
"CUDA-accelerated, gamma-correct Lanczos resizer (TorchLanc).\n\n"
"Modes:\n"
"• Keep AR\n"
"• Stretch\n"
"• Crop (Cover + Crop)\n"
"• Pad (Fit + Pad)\n"
"• AR Scale + Divisible Crop\n\n"
"Node functionality based on Resize nodes by Kijai\n\n"
"More from me!: https://artificialsweetener.ai"
)
def process(
self,
image: torch.Tensor,
width: int,
height: int,
resize_mode: str,
divisible_by: int,
max_batch_size: int,
sinc_window: int,
pad_color: str,
crop_position: str,
precision: str,
mask: Optional[torch.Tensor] = None,
):
if image is None or not isinstance(image, torch.Tensor):
raise ValueError(
"image must be a torch.Tensor of shape (B,H,W,C) in [0,1]."
)
B, H, W, C = image.shape
if C not in (1, 3, 4):
raise ValueError(f"Unsupported channel count C={C}. Expected 1, 3 or 4.")
d = int(divisible_by) if int(divisible_by) > 1 else 1
mode = _normalize_mode(resize_mode)
device = torch.device("cuda")
image = image.float().clamp_(0, 1)
if mode == "stretch":
tw, th = _divisible_box(width, height, d)
rw, rh = tw, th
out_w, out_h = tw, th
elif mode == "keep_ar":
rw, rh = _fit_keep_ar_divisible(W, H, int(width), int(height), d)
out_w, out_h = rw, rh
elif mode == "ar_scale_crop_divisible":
if width <= 0 or height <= 0:
raise ValueError(
"AR Scale + Divisible Crop requires non-zero width and height."
)
rw, rh, out_w, out_h = _scale_then_crop_divisible(
W, H, int(width), int(height), d
)
elif mode == "crop":
if width <= 0 or height <= 0:
raise ValueError("Crop requires non-zero width and height.")
tw, th = _divisible_box(width, height, d)
rw, rh = _cover_keep_aspect(W, H, tw, th)
out_w, out_h = tw, th
elif mode == "pad":
if width <= 0 or height <= 0:
raise ValueError("Pad requires non-zero width and height.")
tw, th = _divisible_box(width, height, d)
rw, rh = _fit_keep_aspect(W, H, tw, th)
out_w, out_h = tw, th
else:
raise ValueError(f"Unknown resize_mode: {resize_mode}")
out_imgs: List[torch.Tensor] = []
out_masks: List[torch.Tensor] = []
crop_like = mode in ("crop", "ar_scale_crop_divisible")
pad_like = mode == "pad"
resize_to = (rh, rw) if (crop_like or pad_like) else (out_h, out_w)
pbar = comfy_utils.ProgressBar(B)
for s, e in _chunk_spans(B, int(max_batch_size)):
x = image[s:e].movedim(-1, 1).to(device, non_blocking=True)
y = lanczos_resize(
x,
height=resize_to[0],
width=resize_to[1],
a=int(sinc_window),
precision=str(precision),
clamp=True,
chunk_size=0,
)
ox = oy = 0
left = right = top = bottom = 0
if crop_like:
ox, oy = _crop_offsets(crop_position, rw, rh, out_w, out_h)
y = y[:, :, oy : oy + out_h, ox : ox + out_w]
elif pad_like:
pad_w = max(0, out_w - rw)
pad_h = max(0, out_h - rh)
left, right, top, bottom = _pad_sides(crop_position, pad_w, pad_h)
if d > 1:
base_w = rw + left + right
base_h = rh + top + bottom
right += _ceil_mul(base_w, d) - base_w
bottom += _ceil_mul(base_h, d) - base_h
out_w = rw + left + right
out_h = rh + top + bottom
color = _parse_pad_color(pad_color, C, y.device, y.dtype).view(
1, C, 1, 1
)
canvas = color.expand(y.shape[0], -1, out_h, out_w).clone()
canvas[:, :, top : top + rh, left : left + rw] = y
y = canvas
out_imgs.append(y.to("cpu", non_blocking=False).movedim(1, -1))
if isinstance(mask, torch.Tensor):
m = mask[s:e].unsqueeze(1).to(device, non_blocking=True)
m_res = _nearest_interp(m, size=resize_to)
if crop_like:
m_res = m_res[:, :, oy : oy + out_h, ox : ox + out_w]
elif pad_like:
base = torch.zeros(
(m_res.shape[0], 1, out_h, out_w),
device=m_res.device,
dtype=m_res.dtype,
)
base[:, :, top : top + rh, left : left + rw] = m_res
m_res = base
out_masks.append(m_res.squeeze(1).to("cpu", non_blocking=False))
pbar.update(e - s)
images_out = torch.cat(out_imgs, dim=0)
mask_out = (
torch.cat(out_masks, dim=0)
if out_masks
else torch.zeros((B, out_h, out_w), dtype=torch.float32)
)
return images_out, out_w, out_h, mask_out
NODE_CLASS_MAPPINGS = {"BatchResizeWithLanczos": BatchResizeWithLanczos}
NODE_DISPLAY_NAME_MAPPINGS = {"BatchResizeWithLanczos": "Batch Resize with Lanczos"}