Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
752 lines
26 KiB
Python
752 lines
26 KiB
Python
# SPDX-License-Identifier: AGPL-3.0-only
|
||
# SPDX-FileCopyrightText: 2025 ArtificialSweetener <artificialsweetenerai@proton.me>
|
||
import math
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from PIL import Image
|
||
|
||
|
||
def _to_lin(x):
|
||
return torch.where(
|
||
x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055).clamp(min=0) ** 2.4
|
||
)
|
||
|
||
|
||
def _to_srgb(x):
|
||
return torch.where(
|
||
x <= 0.0031308, 12.92 * x, 1.055 * x.clamp(min=0) ** (1 / 2.4) - 0.055
|
||
)
|
||
|
||
|
||
def _luma(x):
|
||
return 0.2126 * x[..., 0:1] + 0.7152 * x[..., 1:2] + 0.0722 * x[..., 2:3]
|
||
|
||
|
||
def _sobel_mag(y): # y: NHWC 1ch
|
||
kx = (
|
||
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
|
||
.view(1, 1, 3, 3)
|
||
.to(y.device)
|
||
)
|
||
ky = (
|
||
torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
|
||
.view(1, 1, 3, 3)
|
||
.to(y.device)
|
||
)
|
||
t = F.pad(y.permute(0, 3, 1, 2), (1, 1, 1, 1), mode="reflect")
|
||
gx = F.conv2d(t, kx)
|
||
gy = F.conv2d(t, ky)
|
||
return torch.sqrt(gx * gx + gy * gy).permute(0, 2, 3, 1).contiguous()
|
||
|
||
|
||
def _gauss1d(sigma, r):
|
||
if r <= 0:
|
||
return torch.tensor([1.0], dtype=torch.float32)
|
||
xs = torch.arange(-r, r + 1, dtype=torch.float32)
|
||
k = torch.exp(-(xs * xs) / (2 * sigma * sigma))
|
||
return (k / k.sum()).contiguous()
|
||
|
||
|
||
def _blur_nhwc(x, sigma):
|
||
if sigma <= 0:
|
||
return x
|
||
N, H, W, C = x.shape
|
||
max_r = max(0, min(H, W) // 2 - 1)
|
||
r = min(int(math.ceil(3.0 * sigma)), max_r)
|
||
if r <= 0:
|
||
return x
|
||
k = _gauss1d(sigma, r)
|
||
kH = k.view(1, 1, -1, 1).repeat(C, 1, 1, 1)
|
||
kW = k.view(1, 1, 1, -1).repeat(C, 1, 1, 1)
|
||
t = x.permute(0, 3, 1, 2).contiguous()
|
||
t = F.conv2d(F.pad(t, (0, 0, r, r), mode="reflect"), kH, groups=C)
|
||
t = F.conv2d(F.pad(t, (r, r, 0, 0), mode="reflect"), kW, groups=C)
|
||
return t.permute(0, 2, 3, 1).contiguous()
|
||
|
||
|
||
def _avgpool_tiles(x1, tile):
|
||
t = x1.permute(0, 3, 1, 2)
|
||
o = F.avg_pool2d(t, kernel_size=tile, stride=tile)
|
||
return o.permute(0, 2, 3, 1)
|
||
|
||
|
||
def _mad_tiles(x1, tile):
|
||
t = x1.permute(0, 3, 1, 2)
|
||
N, C, H, W = t.shape
|
||
th, tw = H // tile, W // tile
|
||
t = t[:, :, : th * tile, : tw * tile]
|
||
patches = F.unfold(t, kernel_size=tile, stride=tile) # (N, C*tile*tile, th*tw)
|
||
patches = patches.transpose(1, 2).reshape(-1, tile * tile) # (N*th*tw, K)
|
||
med = patches.median(dim=1, keepdim=True).values
|
||
mad = (patches - med).abs().median(dim=1).values.view(N, th, tw, 1)
|
||
return mad
|
||
|
||
|
||
def _upsample_mask(mask_tile, H, W, mode="nearest"):
|
||
t = mask_tile.permute(0, 3, 1, 2)
|
||
t = F.interpolate(
|
||
t,
|
||
size=(H, W),
|
||
mode=("bilinear" if mode == "bilinear" else "nearest"),
|
||
align_corners=False if mode == "bilinear" else None,
|
||
)
|
||
return t.permute(0, 2, 3, 1)
|
||
|
||
|
||
def _dilate(mask01, r):
|
||
if r <= 0:
|
||
return mask01
|
||
t = mask01.permute(0, 3, 1, 2)
|
||
t = F.pad(t, (r, r, r, r), mode="replicate")
|
||
t = F.max_pool2d(t, kernel_size=2 * r + 1, stride=1)
|
||
return t.permute(0, 2, 3, 1)
|
||
|
||
|
||
def _resize_lanczos(img01, H, W): # (1,Hr,Wr,C) float CPU -> (1,H,W,C) float CPU
|
||
arr = (img01[0].cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
|
||
pil = Image.fromarray(arr, mode="RGB").resize((W, H), resample=Image.LANCZOS)
|
||
out = np.asarray(pil).astype(np.float32) / 255.0
|
||
return torch.from_numpy(out).unsqueeze(0)
|
||
|
||
|
||
class PixelHold:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"frames": (
|
||
"IMAGE",
|
||
{"tooltip": "Your clip (frames×H×W×C, values 0–1)."},
|
||
),
|
||
"ref_source": (
|
||
["external", "batch_index"],
|
||
{
|
||
"default": "external",
|
||
"tooltip": "Pick the reference: an external image or a frame from this clip.",
|
||
},
|
||
),
|
||
"ref_index": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 999999,
|
||
"tooltip": "If using a frame from this clip, which frame to use as the reference.",
|
||
},
|
||
),
|
||
"reference": (
|
||
"IMAGE",
|
||
{
|
||
"default": None,
|
||
"tooltip": "Optional external reference (1×H×W×C). If sizes differ, it will be resized to match.",
|
||
},
|
||
),
|
||
"linearize": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Work in linear color for steadier results on flat areas.",
|
||
},
|
||
),
|
||
"auto_luma": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Auto sensitivity for brightness changes (adapts per frame).",
|
||
},
|
||
),
|
||
"auto_k": (
|
||
"FLOAT",
|
||
{
|
||
"default": 2.5,
|
||
"min": 0.5,
|
||
"max": 6.0,
|
||
"step": 0.1,
|
||
"tooltip": "Auto strength. Higher = lock more to the reference (2–3 is typical).",
|
||
},
|
||
),
|
||
"tau_luma": (
|
||
"FLOAT",
|
||
{
|
||
"default": 1.5 / 255.0,
|
||
"min": 0.0,
|
||
"max": 4.0 / 255.0,
|
||
"step": 0.0005,
|
||
"tooltip": "Manual brightness threshold when Auto is OFF. Lower = stricter (more locking).",
|
||
},
|
||
),
|
||
"tau_grad": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.02,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.001,
|
||
"tooltip": "How much edge change to allow. Lower protects edges more.",
|
||
},
|
||
),
|
||
"mode": (
|
||
["tile", "pixel"],
|
||
{
|
||
"default": "tile",
|
||
"tooltip": "Tile: fast & robust. Pixel: finer but noisier.",
|
||
},
|
||
),
|
||
"tile_size": (
|
||
"INT",
|
||
{
|
||
"default": 32,
|
||
"min": 8,
|
||
"max": 256,
|
||
"step": 8,
|
||
"tooltip": "Tile size when using Tile mode.",
|
||
},
|
||
),
|
||
"score_mode": (
|
||
["l1_tile", "mad_tile"],
|
||
{
|
||
"default": "l1_tile",
|
||
"tooltip": "How tiles measure change: mean abs diff (fast) or median abs dev (robust).",
|
||
},
|
||
),
|
||
"edge_band": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Protect a belt around strong edges to avoid wobble/stretch.",
|
||
},
|
||
),
|
||
"band_radius": (
|
||
"INT",
|
||
{
|
||
"default": 4,
|
||
"min": 0,
|
||
"max": 64,
|
||
"tooltip": "Width of the protected belt (pixels).",
|
||
},
|
||
),
|
||
"tau_edge_low": (
|
||
"FLOAT",
|
||
{
|
||
"default": 1.5 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.25,
|
||
"step": 0.0005,
|
||
"tooltip": "Treat as low-motion below this level (edge belt).",
|
||
},
|
||
),
|
||
"tau_edge_high": (
|
||
"FLOAT",
|
||
{
|
||
"default": 6.0 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.5,
|
||
"step": 0.0005,
|
||
"tooltip": "Treat as high-motion above this level (edge belt).",
|
||
},
|
||
),
|
||
"apply": (
|
||
["all", "lowfreq"],
|
||
{
|
||
"default": "all",
|
||
"tooltip": "Hold the whole image (All) or only its smooth part (Low-freq).",
|
||
},
|
||
),
|
||
"dilate": (
|
||
"INT",
|
||
{
|
||
"default": 1,
|
||
"min": 0,
|
||
"max": 16,
|
||
"tooltip": "Expand the mask (pixels).",
|
||
},
|
||
),
|
||
"feather_sigma": (
|
||
"FLOAT",
|
||
{
|
||
"default": 2.0,
|
||
"min": 0.0,
|
||
"max": 16.0,
|
||
"step": 0.5,
|
||
"tooltip": "Soften mask edges (pixels).",
|
||
},
|
||
),
|
||
"process_on": (
|
||
["auto", "cpu", "gpu"],
|
||
{
|
||
"default": "auto",
|
||
"tooltip": "Choose CPU/GPU. Auto switches to GPU on very large frames.",
|
||
},
|
||
),
|
||
"gpu_clear_every": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 1000,
|
||
"tooltip": "If >0 and using GPU, free memory every N frames.",
|
||
},
|
||
),
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||
RETURN_NAMES = ("images", "mask_preview")
|
||
FUNCTION = "apply_hold"
|
||
CATEGORY = "video utils"
|
||
DESCRIPTION = (
|
||
"Locks parts of each frame to a chosen reference (external image or a frame from the clip) whenever changes are small—"
|
||
"useful for stabilizing flat areas or backgrounds while leaving motion to pass through."
|
||
)
|
||
|
||
@torch.no_grad()
|
||
def apply_hold(
|
||
self,
|
||
frames,
|
||
ref_source="external",
|
||
ref_index=0,
|
||
reference=None,
|
||
linearize=True,
|
||
auto_luma=True,
|
||
auto_k=2.5,
|
||
tau_luma=1.5 / 255.0,
|
||
tau_grad=0.02,
|
||
mode="tile",
|
||
tile_size=32,
|
||
score_mode="l1_tile",
|
||
edge_band=True,
|
||
band_radius=4,
|
||
tau_edge_low=1.5 / 255.0,
|
||
tau_edge_high=6.0 / 255.0,
|
||
apply="all",
|
||
dilate=1,
|
||
feather_sigma=2.0,
|
||
process_on="auto",
|
||
gpu_clear_every=0,
|
||
):
|
||
x = frames if isinstance(frames, torch.Tensor) else torch.tensor(frames)
|
||
B, H, W, C = x.shape
|
||
|
||
if str(ref_source) == "external" and reference is not None:
|
||
ref = (
|
||
reference
|
||
if isinstance(reference, torch.Tensor)
|
||
else torch.tensor(reference)
|
||
)
|
||
if ref.shape[1] != H or ref.shape[2] != W:
|
||
ref = _resize_lanczos(ref[:1].to("cpu"), H, W)
|
||
ref = ref[:1].repeat(B, 1, 1, 1)
|
||
else:
|
||
idx = max(0, min(int(ref_index), B - 1))
|
||
ref = x[idx : idx + 1].repeat(B, 1, 1, 1)
|
||
|
||
x_lin = _to_lin(x) if linearize else x
|
||
r_lin = _to_lin(ref) if linearize else ref
|
||
want_gpu = (process_on == "gpu") or (
|
||
process_on == "auto" and torch.cuda.is_available() and (H * W >= 6_000_000)
|
||
)
|
||
dev = torch.device("cuda") if want_gpu else torch.device("cpu")
|
||
|
||
r_lin = r_lin.to(dev)
|
||
y_r = _luma(r_lin)
|
||
g_r = _sobel_mag(y_r)
|
||
|
||
if apply == "lowfreq":
|
||
LF_r = _blur_nhwc(r_lin.to("cpu"), 13.0)
|
||
|
||
out_frames, mask_frames = [], []
|
||
clear_ctr = 0
|
||
|
||
for i in range(B):
|
||
f = x_lin[i : i + 1].to(dev)
|
||
y_f = _luma(f)
|
||
g_f = _sobel_mag(y_f)
|
||
|
||
dY = (y_f - y_r[i : i + 1]).abs()
|
||
dG = (g_f - g_r[i : i + 1]).abs()
|
||
|
||
if auto_luma:
|
||
med = torch.median(dY.view(-1))
|
||
sigma = 1.4826 * med.item()
|
||
tau_luma_eff = max(0.0, min(4.0 / 255.0, float(auto_k) * float(sigma)))
|
||
else:
|
||
tau_luma_eff = float(tau_luma)
|
||
|
||
if mode == "tile":
|
||
sY = (
|
||
_mad_tiles(dY, tile_size)
|
||
if score_mode == "mad_tile"
|
||
else _avgpool_tiles(dY, tile_size)
|
||
)
|
||
sG = (
|
||
_mad_tiles(dG, tile_size)
|
||
if score_mode == "mad_tile"
|
||
else _avgpool_tiles(dG, tile_size)
|
||
)
|
||
mask = (sY < tau_luma_eff).to(torch.float32) * (
|
||
sG < float(tau_grad)
|
||
).to(torch.float32)
|
||
mask = _upsample_mask(mask, H, W, mode="nearest")
|
||
else:
|
||
mask = (dY < tau_luma_eff).to(torch.float32) * (
|
||
dG < float(tau_grad)
|
||
).to(torch.float32)
|
||
|
||
mask = _dilate(mask, int(dilate))
|
||
if feather_sigma > 0:
|
||
mask = (
|
||
_blur_nhwc(mask.to("cpu"), float(feather_sigma))
|
||
.to(dev)
|
||
.clamp_(0.0, 1.0)
|
||
)
|
||
|
||
if edge_band:
|
||
D = (y_f - y_r[i : i + 1]).abs()
|
||
high = (D > float(tau_edge_high)).to(torch.float32)
|
||
low = (D < float(tau_edge_low)).to(torch.float32)
|
||
band = _dilate(high, int(band_radius)) * low
|
||
if feather_sigma > 0:
|
||
band = (
|
||
_blur_nhwc(band.to("cpu"), float(feather_sigma))
|
||
.to(dev)
|
||
.clamp_(0.0, 1.0)
|
||
)
|
||
mask = (mask * (1.0 - band)).clamp_(0.0, 1.0)
|
||
|
||
if apply == "all":
|
||
composed_lin = mask * r_lin[i : i + 1] + (1.0 - mask) * f
|
||
composed_lin = composed_lin.to("cpu")
|
||
else:
|
||
f_cpu = f.to("cpu")
|
||
LF_f = _blur_nhwc(f_cpu, 13.0)
|
||
HF_f = f_cpu - LF_f
|
||
LF_mix = (
|
||
mask.to("cpu") * LF_r[i : i + 1] + (1.0 - mask.to("cpu")) * LF_f
|
||
)
|
||
composed_lin = (HF_f + LF_mix).clamp(0.0, 1.0)
|
||
|
||
out = _to_srgb(composed_lin) if linearize else composed_lin
|
||
mvis = mask.to("cpu").repeat(1, 1, 1, 3).clamp_(0.0, 1.0)
|
||
|
||
out_frames.append(out.clamp(0, 1))
|
||
mask_frames.append(mvis)
|
||
|
||
if dev.type == "cuda" and int(gpu_clear_every) > 0:
|
||
clear_ctr += 1
|
||
if clear_ctr >= int(gpu_clear_every):
|
||
torch.cuda.empty_cache()
|
||
clear_ctr = 0
|
||
|
||
y_out = torch.cat(out_frames, dim=0)
|
||
mask_preview = torch.cat(mask_frames, dim=0)
|
||
return (y_out, mask_preview)
|
||
|
||
|
||
class BlackSpotCleaner:
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"frames": (
|
||
"IMAGE",
|
||
{"tooltip": "Your clip (frames×H×W×C, values 0–1)."},
|
||
),
|
||
"linearize": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Work in linear color for cleaner detection.",
|
||
},
|
||
),
|
||
"detector": (
|
||
["blackhat", "local_floor"],
|
||
{
|
||
"default": "blackhat",
|
||
"tooltip": "blackhat: tiny dark specks • local_floor: larger soft blotches.",
|
||
},
|
||
),
|
||
"radius": (
|
||
"INT",
|
||
{
|
||
"default": 5,
|
||
"min": 1,
|
||
"max": 31,
|
||
"tooltip": "Approximate spot size (pixels). Increase for bigger blotches.",
|
||
},
|
||
),
|
||
"tau_blackhat": (
|
||
"FLOAT",
|
||
{
|
||
"default": 4.0 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.5,
|
||
"step": 0.0005,
|
||
"tooltip": "Base sensitivity (0–1). Lower = fix more, higher = fix less.",
|
||
},
|
||
),
|
||
"auto_blackhat": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Auto-tune sensitivity from image noise (robust to lighting/texture).",
|
||
},
|
||
),
|
||
"bh_k": (
|
||
"FLOAT",
|
||
{
|
||
"default": 3.0,
|
||
"min": 0.5,
|
||
"max": 8.0,
|
||
"step": 0.1,
|
||
"tooltip": "Auto strength multiplier. Higher = more aggressive fixes.",
|
||
},
|
||
),
|
||
"temporal_gate": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Only fix if darker than neighboring frames (reduces false positives).",
|
||
},
|
||
),
|
||
"temporal_radius": (
|
||
"INT",
|
||
{
|
||
"default": 1,
|
||
"min": 1,
|
||
"max": 3,
|
||
"tooltip": "How many neighbor frames to compare on each side.",
|
||
},
|
||
),
|
||
"grad_guard": (
|
||
"BOOLEAN",
|
||
{
|
||
"default": True,
|
||
"tooltip": "Skip fixes on strong edges/text to avoid halos.",
|
||
},
|
||
),
|
||
"tau_grad_edge": (
|
||
"FLOAT",
|
||
{
|
||
"default": 0.07,
|
||
"min": 0.0,
|
||
"max": 1.0,
|
||
"step": 0.001,
|
||
"tooltip": "Edge strength where fixes are skipped (higher = skip more).",
|
||
},
|
||
),
|
||
"dilate": (
|
||
"INT",
|
||
{
|
||
"default": 1,
|
||
"min": 0,
|
||
"max": 8,
|
||
"tooltip": "Expand the fix mask (pixels).",
|
||
},
|
||
),
|
||
"feather_sigma": (
|
||
"FLOAT",
|
||
{
|
||
"default": 1.5,
|
||
"min": 0.0,
|
||
"max": 16.0,
|
||
"step": 0.5,
|
||
"tooltip": "Soften mask edges (pixels).",
|
||
},
|
||
),
|
||
"process_on": (
|
||
["auto", "cpu", "gpu"],
|
||
{
|
||
"default": "auto",
|
||
"tooltip": "Choose CPU/GPU. Auto switches to GPU on very large frames.",
|
||
},
|
||
),
|
||
"gpu_clear_every": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 1000,
|
||
"tooltip": "If >0 and using GPU, free memory every N frames.",
|
||
},
|
||
),
|
||
},
|
||
"optional": {
|
||
"reference": (
|
||
"IMAGE",
|
||
{
|
||
"tooltip": "Optional external reference floor (1×H×W×C). Resized if needed."
|
||
},
|
||
),
|
||
"ref_source": (
|
||
["none", "external", "batch_index"],
|
||
{
|
||
"default": "none",
|
||
"tooltip": "Choose a floor: none, an external image, or a frame index from this clip.",
|
||
},
|
||
),
|
||
"ref_index": (
|
||
"INT",
|
||
{
|
||
"default": 0,
|
||
"min": 0,
|
||
"max": 999999,
|
||
"tooltip": "If using a frame index as the floor, which one to use.",
|
||
},
|
||
),
|
||
"tau_down": (
|
||
"FLOAT",
|
||
{
|
||
"default": 2.0 / 255.0,
|
||
"min": 0.0,
|
||
"max": 0.5,
|
||
"step": 0.0005,
|
||
"tooltip": "Only lift where the frame is at least this much darker than the floor.",
|
||
},
|
||
),
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||
RETURN_NAMES = ("images", "mask_preview")
|
||
FUNCTION = "clean"
|
||
CATEGORY = "video utils"
|
||
DESCRIPTION = "Removes tiny dark specks and soft blotches by gently lifting only the dark outliers—keeps edges and details safe with guards."
|
||
|
||
@torch.no_grad()
|
||
def clean(
|
||
self,
|
||
frames,
|
||
linearize=True,
|
||
detector="blackhat",
|
||
radius=5,
|
||
tau_blackhat=4.0 / 255.0,
|
||
auto_blackhat=True,
|
||
bh_k=3.0,
|
||
temporal_gate=True,
|
||
temporal_radius=1,
|
||
grad_guard=True,
|
||
tau_grad_edge=0.07,
|
||
dilate=1,
|
||
feather_sigma=1.5,
|
||
process_on="auto",
|
||
gpu_clear_every=0,
|
||
reference=None,
|
||
ref_source="none",
|
||
ref_index=0,
|
||
tau_down=2.0 / 255.0,
|
||
):
|
||
x = frames if isinstance(frames, torch.Tensor) else torch.tensor(frames)
|
||
B, H, W, C = x.shape
|
||
|
||
ref = None
|
||
if str(ref_source) == "external" and reference is not None:
|
||
ref = (
|
||
reference
|
||
if isinstance(reference, torch.Tensor)
|
||
else torch.tensor(reference)
|
||
)
|
||
if ref.shape[1] != H or ref.shape[2] != W:
|
||
ref = _resize_lanczos(ref[:1].to("cpu"), H, W)
|
||
ref = ref[:1].repeat(B, 1, 1, 1)
|
||
elif str(ref_source) == "batch_index":
|
||
idx = max(0, min(int(ref_index), B - 1))
|
||
ref = x[idx : idx + 1].repeat(B, 1, 1, 1)
|
||
|
||
xx = _to_lin(x) if linearize else x
|
||
rr = _to_lin(ref) if (ref is not None and linearize) else ref
|
||
|
||
want_gpu = (process_on == "gpu") or (
|
||
process_on == "auto" and torch.cuda.is_available() and (H * W >= 6_000_000)
|
||
)
|
||
dev = torch.device("cuda") if want_gpu else torch.device("cpu")
|
||
|
||
y = _luma(xx).to(dev)
|
||
g = _sobel_mag(y)
|
||
if rr is not None:
|
||
y_ref = _luma(rr).to(device=y.device, dtype=y.dtype) # match y
|
||
assert (
|
||
y_ref.shape[0] == y.shape[0]
|
||
), f"y_ref B={y_ref.shape[0]} vs y B={y.shape[0]}"
|
||
assert (
|
||
y_ref.shape[1:3] == y.shape[1:3]
|
||
), f"spatial mismatch {y_ref.shape[1:3]} vs {y.shape[1:3]}"
|
||
floor = (y_ref - y) > float(tau_down)
|
||
|
||
r = int(radius)
|
||
if detector == "blackhat":
|
||
k = max(1, 2 * r + 1)
|
||
k = min(k, 2 * min(H, W) - 1)
|
||
t = y.permute(0, 3, 1, 2)
|
||
d = F.max_pool2d(
|
||
F.pad(t, (k // 2, k // 2, k // 2, k // 2), mode="replicate"),
|
||
kernel_size=k,
|
||
stride=1,
|
||
)
|
||
e = -F.max_pool2d(
|
||
F.pad(-d, (k // 2, k // 2, k // 2, k // 2), mode="replicate"),
|
||
kernel_size=k,
|
||
stride=1,
|
||
)
|
||
y_close = e.permute(0, 2, 3, 1)
|
||
score = (y_close - y).clamp_min(0)
|
||
else:
|
||
sigma = max(0.5, r / 2.0)
|
||
Bsm = _blur_nhwc(y.to("cpu"), sigma).to(y.device)
|
||
score = (Bsm - y).clamp_min(0)
|
||
tau = float(tau_blackhat)
|
||
if bool(auto_blackhat):
|
||
region = (g < float(tau_grad_edge)).to(torch.float32)
|
||
if region.sum() < 1:
|
||
region = torch.ones_like(region)
|
||
sel = score[region > 0.5].view(-1)
|
||
if sel.numel() > 0:
|
||
med = torch.median(sel)
|
||
sigma_bh = 1.4826 * torch.median((sel - med).abs())
|
||
tau = max(tau, float(bh_k) * float(sigma_bh))
|
||
|
||
mask = (score > tau).to(torch.float32)
|
||
|
||
if rr is not None:
|
||
floor = (y_ref - y) > float(tau_down)
|
||
mask = torch.maximum(mask, floor.to(torch.float32))
|
||
|
||
if temporal_gate and B > 1:
|
||
idxs = []
|
||
for dt in range(1, int(temporal_radius) + 1):
|
||
if dt < B:
|
||
idxs += [
|
||
torch.clamp(torch.arange(B) - dt, 0, B - 1),
|
||
torch.clamp(torch.arange(B) + dt, 0, B - 1),
|
||
]
|
||
neigh = torch.stack([y[i] for i in torch.stack(idxs, dim=0)], dim=0)
|
||
y_med = torch.median(neigh, dim=0).values
|
||
mask = mask * ((y_med - y) > tau).to(torch.float32)
|
||
|
||
if grad_guard:
|
||
guard = (g < float(tau_grad_edge)).to(torch.float32)
|
||
mask = mask * guard
|
||
|
||
mask = _dilate(mask, int(dilate))
|
||
if feather_sigma > 0:
|
||
mask = (
|
||
_blur_nhwc(mask.to("cpu"), float(feather_sigma))
|
||
.to(dev)
|
||
.clamp_(0.0, 1.0)
|
||
)
|
||
|
||
delta = score * mask
|
||
delta3 = delta.repeat(1, 1, 1, 3)
|
||
out_lin = (xx.to(dev) + delta3).clamp(0.0, 1.0)
|
||
if dev.type == "cuda":
|
||
out_lin = out_lin.to("cpu")
|
||
|
||
out = _to_srgb(out_lin) if linearize else out_lin
|
||
mask_preview = mask.to("cpu").repeat(1, 1, 1, 3).clamp_(0.0, 1.0)
|
||
|
||
if dev.type == "cuda" and int(gpu_clear_every) > 0:
|
||
torch.cuda.empty_cache()
|
||
|
||
return (out.clamp(0, 1), mask_preview)
|