Files
ComfyUI/custom_nodes/whiterabbit/interpolation.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

1169 lines
41 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SPDX-License-Identifier: AGPL-3.0-only
# SPDX-FileCopyrightText: 2025 ArtificialSweetener <artificialsweetenerai@proton.me>
#
# Portions adapted from "ComfyUI-Frame-Interpolation" (MIT License)
# Copyright (c) 20232025 Fannovel16 and contributors
# See LICENSES/MIT-ComfyUI-Frame-Interpolation.txt for the full text.
# Imports for: RIFE_FPS_Resample, RIFE_SeamTimingAnalyzer_Advanced, RIFE_VFI_Advanced
import sys
from pathlib import Path
# Point Python at the sibling base pack
BASE = Path(__file__).resolve().parent.parent / "comfyui-frame-interpolation"
VFI_MODELS = BASE / "vfi_models"
# Prefer external comfyui-frame-interpolation; fall back to vendored /vendor/ inside this repo
p_models = str(
VFI_MODELS
) # external: .../custom_nodes/comfyui-frame-interpolation/vfi_models
p_base = str(BASE) # external: .../custom_nodes/comfyui-frame-interpolation
p_vendor = str(
Path(__file__).resolve().parent / "vendor"
) # <-- fix: repo-local vendor/
if VFI_MODELS.is_dir() and p_models not in sys.path:
sys.path.insert(0, p_models) # exposes external 'rife'
if BASE.is_dir() and p_base not in sys.path:
sys.path.insert(0, p_base) # exposes external 'vfi_utils'
if p_vendor not in sys.path:
sys.path.append(p_vendor) # fallback: vendored 'rife' and 'vfi_utils'
import math
import re
import typing
from fractions import Fraction
# === Cached RIFE model loader (shared by all nodes) ===
from functools import lru_cache
import torch
import torch.nn.functional as F
# --- Comfy + stdlib used by the three nodes ---
from comfy.model_management import get_torch_device
from comfy.utils import ProgressBar as _ComfyProgressBar
from rife import CKPT_NAME_VER_DICT, MODEL_TYPE
# --- imports from the base pack ---
from vfi_utils import (
InterpolationStateList,
generic_frame_loop,
load_file_from_github_release,
postprocess_frames,
preprocess_frames,
)
@lru_cache(maxsize=4)
def _get_rife_model(ckpt_name: str):
import torch
from comfy.model_management import get_torch_device
from rife import CKPT_NAME_VER_DICT, MODEL_TYPE
from rife.rife_arch import IFNet
from vfi_utils import load_file_from_github_release
model_path = load_file_from_github_release(MODEL_TYPE, ckpt_name)
arch_ver = CKPT_NAME_VER_DICT[ckpt_name]
m = IFNet(arch_ver=arch_ver)
sd = torch.load(model_path, map_location="cpu")
m.load_state_dict(sd, strict=True)
m.eval()
return m.to(get_torch_device()).to(memory_format=torch.channels_last)
# ---------- shared helpers (for all nodes) ----------
def _scale_list(scale_factor: float):
s = float(scale_factor)
return [8.0 / s, 4.0 / s, 2.0 / s, 1.0 / s]
def _prep_frames(frames: torch.Tensor) -> torch.Tensor:
x = preprocess_frames(frames)
if x.device.type == "cpu" and torch.cuda.is_available():
x = x.pin_memory()
return x
def _progress(total_ticks: int):
return _ComfyProgressBar(int(max(1, total_ticks)))
def _count_synth(frames: torch.Tensor, multiplier: int) -> int:
N = int(frames.shape[0])
m = max(0, int(multiplier))
return max(0, N - 1) * m
def _make_rife_callback():
def _cb(frame_0, frame_1, timestep, model, scale_list, ensemble, t_mapper):
if torch.is_tensor(timestep):
t_scalar = float(timestep.reshape(-1)[0].item())
else:
t_scalar = float(timestep)
t = t_mapper(t_scalar) if t_mapper is not None else t_scalar
# fast_mode is a no-op for 4.7/4.9; pass False
return model(frame_0, frame_1, t, scale_list, False, ensemble)
return _cb
def _with_progress(cb, pbar):
if pbar is None:
return cb
def _wrapped(*args, **kwargs):
y = cb(*args, **kwargs)
pbar.update(1)
return y
return _wrapped
# ----------------------------------------------------
class RIFE_VFI_Opt:
DESCRIPTION = (
"Interpolate a clip by a chosen multiple using RIFE 4.7/4.9 — inserts evenly spaced in-between frames "
"between every pair (e.g., ×2 adds 1 frame per pair)."
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (
["rife47.pth", "rife49.pth"],
{
"default": "rife47.pth",
"tooltip": "Choose the RIFE 4.7 or 4.9 model file from the base pack.",
},
),
"frames": (
"IMAGE",
{"tooltip": "Your input clip: one image per frame."},
),
"multiplier": (
"INT",
{
"default": 2,
"min": 1,
"tooltip": "Adds extra frames to smooth motion: 2 adds 1 new frame per pair; 4 adds 3.",
},
),
"scale_factor": (
[0.25, 0.5, 1.0, 2.0, 4.0],
{
"default": 1.0,
"tooltip": "Quality vs speed. 1.0 recommended. Lower = faster/softer; higher = sharper/slower.",
},
),
"ensemble": (
"BOOLEAN",
{
"default": True,
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
},
),
"clear_cache_after_n_frames": (
"INT",
{
"default": 10,
"min": 0,
"max": 1000,
"tooltip": "Free up GPU memory every N generated frames (advanced). Set 0 to never.",
},
),
},
"optional": {
"optional_interpolation_states": (
"INTERPOLATION_STATES",
{
"tooltip": "Dont create in-between frames for selected frame pairs (e.g., scene cuts). Timing stays the same."
},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "vfi"
CATEGORY = "video utils"
def vfi(
self,
ckpt_name: typing.AnyStr,
frames: torch.Tensor,
multiplier: int = 2,
scale_factor: float = 1.0,
ensemble: bool = True,
clear_cache_after_n_frames: int = 10,
optional_interpolation_states: InterpolationStateList = None,
**kwargs,
):
model = _get_rife_model(ckpt_name)
frames = _prep_frames(frames)
m_effective = int(multiplier)
if m_effective <= 1:
return (postprocess_frames(frames),)
scale_list = _scale_list(scale_factor)
total_synth = _count_synth(frames, m_effective)
pbar = _progress(total_synth)
cb = _make_rife_callback()
cbp = _with_progress(cb, pbar)
args = [model, scale_list, ensemble, None] # t_mapper=None → identity
with torch.inference_mode():
out = postprocess_frames(
generic_frame_loop(
type(self).__name__,
frames,
clear_cache_after_n_frames,
m_effective,
cbp,
*args,
interpolation_states=optional_interpolation_states,
dtype=torch.float32,
)
)
return (out,)
class RIFE_VFI_Advanced:
"""
Advanced RIFE node exposing timestep controls (by multiple, custom t-schedules).
"""
DESCRIPTION = (
"Custom timing for RIFE 4.7/4.9 — still “interpolate by multiple,” but you control where the in-betweens land "
"(ease in/out, clamps, or your own t-list)."
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (
["rife47.pth", "rife49.pth"],
{
"default": "rife47.pth",
"tooltip": "Choose the RIFE 4.7 or 4.9 model file from the base pack.",
},
),
"frames": (
"IMAGE",
{"tooltip": "Your input clip: one image per frame."},
),
"multiplier": (
"INT",
{
"default": 2,
"min": 0,
"tooltip": "How many new frames to create between each pair. 0 = passthrough (no new frames).",
},
),
"t_mode": (
[
"linear",
"gamma_in",
"gamma_out",
"gamma_in_out",
"bounded_linear",
"custom_list",
],
{
"default": "linear",
"tooltip": "How to spread the new frames over time: straight line, ease in/out, limit the range, or provide your own list.",
},
),
"t_gamma": (
"FLOAT",
{
"default": 1.0,
"min": 0.05,
"max": 10.0,
"step": 0.05,
"tooltip": "Easing strength for gamma modes. Higher = more easing.",
},
),
"t_min": (
"FLOAT",
{
"default": 0.0,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Earliest allowed position between the two frames (0 = exactly the first frame). Use with bounded_linear.",
},
),
"t_max": (
"FLOAT",
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Latest allowed position between the two frames (1 = exactly the next frame). Use with bounded_linear.",
},
),
"scale_factor": (
[0.25, 0.5, 1.0, 2.0, 4.0],
{
"default": 1.0,
"tooltip": "Quality vs speed. 1.0 recommended. Lower = faster/softer; higher = sharper/slower.",
},
),
"ensemble": (
"BOOLEAN",
{
"default": True,
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
},
),
"clear_cache_after_n_frames": (
"INT",
{
"default": 10,
"min": 0,
"max": 1000,
"tooltip": "Free up GPU memory every N generated frames (advanced). Set 0 to never.",
},
),
},
"optional": {
"custom_t_list_csv": (
"STRING",
{
"default": "",
"tooltip": "Exact positions between the two frames (01), comma-separated, e.g. 0.18,0.41,0.66. Overrides the schedule.",
},
),
"optional_interpolation_states": (
"INTERPOLATION_STATES",
{
"tooltip": "Dont create in-between frames for selected frame pairs (e.g., scene cuts). Timing stays the same."
},
),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "vfi_advanced"
CATEGORY = "video utils"
def _map_t(
self, t_scalar, t_mode, t_gamma, t_min, t_max, custom_sorted, m_effective
):
t = max(0.0, min(1.0, float(t_scalar)))
if t_mode == "custom_list" and custom_sorted:
idx = int(round(t * (m_effective + 1))) - 1
idx = max(0, min(len(custom_sorted) - 1, idx))
return float(custom_sorted[idx])
if t_mode == "gamma_in":
t = t ** max(1e-6, t_gamma)
elif t_mode == "gamma_out":
t = 1.0 - (1.0 - t) ** max(1e-6, t_gamma)
elif t_mode == "gamma_in_out":
g = max(1e-6, t_gamma)
if t < 0.5:
t = 0.5 * (2.0 * t) ** g
else:
t = 1.0 - 0.5 * (2.0 * (1.0 - t)) ** g
return max(0.0, min(1.0, t_min + (t_max - t_min) * t))
def vfi_advanced(
self,
ckpt_name,
frames,
clear_cache_after_n_frames=10,
multiplier=2,
ensemble=True,
scale_factor=1.0,
t_mode="linear",
t_gamma=1.0,
t_min=0.0,
t_max=1.0,
custom_t_list_csv="",
optional_interpolation_states=None,
**kwargs,
):
model = _get_rife_model(ckpt_name)
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
frames = _prep_frames(frames)
custom_sorted = None
if custom_t_list_csv and t_mode == "custom_list":
toks = [x for x in re.split(r"[,\s]+", custom_t_list_csv.strip()) if x]
custom_sorted = sorted([max(0.0, min(1.0, float(x))) for x in toks])
m_effective = int(multiplier)
if m_effective <= 0:
return (postprocess_frames(frames),)
def _adv_t_mapper(t_scalar: float) -> float:
return self._map_t(
t_scalar, t_mode, t_gamma, t_min, t_max, custom_sorted, m_effective
)
scale_list = _scale_list(scale_factor)
total_synth = _count_synth(frames, m_effective)
pbar = _progress(total_synth)
cb = _make_rife_callback()
cbp = _with_progress(cb, pbar)
args = [model, scale_list, ensemble, _adv_t_mapper]
with torch.inference_mode():
out = postprocess_frames(
generic_frame_loop(
type(self).__name__,
frames,
clear_cache_after_n_frames,
m_effective,
cbp,
*args,
interpolation_states=optional_interpolation_states,
dtype=torch.float32,
)
)
return (out,)
class RIFE_FPS_Resample:
DESCRIPTION = (
"Convert a clip from one FPS to another using RIFE 4.7/4.9. Non-integer changes synthesize in-betweens; "
"exact integer downscales just decimate. Includes optional stabilizers to reduce flicker and protect edges."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"ckpt_name": (
["rife47.pth", "rife49.pth"],
{
"default": "rife47.pth",
"tooltip": "Choose the RIFE 4.7 or 4.9 model file from the base pack.",
},
),
"frames": (
"IMAGE",
{"tooltip": "Your input clip: one image per frame."},
),
"fps_in": (
"FLOAT",
{
"default": 24.0,
"min": 1e-6,
"max": 1000.0,
"step": 0.01,
"tooltip": "Current frame rate of your clip (frames per second).",
},
),
"fps_out": (
"FLOAT",
{
"default": 60.0,
"min": 1e-6,
"max": 2000.0,
"step": 0.01,
"tooltip": "Target frame rate you want (frames per second).",
},
),
"scale_factor": (
[0.25, 0.5, 1.0, 2.0, 4.0],
{
"default": 1.0,
"tooltip": "Quality vs speed. 1.0 recommended. Lower = faster/softer; higher = sharper/slower.",
},
),
"ensemble": (
"BOOLEAN",
{
"default": True,
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
},
),
"linearize": (
"BOOLEAN",
{
"default": False,
"tooltip": "Work in linear light for more accurate brightness and gradients (slower).",
},
),
"lf_guardrail": (
"BOOLEAN",
{
"default": False,
"tooltip": "Keep overall brightness and gradients close to the originals to reduce flicker.",
},
),
"lf_sigma": (
"FLOAT",
{
"default": 13.0,
"min": 0.0,
"max": 64.0,
"step": 0.5,
"tooltip": "How strong the low-frequency smoothing is. Higher = smoother changes.",
},
),
"source_pair_match": (
"BOOLEAN",
{
"default": False,
"tooltip": "Match exposure and contrast to the source pair to reduce flicker.",
},
),
"match_a_cap": (
"FLOAT",
{
"default": 0.02,
"min": 0.0,
"max": 0.2,
"step": 0.001,
"tooltip": "Maximum change allowed for exposure scale.",
},
),
"match_b_cap": (
"FLOAT",
{
"default": 2.0 / 255.0,
"min": 0.0,
"max": 0.1,
"step": 0.0005,
"tooltip": "Maximum change allowed for brightness offset.",
},
),
"edge_band_lock": (
"BOOLEAN",
{
"default": False,
"tooltip": "Protect sharp edges: near edges, mix in more of the nearest real frame to avoid smearing.",
},
),
"tau_low": (
"FLOAT",
{
"default": 1.5 / 255.0,
"min": 0.0,
"max": 0.25,
"step": 0.0005,
"tooltip": "Edge sensitivity: lower threshold (smaller finds more edges).",
},
),
"tau_high": (
"FLOAT",
{
"default": 6.0 / 255.0,
"min": 0.0,
"max": 0.5,
"step": 0.0005,
"tooltip": "Edge sensitivity: higher threshold (larger finds only strong edges).",
},
),
"band_radius": (
"INT",
{
"default": 4,
"min": 0,
"max": 64,
"tooltip": "Width of the edge protection band (pixels).",
},
),
"band_soft_sigma": (
"FLOAT",
{
"default": 2.0,
"min": 0.0,
"max": 16.0,
"step": 0.5,
"tooltip": "Soften the edge band. Higher = smoother.",
},
),
"clear_cache_after_n_frames": (
"INT",
{
"default": 10,
"min": 1,
"max": 1000,
"tooltip": "Free up GPU memory every N output frames (advanced). Set 0 to never.",
},
),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "resample"
CATEGORY = "video utils"
@staticmethod
def _srgb_to_linear(x):
import torch
return torch.where(
x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055).clamp(min=0) ** 2.4
)
@staticmethod
def _linear_to_srgb(x):
import torch
return torch.where(
x <= 0.0031308, 12.92 * x, 1.055 * x.clamp(min=0) ** (1.0 / 2.4) - 0.055
)
@staticmethod
def _gaussian_kernel1d(sigma, radius):
import torch
if radius <= 0:
k = torch.tensor([1.0], dtype=torch.float32)
return k / k.sum()
xs = torch.arange(-radius, radius + 1, dtype=torch.float32)
k = torch.exp(-(xs * xs) / (2.0 * sigma * sigma))
return k / k.sum()
_GK_CACHE = {} # (float(sigma), int(radius), int(C)) -> (kH, kW)
@staticmethod
def _gaussian_blur_nhwc(x, sigma):
if sigma <= 0:
return x
N, H, W, C = x.shape
max_radius = max(0, min(H, W) // 2 - 1)
radius = min(int(math.ceil(3.0 * sigma)), max_radius)
if radius <= 0:
return x
key = (float(sigma), int(radius), int(C))
k = RIFE_FPS_Resample._GK_CACHE.get(key)
if k is None:
g = RIFE_FPS_Resample._gaussian_kernel1d(sigma, radius)
kH = g.view(1, 1, -1, 1).repeat(C, 1, 1, 1)
kW = g.view(1, 1, 1, -1).repeat(C, 1, 1, 1)
RIFE_FPS_Resample._GK_CACHE[key] = (kH, kW)
k = (kH, kW)
kH, kW = k
x_nchw = x.permute(0, 3, 1, 2).contiguous()
x1 = F.pad(x_nchw, (0, 0, radius, radius), mode="reflect")
x1 = F.conv2d(x1, kH, groups=C)
x2 = F.pad(x1, (radius, radius, 0, 0), mode="reflect")
x2 = F.conv2d(x2, kW, groups=C)
return x2.permute(0, 2, 3, 1).contiguous()
@staticmethod
def _luma_linear(x_lin):
import torch
r, g, b = x_lin[..., 0], x_lin[..., 1], x_lin[..., 2]
return 0.2126 * r + 0.7152 * g + 0.0722 * b
@staticmethod
def _exposure_match_ab(out_lin, tgt_lin, a_cap, b_cap):
import torch
y_o = RIFE_FPS_Resample._luma_linear(out_lin).flatten()
y_t = RIFE_FPS_Resample._luma_linear(tgt_lin).flatten()
p = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float32)
p5_o, p50_o, p95_o = torch.quantile(y_o, p)
p5_t, p50_t, p95_t = torch.quantile(y_t, p)
eps = 1e-6
a = ((p95_t - p5_t) / (p95_o - p5_o + eps)).item()
a = max(1.0 - a_cap, min(1.0 + a_cap, a))
b = (p50_t - a * p50_o).item()
b = max(-b_cap, min(b_cap, b))
return a, b
@staticmethod
def _dilate_mask_nhwc(mask01, radius):
if radius <= 0:
return mask01
x = mask01.permute(0, 3, 1, 2)
x = F.pad(x, (radius, radius, radius, radius), mode="replicate")
x = F.max_pool2d(x, kernel_size=2 * radius + 1, stride=1)
return x.permute(0, 2, 3, 1)
@torch.inference_mode()
def resample(
self,
ckpt_name,
fps_in,
fps_out,
frames,
scale_factor=1.0,
ensemble=True,
linearize=False,
lf_guardrail=False,
lf_sigma=13.0,
source_pair_match=False,
match_a_cap=0.02,
match_b_cap=2.0 / 255.0,
edge_band_lock=False,
tau_low=1.5 / 255.0,
tau_high=6.0 / 255.0,
band_radius=4,
band_soft_sigma=2.0,
clear_cache_after_n_frames=10,
**kwargs,
):
if fps_in <= 0 or fps_out <= 0:
raise ValueError("fps_in and fps_out must be > 0")
F_in = Fraction(str(float(fps_in)))
F_out = Fraction(str(float(fps_out)))
same_rate = F_in == F_out
ratio = F_in / F_out
integer_downscale = ratio.denominator == 1 and ratio.numerator > 1
if same_rate:
return (frames,)
if integer_downscale:
step = int(ratio.numerator)
return (frames[::step].contiguous(),)
model = _get_rife_model(ckpt_name)
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x_nhwc = _prep_frames(frames)
n_in = int(x_nhwc.shape[0])
if n_in <= 1:
return (postprocess_frames(x_nhwc[:1]),)
K_max = int(Fraction(n_in - 1) * F_out / F_in)
n_out = K_max + 1
scale_list = _scale_list(scale_factor)
def to_nchw_on_device(idx: int):
f = x_nhwc[idx : idx + 1].permute(0, 3, 1, 2).contiguous()
return f.to(
device=device,
dtype=dtype,
non_blocking=True,
memory_format=torch.channels_last,
)
cache = {}
def synth(i: int, t_local: float):
if i not in cache:
cache[i] = to_nchw_on_device(i)
if (i + 1) not in cache:
cache[i + 1] = to_nchw_on_device(i + 1)
# fast_mode is a no-op for 4.7/4.9; pass False explicitly
y = model(
cache[i], cache[i + 1], float(t_local), scale_list, False, ensemble
)
return y.permute(0, 2, 3, 1).contiguous().detach().cpu().to(torch.float32)
to_lin = self._srgb_to_linear if linearize else (lambda z: z)
to_srgb = self._linear_to_srgb if linearize else (lambda z: z)
blur = self._gaussian_blur_nhwc
y0 = x_nhwc[0:1]
H, W, C = int(y0.shape[1]), int(y0.shape[2]), int(y0.shape[3])
y_buf = torch.empty((n_out, H, W, C), dtype=torch.float32, device="cpu")
def _write(idx: int, frame_nhwc_cpu: torch.Tensor):
y_buf[idx : idx + 1].copy_(frame_nhwc_cpu)
num = F_in.numerator * F_out.denominator
den = F_in.denominator * F_out.numerator
acc = 0
frames_since_clear = 0
pbar = _progress(int(n_out))
for k in range(n_out):
pbar.update(1)
i = acc // den
frac_num = acc % den
if i >= (n_in - 1):
_write(k, x_nhwc[n_in - 1 : n_in])
else:
if frac_num == 0:
_write(k, x_nhwc[i : i + 1])
else:
frac = float(frac_num) / den
y = synth(i, frac)
if linearize or lf_guardrail or source_pair_match or edge_band_lock:
A = x_nhwc[i : i + 1]
B = x_nhwc[i + 1 : i + 2]
y_lin, A_lin, B_lin = to_lin(y), to_lin(A), to_lin(B)
if lf_guardrail and lf_sigma > 0:
HF = y_lin - blur(y_lin, lf_sigma)
LF_target = (1.0 - frac) * blur(
A_lin, lf_sigma
) + frac * blur(B_lin, lf_sigma)
y_lin = HF + LF_target
if source_pair_match:
tgt = (1.0 - frac) * A_lin + frac * B_lin
a, b = self._exposure_match_ab(
y_lin, tgt, match_a_cap, match_b_cap
)
y_lin = (a * y_lin + b).clamp(0.0, 1.0)
if edge_band_lock:
d = (
(self._luma_linear(A_lin) - self._luma_linear(B_lin))
.abs()
.unsqueeze(-1)
)
high = (d > tau_high).float()
low = (d < tau_low).float()
band = self._dilate_mask_nhwc(high, band_radius) * low
if band_soft_sigma > 0:
band = blur(band, band_soft_sigma).clamp(0.0, 1.0)
near = A_lin if frac < 0.5 else B_lin
y_lin = band * near + (1.0 - band) * y_lin
y = to_srgb(y_lin).clamp(0.0, 1.0)
_write(k, y)
frames_since_clear += 1
if torch.cuda.is_available() and int(clear_cache_after_n_frames) > 0:
if frames_since_clear >= int(clear_cache_after_n_frames):
torch.cuda.empty_cache()
frames_since_clear = 0
acc += num
return (postprocess_frames(y_buf),)
class RIFE_SeamTimingAnalyzer:
"""
Batch-in, timings-out. Chooses t-schedule for the seam [last, first] using
visual step sizes measured from REAL frames in the input batch.
Inputs:
- full_clip (IMAGE, NHWC [N,H,W,C], required): real frames (e.g., 1..10)
- multiplier (INT): number of in-betweens to synthesize for the seam
- use_first_two (BOOL): include dist(clip[0], clip[1]) as a target
- use_last_two (BOOL): include dist(clip[-2], clip[-1]) as a target
- use_global_median (BOOL): include median of ALL adjacent deltas (N>=3)
- ckpt_name / ensemble / scale_factor: RIFE runtime knobs (only used to test t)
- calibrate_metric: "MSE" or "L1" visual distance
- calibrate_iters: binary-search iterations
- t_min / t_max: clamp t search range (keep t_max < 1.0 to avoid hugging frame 1)
Outputs:
- t_list_csv (STRING): e.g. "0.183112, 0.408217, 0.653991, 0.913541"
- multiplier (INT): echo of input multiplier for wiring convenience
"""
DESCRIPTION = (
"Finds a smooth loop timing: measures motion in your clip and solves a set of t-values across the wrap "
"[last→first] so the seam blends naturally."
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (
["rife47.pth", "rife49.pth"],
{
"default": "rife47.pth",
"tooltip": "Choose the RIFE 4.7 or 4.9 model file (used only to test candidate timings).",
},
),
"scale_factor": (
[0.25, 0.5, 1.0, 2.0, 4.0],
{
"default": 1.0,
"tooltip": "Quality vs speed for the probe renders. 1.0 recommended.",
},
),
"ensemble": (
"BOOLEAN",
{
"default": True,
"tooltip": "Blend forward & backward predictions to reduce artifacts (slower).",
},
),
"full_clip": (
"IMAGE",
{
"tooltip": "Your input clip (≥2 frames). Real motion here decides the loop seam timing."
},
),
"multiplier": (
"INT",
{
"default": 4,
"min": 0,
"tooltip": "How many new frames you plan to create at the loop seam [last→first]. Set 0 to skip.",
},
),
"use_first_two": (
"BOOLEAN",
{
"default": True,
"tooltip": "Match the motion between the first two frames in your clip.",
},
),
"use_last_two": (
"BOOLEAN",
{
"default": True,
"tooltip": "Match the motion between the last two frames in your clip.",
},
),
"use_global_median": (
"BOOLEAN",
{
"default": False,
"tooltip": "Use the median motion across the whole clip (needs ≥3 frames). Helps ignore outliers.",
},
),
"calibrate_metric": (
["MSE", "L1"],
{
"default": "MSE",
"tooltip": "How we compare frames while solving: MSE (more sensitive) or L1 (more forgiving).",
},
),
"calibrate_iters": (
"INT",
{
"default": 12,
"min": 4,
"max": 24,
"tooltip": "Search depth per solve. Higher = slower, but a tighter match.",
},
),
"t_min": (
"FLOAT",
{
"default": 0.00,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Earliest allowed blend point at the seam (0 = exactly the last frame).",
},
),
"t_max": (
"FLOAT",
{
"default": 0.96,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Latest allowed blend point at the seam (keep below 1.0 to avoid sticking to the first frame).",
},
),
},
"optional": {
"auto_tmax": (
"BOOLEAN",
{
"default": False,
"tooltip": "Automatically push the upper limit closer to the next frame to hit the target motion step.",
},
),
"t_cap": (
"FLOAT",
{
"default": 0.995,
"min": 0.5,
"max": 0.9999,
"step": 0.0001,
"tooltip": "Safety cap used with the auto upper limit (keeps it just shy of 1.0).",
},
),
},
}
RETURN_TYPES = ("STRING", "INT")
RETURN_NAMES = ("t_list_csv", "multiplier")
FUNCTION = "analyze_wrapper"
CATEGORY = "video utils"
# --- helpers ---
def _to_nchw(self, x):
if x.ndim == 4 and x.shape[-1] in (1, 3, 4):
return x.permute(0, 3, 1, 2).contiguous()
return x
def _dist(self, A, B, kind="MSE"):
if kind == "L1":
return (A - B).abs().mean()
return ((A - B) ** 2).mean()
def _adjacent_deltas(self, clip_nchw, kind="MSE"):
N = clip_nchw.shape[0]
ds = []
for i in range(N - 1):
ds.append(self._dist(clip_nchw[i : i + 1], clip_nchw[i + 1 : i + 2], kind))
return ds # list of scalar tensors
def analyze_wrapper(self, **kwargs):
"""
Safe entrypoint for Comfy mapping:
- If 'multiplier' is omitted: default to 0 (passthrough).
- If 'multiplier' is invalid/negative: normalize to 0 (passthrough).
"""
if "multiplier" not in kwargs:
kwargs["multiplier"] = 0
try:
m = int(kwargs["multiplier"])
except (TypeError, ValueError):
m = 0
if m < 0:
m = 0
kwargs["multiplier"] = m
return self.analyze(**kwargs)
def analyze(
self,
full_clip,
multiplier,
use_first_two,
use_last_two,
use_global_median,
ckpt_name,
ensemble,
scale_factor,
calibrate_metric,
calibrate_iters,
t_min,
t_max,
auto_tmax=False,
t_cap=0.995,
):
try:
m = int(multiplier)
except (TypeError, ValueError):
m = 0
if m < 0:
m = 0
if m == 0:
return ("", 0)
clip_nhwc = preprocess_frames(full_clip)
if clip_nhwc.device.type == "cpu" and torch.cuda.is_available():
clip_nhwc = clip_nhwc.pin_memory()
N = int(clip_nhwc.shape[0])
if N < 2:
raise ValueError("full_clip must contain at least 2 frames.")
metric = "L1" if calibrate_metric == "L1" else "MSE"
clip_nchw_cpu = self._to_nchw(clip_nhwc)
deltas = self._adjacent_deltas(clip_nchw_cpu, metric)
chosen = []
if use_first_two:
chosen.append(deltas[0])
if use_last_two:
chosen.append(deltas[-1])
if use_global_median:
if N < 3:
raise ValueError(
"use_global_median requires full_clip with >= 3 frames."
)
chosen.append(torch.stack(deltas).median())
if not chosen:
raise ValueError(
"Enable at least one of: use_first_two, use_last_two, use_global_median."
)
d_target = (
chosen[0] if len(chosen) == 1 else torch.stack(chosen).median()
).item()
iters = int(calibrate_iters)
m = int(m)
model = _get_rife_model(ckpt_name)
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype # usually float32
last_nchw = self._to_nchw(clip_nhwc[N - 1 : N]).to(
device=device,
dtype=dtype,
non_blocking=True,
memory_format=torch.channels_last,
)
first_nchw = self._to_nchw(clip_nhwc[0:1]).to(
device=device,
dtype=dtype,
non_blocking=True,
memory_format=torch.channels_last,
)
t_min = float(max(0.0, min(1.0, t_min)))
t_max = float(max(0.0, min(1.0, t_max)))
if not (0.0 <= t_min < t_max <= 1.0):
raise ValueError("Require 0 <= t_min < t_max <= 1.0")
hi_eff = t_max
if auto_tmax:
# Raise the upper bracket to a safe cap near 1.0 so we can hit the target step size.
hi_eff = max(hi_eff, min(max(t_min + 1e-6, float(t_cap)), 0.9999))
scale_list = [
8 / scale_factor,
4 / scale_factor,
2 / scale_factor,
1 / scale_factor,
]
@torch.no_grad()
def synth_at(t_scalar: float):
t_scalar = float(max(t_min, min(hi_eff, t_scalar)))
return model(last_nchw, first_nchw, t_scalar, scale_list, False, ensemble)
lo, hi = t_min, hi_eff
with torch.inference_mode():
for _ in range(iters):
mid = 0.5 * (lo + hi)
d_mid = float(self._dist(synth_at(mid), first_nchw, metric).item())
if d_mid > d_target:
lo = mid
else:
hi = mid
t_last = 0.5 * (lo + hi)
f_prev = synth_at(t_last)
ts = [t_last]
for _ in range(m - 1):
lo, hi = t_min, ts[-1]
for _ in range(iters):
mid = 0.5 * (lo + hi)
d_mid = float(self._dist(synth_at(mid), f_prev, metric).item())
if d_mid > d_target:
lo = mid
else:
hi = mid
tk = 0.5 * (lo + hi)
ts.append(tk)
f_prev = synth_at(tk)
ts_sorted = sorted(ts)
csv = ", ".join(f"{t:.6f}" for t in ts_sorted)
return (csv, m)