Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
528 lines
24 KiB
Python
528 lines
24 KiB
Python
from PIL import Image
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from collections import namedtuple
|
|
from . import utils
|
|
import inspect
|
|
import logging
|
|
import os
|
|
|
|
import pickle
|
|
import folder_paths
|
|
|
|
|
|
orig_torch_load = torch.load
|
|
|
|
|
|
SEG = namedtuple("SEG",
|
|
['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'],
|
|
defaults=[None])
|
|
|
|
|
|
# --- Whitelist Configuration ---
|
|
WHITELIST_DIR = None
|
|
WHITELIST_FILE_PATH = None
|
|
|
|
try:
|
|
# --- Attempting: Use ComfyUI's folder_paths (Preferred Method) ---
|
|
user_dir = folder_paths.get_user_directory()
|
|
if user_dir and os.path.isdir(user_dir):
|
|
WHITELIST_DIR = os.path.join(user_dir, "default", "ComfyUI-Impact-Subpack")
|
|
WHITELIST_FILE_PATH = os.path.join(WHITELIST_DIR, "model-whitelist.txt")
|
|
logging.info(f"[Impact Pack/Subpack] Using folder_paths to determine whitelist path: {WHITELIST_FILE_PATH}")
|
|
else:
|
|
logging.warning(f"[Impact Pack/Subpack] folder_paths.get_user_directory() returned invalid path: {user_dir}.")
|
|
|
|
# --- Ensure directory exists---
|
|
if WHITELIST_FILE_PATH: # Check if any method succeeded in setting the path
|
|
try:
|
|
# Crucially, create the DIRECTORY first
|
|
# Use the WHITELIST_DIR determined by one of the methods above
|
|
os.makedirs(WHITELIST_DIR, exist_ok=True)
|
|
logging.info(f"[Impact Pack/Subpack] Ensured whitelist directory exists: {WHITELIST_DIR}")
|
|
except OSError as e:
|
|
logging.error(f"[Impact Pack/Subpack] Failed to create whitelist directory {WHITELIST_DIR}: {e}. Whitelisting may not function.")
|
|
WHITELIST_FILE_PATH = None # Indicate failure / disable whitelisting
|
|
except Exception as e:
|
|
logging.error(f"[Impact Pack/Subpack] Unexpected error creating whitelist directory: {e}", exc_info=True)
|
|
WHITELIST_FILE_PATH = None # Indicate failure / disable whitelisting
|
|
else:
|
|
# Handle case where path determination failed via all methods
|
|
logging.error("[Impact Pack/Subpack] Whitelist path determination failed using all methods. Whitelisting disabled.")
|
|
# WHITELIST_FILE_PATH is already None
|
|
|
|
|
|
except Exception as e:
|
|
# Catch errors during the whole setup process (e.g., inspect failing)
|
|
logging.error(f"[Impact Pack/Subpack] Critical error during whitelist path setup: {e}", exc_info=True)
|
|
WHITELIST_FILE_PATH = None # Disable whitelisting on critical setup error
|
|
logging.error("[Impact Pack/Subpack] Whitelisting disabled due to critical setup error.")
|
|
|
|
|
|
def load_whitelist(filepath):
|
|
"""
|
|
Loads filenames from the whitelist file.
|
|
Attempts to create the file with instructions if it doesn't exist.
|
|
Returns a set of approved base filenames.
|
|
"""
|
|
approved_files = set()
|
|
# Check again if filepath is valid before proceeding
|
|
if filepath is None or not isinstance(filepath, str):
|
|
# Log was already done if None during setup, avoid duplicate messages
|
|
# logging.error("[Impact Pack/Subpack] Whitelist file path is invalid. Whitelisting disabled.")
|
|
return approved_files # Return empty set
|
|
|
|
try:
|
|
# Try reading the existing file
|
|
with open(filepath, 'r') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
# Store only the base filename for easier matching
|
|
if line and not line.startswith('#'):
|
|
approved_files.add(os.path.basename(line))
|
|
logging.info(f"[Impact Pack/Subpack] Loaded {len(approved_files)} model(s) from whitelist: {filepath}")
|
|
|
|
except FileNotFoundError:
|
|
# This block now runs only if the directory was created successfully but the file is missing
|
|
logging.warning(f"[Impact Pack/Subpack] Model whitelist file not found at: {filepath}. ")
|
|
logging.warning(" >> An empty whitelist file will be created.")
|
|
logging.warning(" >> To allow unsafe loading for specific trusted legacy models (e.g., older .pt),")
|
|
logging.warning(" >> add their base filenames (one per line) to this file.")
|
|
try:
|
|
# Attempt to create the file with comments since it wasn't found
|
|
# This should now succeed because os.makedirs created the directory
|
|
with open(filepath, 'w') as f:
|
|
f.write("# Add base filenames of trusted models (e.g., my_old_yolo.pt) here, one per line.\n")
|
|
f.write("# This allows loading them with `weights_only=False` if they fail safe loading\n")
|
|
f.write("# due to errors like 'restricted getattr' in newer PyTorch versions.\n")
|
|
f.write("# WARNING: Only add files you absolutely trust, as this bypasses a security feature.\n")
|
|
f.write("# Prefer using .safetensors files whenever possible.\n")
|
|
logging.info(f"[Impact Pack/Subpack] Created empty whitelist file: {filepath}")
|
|
except Exception as create_e:
|
|
# Log error if creating the file fails even after creating the directory
|
|
logging.error(f"[Impact Pack/Subpack] Failed to create empty whitelist file at {filepath}: {create_e}", exc_info=True)
|
|
|
|
except Exception as e:
|
|
logging.error(f"[Impact Pack/Subpack] Error loading model whitelist from {filepath}: {e}", exc_info=True)
|
|
|
|
return approved_files
|
|
|
|
# Now call the function using the dynamically determined (or None) path
|
|
_MODEL_WHITELIST = load_whitelist(WHITELIST_FILE_PATH)
|
|
|
|
# ---------- End of Whitelist Management ----------
|
|
|
|
class NO_BBOX_DETECTOR:
|
|
pass
|
|
|
|
|
|
class NO_SEGM_DETECTOR:
|
|
pass
|
|
|
|
|
|
def create_segmasks(results):
|
|
bboxs = results[1]
|
|
segms = results[2]
|
|
confidence = results[3]
|
|
|
|
results = []
|
|
for i in range(len(segms)):
|
|
item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
|
|
results.append(item)
|
|
return results
|
|
|
|
|
|
# Limit the commands that can be executed through `getattr` to `ultralytics.nn.modules.head.Detect.forward`.
|
|
def restricted_getattr(obj, name, *args):
|
|
if name != "forward":
|
|
logging.error(f"Access to potentially dangerous attribute '{obj.__module__}.{obj.__name__}.{name}' is blocked.\nIf you believe the use of this code is genuinely safe, please report it.\nhttps://github.com/ltdrdata/ComfyUI-Impact-Subpack/issues")
|
|
raise RuntimeError(f"Access to potentially dangerous attribute '{obj.__module__}.{obj.__name__}.{name}' is blocked.")
|
|
|
|
return getattr(obj, name, *args)
|
|
|
|
restricted_getattr.__module__ = 'builtins'
|
|
restricted_getattr.__name__ = 'getattr'
|
|
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
from ultralytics.nn.tasks import DetectionModel
|
|
from ultralytics.nn.tasks import SegmentationModel
|
|
from ultralytics.utils import IterableSimpleNamespace
|
|
from ultralytics.utils.tal import TaskAlignedAssigner
|
|
import ultralytics.nn.modules as modules
|
|
import ultralytics.nn.modules.block as block_modules
|
|
import torch.nn.modules as torch_modules
|
|
import ultralytics.utils.loss as loss_modules
|
|
import dill._dill
|
|
from numpy.core.multiarray import scalar
|
|
try:
|
|
from numpy import dtype
|
|
from numpy.dtypes import Float64DType
|
|
except:
|
|
logging.error("[Impact Subpack] installed 'numpy' is outdated. Please update 'numpy>=1.26.4'")
|
|
raise Exception("[Impact Subpack] installed 'numpy' is outdated. Please update 'numpy>=1.26.4'")
|
|
|
|
|
|
torch_whitelist = []
|
|
|
|
except Exception as e:
|
|
logging.error(e)
|
|
logging.error("\n!!!!!\n\n[ComfyUI-Impact-Subpack] If this error occurs, please check the following link:\n\thttps://github.com/ltdrdata/ComfyUI-Impact-Pack/blob/Main/troubleshooting/TROUBLESHOOTING.md\n\n!!!!!\n")
|
|
raise e
|
|
|
|
# HOTFIX: https://github.com/ltdrdata/ComfyUI-Impact-Pack/issues/754
|
|
# importing YOLO breaking original torch.load capabilities
|
|
|
|
# --- Start: REPLACE the existing torch_wrapper function ---
|
|
|
|
def torch_wrapper(*args, **kwargs):
|
|
"""
|
|
Wrapper for torch.load that attempts safe loading (weights_only=True) first.
|
|
If a specific UnpicklingError related to disallowed globals (like 'getattr')
|
|
occurs, it checks a user-defined whitelist (_MODEL_WHITELIST). If the file
|
|
is whitelisted, it retries with weights_only=False. Otherwise, it blocks
|
|
the unsafe load and raises the error.
|
|
"""
|
|
# Use the globally saved original torch.load reference from the top of the file
|
|
# Check if weights_only was explicitly passed by the caller
|
|
# Explicitly declare modification of global scope is intended
|
|
global _MODEL_WHITELIST
|
|
weights_only_explicit = kwargs.get('weights_only', None) # Read value without popping yet
|
|
|
|
# Try to get the filename being loaded (usually the first arg if it's a path)
|
|
filename = None
|
|
filename_arg_source = "[unknown source]"
|
|
if args and isinstance(args[0], str):
|
|
filename = os.path.basename(args[0]) # Get just the filename part
|
|
filename_arg_source = args[0]
|
|
elif 'f' in kwargs and isinstance(kwargs['f'], str):
|
|
filename = os.path.basename(kwargs['f']) # Get just the filename part
|
|
filename_arg_source = kwargs['f']
|
|
# Note: filename might remain None if loading from a file-like object
|
|
|
|
# Check if using newer PyTorch with safe_globals attribute (indicates >= 2.6 behavior likely)
|
|
if hasattr(torch.serialization, 'safe_globals'):
|
|
|
|
# Determine the effective weights_only setting for the FIRST attempt
|
|
load_kwargs = kwargs.copy()
|
|
|
|
try:
|
|
# --- Attempt 1: Default Load ---
|
|
# Try loading with the determined weights_only setting (usually True)
|
|
logging.debug(f"[Impact Pack/Subpack] Attempting safe load (weights_only=True) for: {filename_arg_source}")
|
|
return orig_torch_load(*args, **load_kwargs)
|
|
|
|
except pickle.UnpicklingError as e:
|
|
# --- Handle Specific Load Failure ---
|
|
# Check if the error is the specific one caused by disallowed globals
|
|
# like 'getattr' AND we were attempting a safe load (weights_only=True)
|
|
# Using 'getattr' because it was the specific error reported.
|
|
is_disallowed_global_error = 'getattr' in str(e)
|
|
|
|
if is_disallowed_global_error:
|
|
# Check the whitelist
|
|
if filename and filename in _MODEL_WHITELIST:
|
|
# --- Fallback: Whitelisted Unsafe Load ---
|
|
logging.warning("##############################################################################")
|
|
logging.warning(f"[Impact Pack/Subpack] WARNING: Safe load failed for '{filename}' (Reason: {e}).")
|
|
logging.warning(f" >> FILE IS IN THE WHITELIST: {WHITELIST_FILE_PATH}")
|
|
logging.warning(" >> This model likely uses legacy Python features blocked by default for security.")
|
|
logging.warning(" >> RETRYING WITH 'weights_only=False' because it's whitelisted.")
|
|
logging.warning(" >> SECURITY RISK: Ensure you added this file to the whitelist consciously")
|
|
logging.warning(f" >> and trust its source: {filename_arg_source}")
|
|
logging.warning(" >> Prefer using .safetensors files whenever available.")
|
|
logging.warning("##############################################################################")
|
|
|
|
retry_kwargs = kwargs.copy()
|
|
retry_kwargs['weights_only'] = False
|
|
# Call the original function again, now unsafely (because whitelisted)
|
|
return orig_torch_load(*args, **retry_kwargs)
|
|
|
|
else:
|
|
# --- File not in current whitelist, try reloading ---
|
|
logging.warning(f"[Impact Pack/Subpack] File '{filename}' not found in current whitelist cache.")
|
|
whitelist_path_msg = WHITELIST_FILE_PATH if WHITELIST_FILE_PATH else "[Path not determined]"
|
|
logging.info(f"[Impact Pack/Subpack] Attempting to reload whitelist from: {whitelist_path_msg}")
|
|
try:
|
|
# Reload the whitelist from the file
|
|
_MODEL_WHITELIST = load_whitelist(WHITELIST_FILE_PATH)
|
|
logging.info(f"[Impact Pack/Subpack] Whitelist reloaded. Now contains {len(_MODEL_WHITELIST)} entries.")
|
|
|
|
# --- Re-check Whitelist After Reload ---
|
|
if filename and filename in _MODEL_WHITELIST:
|
|
logging.warning("##############################################################################")
|
|
logging.warning(f"[Impact Pack/Subpack] SUCCESS: File '{filename}' FOUND in reloaded whitelist.")
|
|
logging.warning(" >> Proceeding with whitelisted unsafe load (weights_only=False).")
|
|
logging.warning(f" >> Ensure you recently added this file to: {whitelist_path_msg}")
|
|
logging.warning(" >> SECURITY RISK: Ensure you trust its source.")
|
|
logging.warning("##############################################################################")
|
|
retry_kwargs = kwargs.copy()
|
|
retry_kwargs['weights_only'] = False
|
|
return orig_torch_load(*args, **retry_kwargs)
|
|
else:
|
|
# File still not found after reload, proceed with blocking
|
|
logging.error("[Impact Pack/Subpack] File still not found in whitelist after reload.")
|
|
# Fall through to the original blocking logic below
|
|
|
|
except Exception as reload_e:
|
|
logging.error(f"[Impact Pack/Subpack] Error occurred during whitelist reload attempt: {reload_e}", exc_info=True)
|
|
# Fall through to the original blocking logic below if reload fails
|
|
|
|
# --- Blocked: Not Whitelisted (Original Logic - runs if reload failed or file still not found) ---
|
|
logging.error("##############################################################################")
|
|
logging.error(f"[Impact Pack/Subpack] ERROR: Safe load failed for '{filename_arg_source}' (Reason: {e}).")
|
|
logging.error(" >> This model likely uses legacy Python features blocked by default for security.")
|
|
# Updated log message here:
|
|
logging.error(f" >> UNSAFE LOAD BLOCKED because the file ('{filename or 'unknown'}') is NOT in the whitelist (even after reload attempt).")
|
|
logging.error(f" >> Whitelist path: {whitelist_path_msg}")
|
|
if filename:
|
|
logging.error(" >> To allow loading this specific file (IF YOU TRUST IT), ensure its base name")
|
|
logging.error(f" >> ('{filename}') is correctly added to the whitelist file (one name per line) and saved.")
|
|
else:
|
|
logging.error(" >> Cannot determine filename to check against whitelist.")
|
|
logging.error(" >> SECURITY RISK: Only whitelist files from sources you absolutely trust.")
|
|
logging.error(" >> Prefer using .safetensors files whenever available.")
|
|
logging.error("##############################################################################")
|
|
raise e # Re-raise the original security-related error
|
|
|
|
else:
|
|
# If it's a different UnpicklingError, re-raise it. Don't attempt unsafe load.
|
|
logging.error(f"[Impact Pack/Subpack] UnpicklingError during safe load (not 'getattr' related): {e}. Re-raising.")
|
|
raise e # Re-raise other UnpicklingErrors
|
|
|
|
else:
|
|
# --- Handle Older PyTorch Versions (no safe_globals) ---
|
|
# Behavior here respects the caller's explicit request or defaults to False
|
|
load_kwargs = kwargs.copy()
|
|
effective_weights_only = weights_only_explicit if weights_only_explicit is not None else False # Default False for old torch
|
|
load_kwargs['weights_only'] = effective_weights_only
|
|
|
|
if not effective_weights_only:
|
|
logging.warning(f"[Impact Pack/Subpack] Older PyTorch version detected. Proceeding with potentially unsafe load (weights_only=False) for: {filename_arg_source}")
|
|
else:
|
|
logging.debug(f"[Impact Pack/Subpack] Older PyTorch version detected. Proceeding with explicit weights_only=True for: {filename_arg_source}")
|
|
|
|
# Call the original torch.load directly with the determined settings for older PyTorch
|
|
return orig_torch_load(*args, **load_kwargs)
|
|
|
|
# --- End: Replacement block for the torch_wrapper function ---
|
|
|
|
torch.load = torch_wrapper
|
|
|
|
|
|
def load_yolo(model_path: str):
|
|
return YOLO(model_path)
|
|
|
|
|
|
def inference_bbox(
|
|
model,
|
|
image: Image.Image,
|
|
confidence: float = 0.3,
|
|
device: str = "",
|
|
):
|
|
pred = model(image, conf=confidence, device=device)
|
|
|
|
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
|
cv2_image = np.array(image)
|
|
if len(cv2_image.shape) == 3:
|
|
cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing
|
|
else:
|
|
# Handle the grayscale image here
|
|
# For example, you might want to convert it to a 3-channel grayscale image for consistency:
|
|
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR)
|
|
cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
|
|
|
|
segms = []
|
|
for x0, y0, x1, y1 in bboxes:
|
|
cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
|
|
cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
|
|
cv2_mask_bool = cv2_mask.astype(bool)
|
|
segms.append(cv2_mask_bool)
|
|
|
|
n, m = bboxes.shape
|
|
if n == 0:
|
|
return [[], [], [], []]
|
|
|
|
results = [[], [], [], []]
|
|
for i in range(len(bboxes)):
|
|
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
|
|
results[1].append(bboxes[i])
|
|
results[2].append(segms[i])
|
|
results[3].append(pred[0].boxes[i].conf.cpu().numpy())
|
|
|
|
return results
|
|
|
|
|
|
def inference_segm(
|
|
model,
|
|
image: Image.Image,
|
|
confidence: float = 0.3,
|
|
device: str = "",
|
|
):
|
|
pred = model(image, conf=confidence, device=device)
|
|
|
|
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
|
n, m = bboxes.shape
|
|
if n == 0:
|
|
return [[], [], [], []]
|
|
|
|
# NOTE: masks.data will be None when n == 0
|
|
segms = pred[0].masks.data.cpu().numpy()
|
|
|
|
h_segms = segms.shape[1]
|
|
w_segms = segms.shape[2]
|
|
h_orig = image.size[1]
|
|
w_orig = image.size[0]
|
|
ratio_segms = h_segms / w_segms
|
|
ratio_orig = h_orig / w_orig
|
|
|
|
if ratio_segms == ratio_orig:
|
|
h_gap = 0
|
|
w_gap = 0
|
|
elif ratio_segms > ratio_orig:
|
|
h_gap = int((ratio_segms - ratio_orig) * h_segms)
|
|
w_gap = 0
|
|
else:
|
|
h_gap = 0
|
|
ratio_segms = w_segms / h_segms
|
|
ratio_orig = w_orig / h_orig
|
|
w_gap = int((ratio_segms - ratio_orig) * w_segms)
|
|
|
|
results = [[], [], [], []]
|
|
for i in range(len(bboxes)):
|
|
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
|
|
results[1].append(bboxes[i])
|
|
|
|
mask = torch.from_numpy(segms[i])
|
|
mask = mask[h_gap:mask.shape[0] - h_gap, w_gap:mask.shape[1] - w_gap]
|
|
|
|
scaled_mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(image.size[1], image.size[0]),
|
|
mode='bilinear', align_corners=False)
|
|
scaled_mask = scaled_mask.squeeze().squeeze()
|
|
|
|
results[2].append(scaled_mask.numpy())
|
|
results[3].append(pred[0].boxes[i].conf.cpu().numpy())
|
|
|
|
return results
|
|
|
|
|
|
class UltraBBoxDetector:
|
|
bbox_model = None
|
|
|
|
def __init__(self, bbox_model):
|
|
self.bbox_model = bbox_model
|
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
|
drop_size = max(drop_size, 1)
|
|
detected_results = inference_bbox(self.bbox_model, utils.tensor2pil(image), threshold)
|
|
segmasks = create_segmasks(detected_results)
|
|
|
|
if dilation > 0:
|
|
segmasks = utils.dilate_masks(segmasks, dilation)
|
|
|
|
items = []
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
|
|
for x, label in zip(segmasks, detected_results[0]):
|
|
item_bbox = x[0]
|
|
item_mask = x[1]
|
|
|
|
y1, x1, y2, x2 = item_bbox
|
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue
|
|
crop_region = utils.make_crop_region(w, h, item_bbox, crop_factor)
|
|
|
|
if detailer_hook is not None:
|
|
crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
|
|
|
|
cropped_image = utils.crop_image(image, crop_region)
|
|
cropped_mask = utils.crop_ndarray2(item_mask, crop_region)
|
|
confidence = x[2]
|
|
# bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None)
|
|
|
|
items.append(item)
|
|
|
|
shape = image.shape[1], image.shape[2]
|
|
segs = shape, items
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
|
|
segs = detailer_hook.post_detection(segs)
|
|
|
|
return segs
|
|
|
|
def detect_combined(self, image, threshold, dilation):
|
|
detected_results = inference_bbox(self.bbox_model, utils.tensor2pil(image), threshold)
|
|
segmasks = create_segmasks(detected_results)
|
|
if dilation > 0:
|
|
segmasks = utils.dilate_masks(segmasks, dilation)
|
|
|
|
return utils.combine_masks(segmasks)
|
|
|
|
def setAux(self, x):
|
|
pass
|
|
|
|
|
|
class UltraSegmDetector:
|
|
bbox_model = None
|
|
|
|
def __init__(self, bbox_model):
|
|
self.bbox_model = bbox_model
|
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
|
drop_size = max(drop_size, 1)
|
|
detected_results = inference_segm(self.bbox_model, utils.tensor2pil(image), threshold)
|
|
segmasks = create_segmasks(detected_results)
|
|
|
|
if dilation > 0:
|
|
segmasks = utils.dilate_masks(segmasks, dilation)
|
|
|
|
items = []
|
|
h = image.shape[1]
|
|
w = image.shape[2]
|
|
|
|
for x, label in zip(segmasks, detected_results[0]):
|
|
item_bbox = x[0]
|
|
item_mask = x[1]
|
|
|
|
y1, x1, y2, x2 = item_bbox
|
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue
|
|
crop_region = utils.make_crop_region(w, h, item_bbox, crop_factor)
|
|
|
|
if detailer_hook is not None:
|
|
crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
|
|
|
|
cropped_image = utils.crop_image(image, crop_region)
|
|
cropped_mask = utils.crop_ndarray2(item_mask, crop_region)
|
|
confidence = x[2]
|
|
# bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None)
|
|
|
|
items.append(item)
|
|
|
|
shape = image.shape[1], image.shape[2]
|
|
segs = shape, items
|
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
|
|
segs = detailer_hook.post_detection(segs)
|
|
|
|
return segs
|
|
|
|
def detect_combined(self, image, threshold, dilation):
|
|
detected_results = inference_segm(self.bbox_model, utils.tensor2pil(image), threshold)
|
|
segmasks = create_segmasks(detected_results)
|
|
if dilation > 0:
|
|
segmasks = utils.dilate_masks(segmasks, dilation)
|
|
|
|
return utils.combine_masks(segmasks)
|
|
|
|
def setAux(self, x):
|
|
pass
|