Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
301 lines
11 KiB
Python
301 lines
11 KiB
Python
import torch
|
|
import numpy as np
|
|
from typing import Dict, List, Optional, Tuple
|
|
from PIL import Image, ImageDraw
|
|
|
|
def pil2tensor(image):
|
|
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
|
class RegionMaskConditioning:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"mask1": ("MASK",),
|
|
"bbox1": ("BBOX",),
|
|
"conditioning1": ("CONDITIONING",),
|
|
"number_of_regions": ("INT", {
|
|
"default": 1,
|
|
"min": 1,
|
|
"max": 3,
|
|
"step": 1,
|
|
"display": "Number of Regions"
|
|
}),
|
|
"strength1": ("FLOAT", {
|
|
"default": 1.0,
|
|
"min": 0.0,
|
|
"max": 10.0,
|
|
"step": 0.1,
|
|
"display": "Strength for Region 1"
|
|
}),
|
|
},
|
|
"optional": {
|
|
"mask2": ("MASK",),
|
|
"bbox2": ("BBOX",),
|
|
"conditioning2": ("CONDITIONING",),
|
|
"strength2": ("FLOAT", {
|
|
"default": 1.0,
|
|
"min": 0.0,
|
|
"max": 10.0,
|
|
"step": 0.1,
|
|
"display": "Strength for Region 2"
|
|
}),
|
|
"mask3": ("MASK",),
|
|
"bbox3": ("BBOX",),
|
|
"conditioning3": ("CONDITIONING",),
|
|
"strength3": ("FLOAT", {
|
|
"default": 1.0,
|
|
"min": 0.0,
|
|
"max": 10.0,
|
|
"step": 0.1,
|
|
"display": "Strength for Region 3"
|
|
}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("REGION", "REGION", "REGION", "INT", "IMAGE")
|
|
RETURN_NAMES = ("region1", "region2", "region3",
|
|
"region_count", "preview_image")
|
|
FUNCTION = "create_conditioned_regions"
|
|
CATEGORY = "ControlAltAI Nodes/Flux Region"
|
|
|
|
def validate_bbox(self, bbox: Dict) -> bool:
|
|
"""Validate bbox coordinates and structure"""
|
|
print(f"\n=== Validating BBox ===")
|
|
print(f"Input bbox: {bbox}")
|
|
|
|
if bbox is None or not isinstance(bbox, dict):
|
|
print("Failed: Invalid bbox type")
|
|
return False
|
|
|
|
required_keys = ["x1", "y1", "x2", "y2"]
|
|
if not all(k in bbox for k in required_keys):
|
|
print(f"Failed: Missing keys. Required: {required_keys}")
|
|
return False
|
|
|
|
# Validate coordinate values
|
|
if not all(isinstance(bbox[k], (int, float)) for k in required_keys):
|
|
print("Failed: Invalid coordinate types")
|
|
return False
|
|
|
|
# Validate coordinate ranges
|
|
if not all(0 <= bbox[k] <= 1.0 for k in required_keys):
|
|
print("Failed: Coordinates out of range [0,1]")
|
|
return False
|
|
|
|
# Validate proper ordering
|
|
if bbox["x1"] >= bbox["x2"] or bbox["y1"] >= bbox["y2"]:
|
|
print("Failed: Invalid coordinate ordering")
|
|
return False
|
|
|
|
print("Passed: BBox validation successful")
|
|
return True
|
|
|
|
def scale_conditioning(self, conditioning: List, strength: float) -> List:
|
|
"""Scale conditioning tensors by strength"""
|
|
print(f"\n=== Scaling Conditioning ===")
|
|
print(f"Strength: {strength}")
|
|
|
|
try:
|
|
if not conditioning or not isinstance(conditioning, list):
|
|
print("Failed: Invalid conditioning format")
|
|
raise ValueError("Invalid conditioning format")
|
|
|
|
# Get the conditioning tensors and dict
|
|
cond_tensors = conditioning[0][0]
|
|
cond_dict = conditioning[0][1]
|
|
|
|
print(f"Input tensor shape: {cond_tensors.shape}")
|
|
print(f"Conditioning keys: {list(cond_dict.keys())}")
|
|
print(f"Input tensor stats: min={cond_tensors.min():.3f}, max={cond_tensors.max():.3f}, mean={cond_tensors.mean():.3f}")
|
|
|
|
# Scale the tensors
|
|
scaled_tensors = cond_tensors.clone() * strength
|
|
print(f"Scaled tensor stats: min={scaled_tensors.min():.3f}, max={scaled_tensors.max():.3f}, mean={scaled_tensors.mean():.3f}")
|
|
|
|
return [[scaled_tensors, cond_dict]]
|
|
|
|
except Exception as e:
|
|
print(f"Error in scale_conditioning: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return conditioning
|
|
|
|
def create_region(self, mask: Optional[torch.Tensor], bbox: Optional[Dict],
|
|
conditioning: Optional[List], strength: float, region_idx: int) -> Dict:
|
|
"""Create a single region with its conditioning"""
|
|
print(f"\n=== Creating Region {region_idx} ===")
|
|
|
|
# Debug inputs
|
|
print("Input validation:")
|
|
print(f"- Mask: {type(mask)}, shape={mask.shape if mask is not None else None}")
|
|
print(f"- BBox: {bbox}")
|
|
print(f"- Conditioning type: {type(conditioning)}")
|
|
print(f"- Strength: {strength}")
|
|
|
|
# Default empty region
|
|
empty_region = {
|
|
"conditioning": None,
|
|
"bbox": [0.0, 0.0, 0.0, 0.0], # Array format for empty
|
|
"is_active": False,
|
|
"strength": 1.0
|
|
}
|
|
|
|
try:
|
|
# Validate inputs
|
|
if mask is None or bbox is None or conditioning is None:
|
|
print(f"Region {region_idx}: Missing components")
|
|
return empty_region
|
|
|
|
if not self.validate_bbox(bbox):
|
|
print(f"Region {region_idx}: Invalid bbox")
|
|
return empty_region
|
|
|
|
# Scale conditioning
|
|
scaled_conditioning = self.scale_conditioning(conditioning, strength)
|
|
|
|
# Create region output - bbox array, conditioning, and strength
|
|
region = {
|
|
"conditioning": scaled_conditioning,
|
|
"bbox": [bbox["x1"], bbox["y1"], bbox["x2"], bbox["y2"]], # Array format
|
|
"is_active": True,
|
|
"strength": strength
|
|
}
|
|
|
|
print(f"\nSuccessfully created region {region_idx}")
|
|
return region
|
|
|
|
except Exception as e:
|
|
print(f"Error creating region {region_idx}: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return empty_region
|
|
|
|
def create_preview(self, masks: List[torch.Tensor], bboxes: List[Dict],
|
|
number_of_regions: int) -> torch.Tensor:
|
|
"""Create preview of conditioned regions"""
|
|
print("\n=== Creating Preview ===")
|
|
|
|
if not masks:
|
|
print("No masks provided")
|
|
return torch.zeros((3, 64, 64), dtype=torch.float32)
|
|
|
|
height, width = masks[0].shape
|
|
print(f"Preview dimensions: {width}x{height}")
|
|
|
|
# Create PIL Image for preview
|
|
preview = Image.new("RGB", (width, height), (0, 0, 0))
|
|
draw = ImageDraw.Draw(preview)
|
|
|
|
# Define colors for 3 regions
|
|
colors = [
|
|
(255, 0, 0), # Red
|
|
(0, 255, 0), # Green
|
|
(255, 255, 0), # Yellow
|
|
]
|
|
|
|
# Draw each region
|
|
for i, (mask, bbox) in enumerate(zip(masks[:number_of_regions], bboxes[:number_of_regions])):
|
|
validation_result = self.validate_bbox(bbox)
|
|
if validation_result and mask is not None:
|
|
print(f"\nDrawing region {i+1}:")
|
|
# Get pixel coordinates
|
|
x1 = int(bbox["x1"] * width)
|
|
y1 = int(bbox["y1"] * height)
|
|
x2 = int(bbox["x2"] * width)
|
|
y2 = int(bbox["y2"] * height)
|
|
|
|
print(f"Region {i+1} coordinates: ({x1},{y1}) to ({x2},{y2})")
|
|
|
|
# Draw region outline
|
|
draw.rectangle([x1, y1, x2, y2], outline=colors[i], width=4)
|
|
|
|
return pil2tensor(preview)
|
|
|
|
def create_conditioned_regions(self,
|
|
mask1: torch.Tensor,
|
|
bbox1: Dict,
|
|
conditioning1: List,
|
|
number_of_regions: int,
|
|
strength1: float,
|
|
mask2: Optional[torch.Tensor] = None,
|
|
bbox2: Optional[Dict] = None,
|
|
conditioning2: Optional[List] = None,
|
|
strength2: Optional[float] = 1.0,
|
|
mask3: Optional[torch.Tensor] = None,
|
|
bbox3: Optional[Dict] = None,
|
|
conditioning3: Optional[List] = None,
|
|
strength3: Optional[float] = 1.0) -> Tuple:
|
|
print("\n=== Creating Conditioned Regions ===")
|
|
print(f"Number of regions: {number_of_regions}")
|
|
|
|
try:
|
|
# Create regions
|
|
regions = []
|
|
active_count = 0
|
|
|
|
# Process required number of regions
|
|
inputs = [
|
|
(mask1, bbox1, conditioning1, strength1),
|
|
(mask2, bbox2, conditioning2, strength2),
|
|
(mask3, bbox3, conditioning3, strength3)
|
|
]
|
|
|
|
# Store masks and bboxes for preview only
|
|
preview_masks = []
|
|
preview_bboxes = []
|
|
|
|
for i, (mask, bbox, conditioning, strength) in enumerate(inputs[:number_of_regions]):
|
|
# Create region with per-region strength
|
|
region = self.create_region(mask, bbox, conditioning, strength, i+1)
|
|
if region["is_active"]:
|
|
active_count += 1
|
|
regions.append(region)
|
|
print(f"Processed region {i+1}: active={region['is_active']}")
|
|
|
|
# Store for preview
|
|
preview_masks.append(mask)
|
|
preview_bboxes.append(bbox)
|
|
|
|
# Fill remaining slots with empty regions
|
|
empty_region = {
|
|
"conditioning": None,
|
|
"bbox": [0.0, 0.0, 0.0, 0.0], # Array format
|
|
"is_active": False,
|
|
"strength": 1.0
|
|
}
|
|
|
|
while len(regions) < 3:
|
|
idx = len(regions) + 1
|
|
print(f"Adding empty region {idx}")
|
|
regions.append(empty_region)
|
|
|
|
print(f"\nCreated {active_count} active regions out of {number_of_regions} requested")
|
|
|
|
# Create preview using stored masks and bboxes
|
|
preview = self.create_preview(preview_masks, preview_bboxes, number_of_regions)
|
|
|
|
return (*regions, active_count, preview)
|
|
|
|
except Exception as e:
|
|
print(f"Error in create_conditioned_regions: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
empty_region = {
|
|
"conditioning": None,
|
|
"bbox": [0.0, 0.0, 0.0, 0.0], # Array format
|
|
"is_active": False,
|
|
"strength": 1.0
|
|
}
|
|
empty_preview = torch.zeros((3, mask1.shape[0], mask1.shape[1]), dtype=torch.float32)
|
|
return (empty_region, empty_region, empty_region, 0, empty_preview)
|
|
|
|
# Node class mappings
|
|
NODE_CLASS_MAPPINGS = {
|
|
"RegionMaskConditioning": RegionMaskConditioning
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"RegionMaskConditioning": "Region Mask Conditioning"
|
|
} |