Files
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

75 lines
3.1 KiB
Python

from ..utils import common_annotator_call
import comfy.model_management as model_management
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
class Unimatch_OptFlowPreprocessor:
@classmethod
def INPUT_TYPES(s):
return {
"required": dict(
image=("IMAGE",),
ckpt_name=(
["gmflow-scale1-mixdata.pth", "gmflow-scale2-mixdata.pth", "gmflow-scale2-regrefine6-mixdata.pth"],
{"default": "gmflow-scale2-regrefine6-mixdata.pth"}
),
backward_flow=("BOOLEAN", {"default": False}),
bidirectional_flow=("BOOLEAN", {"default": False})
)
}
RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
FUNCTION = "estimate"
CATEGORY = "ControlNet Preprocessors/Optical Flow"
def estimate(self, image, ckpt_name, backward_flow=False, bidirectional_flow=False):
assert len(image) > 1, "[Unimatch] Requiring as least two frames as an optical flow estimator. Only use this node on video input."
from custom_controlnet_aux.unimatch import UnimatchDetector
tensor_images = image
model = UnimatchDetector.from_pretrained(filename=ckpt_name).to(model_management.get_torch_device())
flows, vis_flows = [], []
for i in range(len(tensor_images) - 1):
image0, image1 = np.asarray(image[i:i+2].cpu() * 255., dtype=np.uint8)
flow, vis_flow = model(image0, image1, output_type="np", pred_bwd_flow=backward_flow, pred_bidir_flow=bidirectional_flow)
flows.append(torch.from_numpy(flow).float())
vis_flows.append(torch.from_numpy(vis_flow).float() / 255.)
del model
return (torch.stack(flows, dim=0), torch.stack(vis_flows, dim=0))
class MaskOptFlow:
@classmethod
def INPUT_TYPES(s):
return {
"required": dict(optical_flow=("OPTICAL_FLOW",), mask=("MASK",))
}
RETURN_TYPES = ("OPTICAL_FLOW", "IMAGE")
RETURN_NAMES = ("OPTICAL_FLOW", "PREVIEW_IMAGE")
FUNCTION = "mask_opt_flow"
CATEGORY = "ControlNet Preprocessors/Optical Flow"
def mask_opt_flow(self, optical_flow, mask):
from custom_controlnet_aux.unimatch import flow_to_image
assert len(mask) >= len(optical_flow), f"Not enough masks to mask optical flow: {len(mask)} vs {len(optical_flow)}"
mask = mask[:optical_flow.shape[0]]
mask = F.interpolate(mask, optical_flow.shape[1:3])
mask = rearrange(mask, "n 1 h w -> n h w 1")
vis_flows = torch.stack([torch.from_numpy(flow_to_image(flow)).float() / 255. for flow in optical_flow.numpy()], dim=0)
vis_flows *= mask
optical_flow *= mask
return (optical_flow, vis_flows)
NODE_CLASS_MAPPINGS = {
"Unimatch_OptFlowPreprocessor": Unimatch_OptFlowPreprocessor,
"MaskOptFlow": MaskOptFlow
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Unimatch_OptFlowPreprocessor": "Unimatch Optical Flow",
"MaskOptFlow": "Mask Optical Flow (DragNUWA)"
}