Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
230 lines
8.6 KiB
Python
230 lines
8.6 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from typing import Tuple, Dict, Optional, List
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw
|
|
|
|
def pil2tensor(image):
|
|
"""Convert a PIL image to a PyTorch tensor in the expected format."""
|
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
class RegionMaskProcessor:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask1": ("MASK",),
|
|
"bbox1": ("BBOX",),
|
|
"blur_radius": ("INT", {
|
|
"default": 5,
|
|
"min": 0,
|
|
"max": 32,
|
|
"step": 1,
|
|
"display": "Blur Radius"
|
|
}),
|
|
"threshold": ("FLOAT", {
|
|
"default": 0.5,
|
|
"min": 0.0,
|
|
"max": 1.0,
|
|
"step": 0.1,
|
|
"display": "Mask Threshold"
|
|
}),
|
|
"feather_edges": ("BOOLEAN", {
|
|
"default": True,
|
|
"display": "Feather Edges"
|
|
}),
|
|
"number_of_regions": ("INT", {
|
|
"default": 1,
|
|
"min": 1,
|
|
"max": 3,
|
|
"display": "Number of Regions"
|
|
}),
|
|
},
|
|
"optional": {
|
|
"mask2": ("MASK",),
|
|
"bbox2": ("BBOX",),
|
|
"mask3": ("MASK",),
|
|
"bbox3": ("BBOX",),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK", "BBOX", "MASK", "BBOX", "MASK", "BBOX", "IMAGE", "INT")
|
|
RETURN_NAMES = ("processed_mask1", "processed_bbox1",
|
|
"processed_mask2", "processed_bbox2",
|
|
"processed_mask3", "processed_bbox3",
|
|
"preview_image", "region_count")
|
|
FUNCTION = "process_regions"
|
|
CATEGORY = "ControlAltAI Nodes/Flux Region"
|
|
|
|
def apply_gaussian_blur(self, mask: torch.Tensor, radius: int) -> torch.Tensor:
|
|
"""Apply gaussian blur to mask edges"""
|
|
if radius <= 0:
|
|
return mask
|
|
|
|
kernel_size = 2 * radius + 1
|
|
sigma = radius / 3.0
|
|
|
|
if len(mask.shape) == 2:
|
|
mask = mask.unsqueeze(0).unsqueeze(0)
|
|
|
|
kernel_1d = torch.exp(torch.linspace(-radius, radius, kernel_size).pow(2) / (-2 * sigma ** 2))
|
|
kernel_1d = kernel_1d / kernel_1d.sum()
|
|
|
|
padding = radius
|
|
kernel_h = kernel_1d.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(mask.device)
|
|
kernel_v = kernel_1d.unsqueeze(0).unsqueeze(0).unsqueeze(-1).to(mask.device)
|
|
|
|
mask = F.pad(mask, (padding, padding, 0, 0), mode='reflect')
|
|
mask = F.conv2d(mask, kernel_h)
|
|
mask = F.pad(mask, (0, 0, padding, padding), mode='reflect')
|
|
mask = F.conv2d(mask, kernel_v)
|
|
|
|
return mask.squeeze()
|
|
|
|
def apply_feathering(self, mask: torch.Tensor, bbox: Dict, radius: int) -> Tuple[torch.Tensor, Dict]:
|
|
"""Apply feathering to mask edges while preserving bbox boundaries"""
|
|
if radius <= 0 or not bbox["active"]:
|
|
return mask, bbox
|
|
|
|
height, width = mask.shape
|
|
x1 = int(bbox["x1"] * width)
|
|
y1 = int(bbox["y1"] * height)
|
|
x2 = int(bbox["x2"] * width)
|
|
y2 = int(bbox["y2"] * height)
|
|
|
|
inner_mask = torch.zeros_like(mask)
|
|
inner_mask[y1+radius:y2-radius, x1+radius:x2-radius] = 1.0
|
|
edge_mask = mask - inner_mask
|
|
|
|
if edge_mask.any():
|
|
blurred = self.apply_gaussian_blur(mask, radius)
|
|
result = mask.clone()
|
|
result[edge_mask > 0] = blurred[edge_mask > 0]
|
|
else:
|
|
result = mask
|
|
|
|
return result, bbox
|
|
|
|
def process_single_region(self,
|
|
mask: torch.Tensor,
|
|
bbox: Dict,
|
|
blur_radius: int,
|
|
threshold: float,
|
|
feather_edges: bool) -> Tuple[torch.Tensor, Dict]:
|
|
"""Process a single mask-bbox pair"""
|
|
if mask is None or not bbox["active"]:
|
|
return mask, bbox
|
|
|
|
try:
|
|
processed = (mask > threshold).float()
|
|
|
|
if feather_edges and blur_radius > 0:
|
|
processed, bbox = self.apply_feathering(processed, bbox, blur_radius)
|
|
elif blur_radius > 0:
|
|
processed = self.apply_gaussian_blur(processed, blur_radius)
|
|
|
|
return processed, bbox
|
|
|
|
except Exception as e:
|
|
print(f"Error processing region: {str(e)}")
|
|
return mask, bbox
|
|
|
|
def create_preview(self, masks: List[torch.Tensor], bboxes: List[Dict],
|
|
number_of_regions: int) -> torch.Tensor:
|
|
"""Create preview of processed regions with PIL for consistent coloring"""
|
|
if not masks:
|
|
return torch.zeros((3, 64, 64), dtype=torch.float32)
|
|
|
|
height, width = masks[0].shape
|
|
|
|
# Create PIL Image for preview
|
|
preview = Image.new("RGB", (width, height), (0, 0, 0))
|
|
|
|
colors = [
|
|
(255, 0, 0), # Red - Region 1
|
|
(0, 255, 0), # Green - Region 2
|
|
(255, 255, 0), # Yellow - Region 3
|
|
]
|
|
|
|
# Store regions for ordered preview
|
|
preview_regions = []
|
|
for i in range(number_of_regions):
|
|
if bboxes[i]["active"] and masks[i] is not None:
|
|
mask_np = masks[i].cpu().numpy() > 0.5
|
|
preview_regions.append((i, mask_np))
|
|
|
|
# Draw regions in reverse order (Region 3 first, Region 1 last)
|
|
for i, mask_np in sorted(preview_regions, reverse=True):
|
|
color_array = np.zeros((height, width, 3), dtype=np.uint8)
|
|
color_array[mask_np] = colors[i]
|
|
|
|
# Convert to PIL and composite
|
|
region_img = Image.fromarray(color_array, 'RGB')
|
|
preview = Image.alpha_composite(
|
|
preview.convert('RGBA'),
|
|
Image.merge('RGBA', (*region_img.split(), Image.fromarray((mask_np * 255).astype(np.uint8))))
|
|
)
|
|
|
|
return pil2tensor(preview.convert('RGB'))
|
|
|
|
def process_regions(self,
|
|
mask1: torch.Tensor,
|
|
bbox1: Dict,
|
|
blur_radius: int,
|
|
threshold: float,
|
|
feather_edges: bool,
|
|
number_of_regions: int,
|
|
mask2: Optional[torch.Tensor] = None,
|
|
bbox2: Optional[Dict] = None,
|
|
mask3: Optional[torch.Tensor] = None,
|
|
bbox3: Optional[Dict] = None) -> Tuple:
|
|
try:
|
|
# Process each mask-bbox pair
|
|
mask_bbox_pairs = [
|
|
(mask1, bbox1),
|
|
(mask2, bbox2) if mask2 is not None else (None, None),
|
|
(mask3, bbox3) if mask3 is not None else (None, None),
|
|
]
|
|
|
|
processed_masks = []
|
|
processed_bboxes = []
|
|
active_count = 0
|
|
|
|
for i, (mask, bbox) in enumerate(mask_bbox_pairs):
|
|
if i < number_of_regions and mask is not None and bbox is not None:
|
|
proc_mask, proc_bbox = self.process_single_region(
|
|
mask, bbox, blur_radius, threshold, feather_edges
|
|
)
|
|
if proc_bbox["active"]:
|
|
active_count += 1
|
|
processed_masks.append(proc_mask)
|
|
processed_bboxes.append(proc_bbox)
|
|
else:
|
|
empty_mask = torch.zeros_like(mask1)
|
|
empty_bbox = {"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0, "active": False}
|
|
processed_masks.append(empty_mask)
|
|
processed_bboxes.append(empty_bbox)
|
|
|
|
# Create preview
|
|
preview = self.create_preview(processed_masks, processed_bboxes, number_of_regions)
|
|
|
|
return (*[item for pair in zip(processed_masks, processed_bboxes) for item in pair],
|
|
preview, active_count)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing regions: {str(e)}")
|
|
empty_mask = torch.zeros_like(mask1)
|
|
empty_bbox = {"x1": 0.0, "y1": 0.0, "x2": 0.0, "y2": 0.0, "active": False}
|
|
empty_preview = torch.zeros((3, mask1.shape[0], mask1.shape[1]), dtype=torch.float32)
|
|
return (empty_mask, empty_bbox, empty_mask, empty_bbox,
|
|
empty_mask, empty_bbox,
|
|
empty_preview, 0)
|
|
|
|
# Node class mappings
|
|
NODE_CLASS_MAPPINGS = {
|
|
"RegionMaskProcessor": RegionMaskProcessor
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"RegionMaskProcessor": "Region Mask Processor"
|
|
} |