Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
527
custom_nodes/comfyui-impact-subpack/modules/subcore.py
Normal file
527
custom_nodes/comfyui-impact-subpack/modules/subcore.py
Normal file
@@ -0,0 +1,527 @@
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import namedtuple
|
||||
from . import utils
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pickle
|
||||
import folder_paths
|
||||
|
||||
|
||||
orig_torch_load = torch.load
|
||||
|
||||
|
||||
SEG = namedtuple("SEG",
|
||||
['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'],
|
||||
defaults=[None])
|
||||
|
||||
|
||||
# --- Whitelist Configuration ---
|
||||
WHITELIST_DIR = None
|
||||
WHITELIST_FILE_PATH = None
|
||||
|
||||
try:
|
||||
# --- Attempting: Use ComfyUI's folder_paths (Preferred Method) ---
|
||||
user_dir = folder_paths.get_user_directory()
|
||||
if user_dir and os.path.isdir(user_dir):
|
||||
WHITELIST_DIR = os.path.join(user_dir, "default", "ComfyUI-Impact-Subpack")
|
||||
WHITELIST_FILE_PATH = os.path.join(WHITELIST_DIR, "model-whitelist.txt")
|
||||
logging.info(f"[Impact Pack/Subpack] Using folder_paths to determine whitelist path: {WHITELIST_FILE_PATH}")
|
||||
else:
|
||||
logging.warning(f"[Impact Pack/Subpack] folder_paths.get_user_directory() returned invalid path: {user_dir}.")
|
||||
|
||||
# --- Ensure directory exists---
|
||||
if WHITELIST_FILE_PATH: # Check if any method succeeded in setting the path
|
||||
try:
|
||||
# Crucially, create the DIRECTORY first
|
||||
# Use the WHITELIST_DIR determined by one of the methods above
|
||||
os.makedirs(WHITELIST_DIR, exist_ok=True)
|
||||
logging.info(f"[Impact Pack/Subpack] Ensured whitelist directory exists: {WHITELIST_DIR}")
|
||||
except OSError as e:
|
||||
logging.error(f"[Impact Pack/Subpack] Failed to create whitelist directory {WHITELIST_DIR}: {e}. Whitelisting may not function.")
|
||||
WHITELIST_FILE_PATH = None # Indicate failure / disable whitelisting
|
||||
except Exception as e:
|
||||
logging.error(f"[Impact Pack/Subpack] Unexpected error creating whitelist directory: {e}", exc_info=True)
|
||||
WHITELIST_FILE_PATH = None # Indicate failure / disable whitelisting
|
||||
else:
|
||||
# Handle case where path determination failed via all methods
|
||||
logging.error("[Impact Pack/Subpack] Whitelist path determination failed using all methods. Whitelisting disabled.")
|
||||
# WHITELIST_FILE_PATH is already None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Catch errors during the whole setup process (e.g., inspect failing)
|
||||
logging.error(f"[Impact Pack/Subpack] Critical error during whitelist path setup: {e}", exc_info=True)
|
||||
WHITELIST_FILE_PATH = None # Disable whitelisting on critical setup error
|
||||
logging.error("[Impact Pack/Subpack] Whitelisting disabled due to critical setup error.")
|
||||
|
||||
|
||||
def load_whitelist(filepath):
|
||||
"""
|
||||
Loads filenames from the whitelist file.
|
||||
Attempts to create the file with instructions if it doesn't exist.
|
||||
Returns a set of approved base filenames.
|
||||
"""
|
||||
approved_files = set()
|
||||
# Check again if filepath is valid before proceeding
|
||||
if filepath is None or not isinstance(filepath, str):
|
||||
# Log was already done if None during setup, avoid duplicate messages
|
||||
# logging.error("[Impact Pack/Subpack] Whitelist file path is invalid. Whitelisting disabled.")
|
||||
return approved_files # Return empty set
|
||||
|
||||
try:
|
||||
# Try reading the existing file
|
||||
with open(filepath, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
# Store only the base filename for easier matching
|
||||
if line and not line.startswith('#'):
|
||||
approved_files.add(os.path.basename(line))
|
||||
logging.info(f"[Impact Pack/Subpack] Loaded {len(approved_files)} model(s) from whitelist: {filepath}")
|
||||
|
||||
except FileNotFoundError:
|
||||
# This block now runs only if the directory was created successfully but the file is missing
|
||||
logging.warning(f"[Impact Pack/Subpack] Model whitelist file not found at: {filepath}. ")
|
||||
logging.warning(" >> An empty whitelist file will be created.")
|
||||
logging.warning(" >> To allow unsafe loading for specific trusted legacy models (e.g., older .pt),")
|
||||
logging.warning(" >> add their base filenames (one per line) to this file.")
|
||||
try:
|
||||
# Attempt to create the file with comments since it wasn't found
|
||||
# This should now succeed because os.makedirs created the directory
|
||||
with open(filepath, 'w') as f:
|
||||
f.write("# Add base filenames of trusted models (e.g., my_old_yolo.pt) here, one per line.\n")
|
||||
f.write("# This allows loading them with `weights_only=False` if they fail safe loading\n")
|
||||
f.write("# due to errors like 'restricted getattr' in newer PyTorch versions.\n")
|
||||
f.write("# WARNING: Only add files you absolutely trust, as this bypasses a security feature.\n")
|
||||
f.write("# Prefer using .safetensors files whenever possible.\n")
|
||||
logging.info(f"[Impact Pack/Subpack] Created empty whitelist file: {filepath}")
|
||||
except Exception as create_e:
|
||||
# Log error if creating the file fails even after creating the directory
|
||||
logging.error(f"[Impact Pack/Subpack] Failed to create empty whitelist file at {filepath}: {create_e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"[Impact Pack/Subpack] Error loading model whitelist from {filepath}: {e}", exc_info=True)
|
||||
|
||||
return approved_files
|
||||
|
||||
# Now call the function using the dynamically determined (or None) path
|
||||
_MODEL_WHITELIST = load_whitelist(WHITELIST_FILE_PATH)
|
||||
|
||||
# ---------- End of Whitelist Management ----------
|
||||
|
||||
class NO_BBOX_DETECTOR:
|
||||
pass
|
||||
|
||||
|
||||
class NO_SEGM_DETECTOR:
|
||||
pass
|
||||
|
||||
|
||||
def create_segmasks(results):
|
||||
bboxs = results[1]
|
||||
segms = results[2]
|
||||
confidence = results[3]
|
||||
|
||||
results = []
|
||||
for i in range(len(segms)):
|
||||
item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
|
||||
# Limit the commands that can be executed through `getattr` to `ultralytics.nn.modules.head.Detect.forward`.
|
||||
def restricted_getattr(obj, name, *args):
|
||||
if name != "forward":
|
||||
logging.error(f"Access to potentially dangerous attribute '{obj.__module__}.{obj.__name__}.{name}' is blocked.\nIf you believe the use of this code is genuinely safe, please report it.\nhttps://github.com/ltdrdata/ComfyUI-Impact-Subpack/issues")
|
||||
raise RuntimeError(f"Access to potentially dangerous attribute '{obj.__module__}.{obj.__name__}.{name}' is blocked.")
|
||||
|
||||
return getattr(obj, name, *args)
|
||||
|
||||
restricted_getattr.__module__ = 'builtins'
|
||||
restricted_getattr.__name__ = 'getattr'
|
||||
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.nn.tasks import DetectionModel
|
||||
from ultralytics.nn.tasks import SegmentationModel
|
||||
from ultralytics.utils import IterableSimpleNamespace
|
||||
from ultralytics.utils.tal import TaskAlignedAssigner
|
||||
import ultralytics.nn.modules as modules
|
||||
import ultralytics.nn.modules.block as block_modules
|
||||
import torch.nn.modules as torch_modules
|
||||
import ultralytics.utils.loss as loss_modules
|
||||
import dill._dill
|
||||
from numpy.core.multiarray import scalar
|
||||
try:
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
except:
|
||||
logging.error("[Impact Subpack] installed 'numpy' is outdated. Please update 'numpy>=1.26.4'")
|
||||
raise Exception("[Impact Subpack] installed 'numpy' is outdated. Please update 'numpy>=1.26.4'")
|
||||
|
||||
|
||||
torch_whitelist = []
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.error("\n!!!!!\n\n[ComfyUI-Impact-Subpack] If this error occurs, please check the following link:\n\thttps://github.com/ltdrdata/ComfyUI-Impact-Pack/blob/Main/troubleshooting/TROUBLESHOOTING.md\n\n!!!!!\n")
|
||||
raise e
|
||||
|
||||
# HOTFIX: https://github.com/ltdrdata/ComfyUI-Impact-Pack/issues/754
|
||||
# importing YOLO breaking original torch.load capabilities
|
||||
|
||||
# --- Start: REPLACE the existing torch_wrapper function ---
|
||||
|
||||
def torch_wrapper(*args, **kwargs):
|
||||
"""
|
||||
Wrapper for torch.load that attempts safe loading (weights_only=True) first.
|
||||
If a specific UnpicklingError related to disallowed globals (like 'getattr')
|
||||
occurs, it checks a user-defined whitelist (_MODEL_WHITELIST). If the file
|
||||
is whitelisted, it retries with weights_only=False. Otherwise, it blocks
|
||||
the unsafe load and raises the error.
|
||||
"""
|
||||
# Use the globally saved original torch.load reference from the top of the file
|
||||
# Check if weights_only was explicitly passed by the caller
|
||||
# Explicitly declare modification of global scope is intended
|
||||
global _MODEL_WHITELIST
|
||||
weights_only_explicit = kwargs.get('weights_only', None) # Read value without popping yet
|
||||
|
||||
# Try to get the filename being loaded (usually the first arg if it's a path)
|
||||
filename = None
|
||||
filename_arg_source = "[unknown source]"
|
||||
if args and isinstance(args[0], str):
|
||||
filename = os.path.basename(args[0]) # Get just the filename part
|
||||
filename_arg_source = args[0]
|
||||
elif 'f' in kwargs and isinstance(kwargs['f'], str):
|
||||
filename = os.path.basename(kwargs['f']) # Get just the filename part
|
||||
filename_arg_source = kwargs['f']
|
||||
# Note: filename might remain None if loading from a file-like object
|
||||
|
||||
# Check if using newer PyTorch with safe_globals attribute (indicates >= 2.6 behavior likely)
|
||||
if hasattr(torch.serialization, 'safe_globals'):
|
||||
|
||||
# Determine the effective weights_only setting for the FIRST attempt
|
||||
load_kwargs = kwargs.copy()
|
||||
|
||||
try:
|
||||
# --- Attempt 1: Default Load ---
|
||||
# Try loading with the determined weights_only setting (usually True)
|
||||
logging.debug(f"[Impact Pack/Subpack] Attempting safe load (weights_only=True) for: {filename_arg_source}")
|
||||
return orig_torch_load(*args, **load_kwargs)
|
||||
|
||||
except pickle.UnpicklingError as e:
|
||||
# --- Handle Specific Load Failure ---
|
||||
# Check if the error is the specific one caused by disallowed globals
|
||||
# like 'getattr' AND we were attempting a safe load (weights_only=True)
|
||||
# Using 'getattr' because it was the specific error reported.
|
||||
is_disallowed_global_error = 'getattr' in str(e)
|
||||
|
||||
if is_disallowed_global_error:
|
||||
# Check the whitelist
|
||||
if filename and filename in _MODEL_WHITELIST:
|
||||
# --- Fallback: Whitelisted Unsafe Load ---
|
||||
logging.warning("##############################################################################")
|
||||
logging.warning(f"[Impact Pack/Subpack] WARNING: Safe load failed for '{filename}' (Reason: {e}).")
|
||||
logging.warning(f" >> FILE IS IN THE WHITELIST: {WHITELIST_FILE_PATH}")
|
||||
logging.warning(" >> This model likely uses legacy Python features blocked by default for security.")
|
||||
logging.warning(" >> RETRYING WITH 'weights_only=False' because it's whitelisted.")
|
||||
logging.warning(" >> SECURITY RISK: Ensure you added this file to the whitelist consciously")
|
||||
logging.warning(f" >> and trust its source: {filename_arg_source}")
|
||||
logging.warning(" >> Prefer using .safetensors files whenever available.")
|
||||
logging.warning("##############################################################################")
|
||||
|
||||
retry_kwargs = kwargs.copy()
|
||||
retry_kwargs['weights_only'] = False
|
||||
# Call the original function again, now unsafely (because whitelisted)
|
||||
return orig_torch_load(*args, **retry_kwargs)
|
||||
|
||||
else:
|
||||
# --- File not in current whitelist, try reloading ---
|
||||
logging.warning(f"[Impact Pack/Subpack] File '{filename}' not found in current whitelist cache.")
|
||||
whitelist_path_msg = WHITELIST_FILE_PATH if WHITELIST_FILE_PATH else "[Path not determined]"
|
||||
logging.info(f"[Impact Pack/Subpack] Attempting to reload whitelist from: {whitelist_path_msg}")
|
||||
try:
|
||||
# Reload the whitelist from the file
|
||||
_MODEL_WHITELIST = load_whitelist(WHITELIST_FILE_PATH)
|
||||
logging.info(f"[Impact Pack/Subpack] Whitelist reloaded. Now contains {len(_MODEL_WHITELIST)} entries.")
|
||||
|
||||
# --- Re-check Whitelist After Reload ---
|
||||
if filename and filename in _MODEL_WHITELIST:
|
||||
logging.warning("##############################################################################")
|
||||
logging.warning(f"[Impact Pack/Subpack] SUCCESS: File '{filename}' FOUND in reloaded whitelist.")
|
||||
logging.warning(" >> Proceeding with whitelisted unsafe load (weights_only=False).")
|
||||
logging.warning(f" >> Ensure you recently added this file to: {whitelist_path_msg}")
|
||||
logging.warning(" >> SECURITY RISK: Ensure you trust its source.")
|
||||
logging.warning("##############################################################################")
|
||||
retry_kwargs = kwargs.copy()
|
||||
retry_kwargs['weights_only'] = False
|
||||
return orig_torch_load(*args, **retry_kwargs)
|
||||
else:
|
||||
# File still not found after reload, proceed with blocking
|
||||
logging.error("[Impact Pack/Subpack] File still not found in whitelist after reload.")
|
||||
# Fall through to the original blocking logic below
|
||||
|
||||
except Exception as reload_e:
|
||||
logging.error(f"[Impact Pack/Subpack] Error occurred during whitelist reload attempt: {reload_e}", exc_info=True)
|
||||
# Fall through to the original blocking logic below if reload fails
|
||||
|
||||
# --- Blocked: Not Whitelisted (Original Logic - runs if reload failed or file still not found) ---
|
||||
logging.error("##############################################################################")
|
||||
logging.error(f"[Impact Pack/Subpack] ERROR: Safe load failed for '{filename_arg_source}' (Reason: {e}).")
|
||||
logging.error(" >> This model likely uses legacy Python features blocked by default for security.")
|
||||
# Updated log message here:
|
||||
logging.error(f" >> UNSAFE LOAD BLOCKED because the file ('{filename or 'unknown'}') is NOT in the whitelist (even after reload attempt).")
|
||||
logging.error(f" >> Whitelist path: {whitelist_path_msg}")
|
||||
if filename:
|
||||
logging.error(" >> To allow loading this specific file (IF YOU TRUST IT), ensure its base name")
|
||||
logging.error(f" >> ('{filename}') is correctly added to the whitelist file (one name per line) and saved.")
|
||||
else:
|
||||
logging.error(" >> Cannot determine filename to check against whitelist.")
|
||||
logging.error(" >> SECURITY RISK: Only whitelist files from sources you absolutely trust.")
|
||||
logging.error(" >> Prefer using .safetensors files whenever available.")
|
||||
logging.error("##############################################################################")
|
||||
raise e # Re-raise the original security-related error
|
||||
|
||||
else:
|
||||
# If it's a different UnpicklingError, re-raise it. Don't attempt unsafe load.
|
||||
logging.error(f"[Impact Pack/Subpack] UnpicklingError during safe load (not 'getattr' related): {e}. Re-raising.")
|
||||
raise e # Re-raise other UnpicklingErrors
|
||||
|
||||
else:
|
||||
# --- Handle Older PyTorch Versions (no safe_globals) ---
|
||||
# Behavior here respects the caller's explicit request or defaults to False
|
||||
load_kwargs = kwargs.copy()
|
||||
effective_weights_only = weights_only_explicit if weights_only_explicit is not None else False # Default False for old torch
|
||||
load_kwargs['weights_only'] = effective_weights_only
|
||||
|
||||
if not effective_weights_only:
|
||||
logging.warning(f"[Impact Pack/Subpack] Older PyTorch version detected. Proceeding with potentially unsafe load (weights_only=False) for: {filename_arg_source}")
|
||||
else:
|
||||
logging.debug(f"[Impact Pack/Subpack] Older PyTorch version detected. Proceeding with explicit weights_only=True for: {filename_arg_source}")
|
||||
|
||||
# Call the original torch.load directly with the determined settings for older PyTorch
|
||||
return orig_torch_load(*args, **load_kwargs)
|
||||
|
||||
# --- End: Replacement block for the torch_wrapper function ---
|
||||
|
||||
torch.load = torch_wrapper
|
||||
|
||||
|
||||
def load_yolo(model_path: str):
|
||||
return YOLO(model_path)
|
||||
|
||||
|
||||
def inference_bbox(
|
||||
model,
|
||||
image: Image.Image,
|
||||
confidence: float = 0.3,
|
||||
device: str = "",
|
||||
):
|
||||
pred = model(image, conf=confidence, device=device)
|
||||
|
||||
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
||||
cv2_image = np.array(image)
|
||||
if len(cv2_image.shape) == 3:
|
||||
cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing
|
||||
else:
|
||||
# Handle the grayscale image here
|
||||
# For example, you might want to convert it to a 3-channel grayscale image for consistency:
|
||||
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_GRAY2BGR)
|
||||
cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
segms = []
|
||||
for x0, y0, x1, y1 in bboxes:
|
||||
cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
|
||||
cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
|
||||
cv2_mask_bool = cv2_mask.astype(bool)
|
||||
segms.append(cv2_mask_bool)
|
||||
|
||||
n, m = bboxes.shape
|
||||
if n == 0:
|
||||
return [[], [], [], []]
|
||||
|
||||
results = [[], [], [], []]
|
||||
for i in range(len(bboxes)):
|
||||
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
|
||||
results[1].append(bboxes[i])
|
||||
results[2].append(segms[i])
|
||||
results[3].append(pred[0].boxes[i].conf.cpu().numpy())
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def inference_segm(
|
||||
model,
|
||||
image: Image.Image,
|
||||
confidence: float = 0.3,
|
||||
device: str = "",
|
||||
):
|
||||
pred = model(image, conf=confidence, device=device)
|
||||
|
||||
bboxes = pred[0].boxes.xyxy.cpu().numpy()
|
||||
n, m = bboxes.shape
|
||||
if n == 0:
|
||||
return [[], [], [], []]
|
||||
|
||||
# NOTE: masks.data will be None when n == 0
|
||||
segms = pred[0].masks.data.cpu().numpy()
|
||||
|
||||
h_segms = segms.shape[1]
|
||||
w_segms = segms.shape[2]
|
||||
h_orig = image.size[1]
|
||||
w_orig = image.size[0]
|
||||
ratio_segms = h_segms / w_segms
|
||||
ratio_orig = h_orig / w_orig
|
||||
|
||||
if ratio_segms == ratio_orig:
|
||||
h_gap = 0
|
||||
w_gap = 0
|
||||
elif ratio_segms > ratio_orig:
|
||||
h_gap = int((ratio_segms - ratio_orig) * h_segms)
|
||||
w_gap = 0
|
||||
else:
|
||||
h_gap = 0
|
||||
ratio_segms = w_segms / h_segms
|
||||
ratio_orig = w_orig / h_orig
|
||||
w_gap = int((ratio_segms - ratio_orig) * w_segms)
|
||||
|
||||
results = [[], [], [], []]
|
||||
for i in range(len(bboxes)):
|
||||
results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
|
||||
results[1].append(bboxes[i])
|
||||
|
||||
mask = torch.from_numpy(segms[i])
|
||||
mask = mask[h_gap:mask.shape[0] - h_gap, w_gap:mask.shape[1] - w_gap]
|
||||
|
||||
scaled_mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(image.size[1], image.size[0]),
|
||||
mode='bilinear', align_corners=False)
|
||||
scaled_mask = scaled_mask.squeeze().squeeze()
|
||||
|
||||
results[2].append(scaled_mask.numpy())
|
||||
results[3].append(pred[0].boxes[i].conf.cpu().numpy())
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class UltraBBoxDetector:
|
||||
bbox_model = None
|
||||
|
||||
def __init__(self, bbox_model):
|
||||
self.bbox_model = bbox_model
|
||||
|
||||
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
||||
drop_size = max(drop_size, 1)
|
||||
detected_results = inference_bbox(self.bbox_model, utils.tensor2pil(image), threshold)
|
||||
segmasks = create_segmasks(detected_results)
|
||||
|
||||
if dilation > 0:
|
||||
segmasks = utils.dilate_masks(segmasks, dilation)
|
||||
|
||||
items = []
|
||||
h = image.shape[1]
|
||||
w = image.shape[2]
|
||||
|
||||
for x, label in zip(segmasks, detected_results[0]):
|
||||
item_bbox = x[0]
|
||||
item_mask = x[1]
|
||||
|
||||
y1, x1, y2, x2 = item_bbox
|
||||
|
||||
if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue
|
||||
crop_region = utils.make_crop_region(w, h, item_bbox, crop_factor)
|
||||
|
||||
if detailer_hook is not None:
|
||||
crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
|
||||
|
||||
cropped_image = utils.crop_image(image, crop_region)
|
||||
cropped_mask = utils.crop_ndarray2(item_mask, crop_region)
|
||||
confidence = x[2]
|
||||
# bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
|
||||
|
||||
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None)
|
||||
|
||||
items.append(item)
|
||||
|
||||
shape = image.shape[1], image.shape[2]
|
||||
segs = shape, items
|
||||
|
||||
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
|
||||
segs = detailer_hook.post_detection(segs)
|
||||
|
||||
return segs
|
||||
|
||||
def detect_combined(self, image, threshold, dilation):
|
||||
detected_results = inference_bbox(self.bbox_model, utils.tensor2pil(image), threshold)
|
||||
segmasks = create_segmasks(detected_results)
|
||||
if dilation > 0:
|
||||
segmasks = utils.dilate_masks(segmasks, dilation)
|
||||
|
||||
return utils.combine_masks(segmasks)
|
||||
|
||||
def setAux(self, x):
|
||||
pass
|
||||
|
||||
|
||||
class UltraSegmDetector:
|
||||
bbox_model = None
|
||||
|
||||
def __init__(self, bbox_model):
|
||||
self.bbox_model = bbox_model
|
||||
|
||||
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None):
|
||||
drop_size = max(drop_size, 1)
|
||||
detected_results = inference_segm(self.bbox_model, utils.tensor2pil(image), threshold)
|
||||
segmasks = create_segmasks(detected_results)
|
||||
|
||||
if dilation > 0:
|
||||
segmasks = utils.dilate_masks(segmasks, dilation)
|
||||
|
||||
items = []
|
||||
h = image.shape[1]
|
||||
w = image.shape[2]
|
||||
|
||||
for x, label in zip(segmasks, detected_results[0]):
|
||||
item_bbox = x[0]
|
||||
item_mask = x[1]
|
||||
|
||||
y1, x1, y2, x2 = item_bbox
|
||||
|
||||
if x2 - x1 > drop_size and y2 - y1 > drop_size: # minimum dimension must be (2,2) to avoid squeeze issue
|
||||
crop_region = utils.make_crop_region(w, h, item_bbox, crop_factor)
|
||||
|
||||
if detailer_hook is not None:
|
||||
crop_region = detailer_hook.post_crop_region(w, h, item_bbox, crop_region)
|
||||
|
||||
cropped_image = utils.crop_image(image, crop_region)
|
||||
cropped_mask = utils.crop_ndarray2(item_mask, crop_region)
|
||||
confidence = x[2]
|
||||
# bbox_size = (item_bbox[2]-item_bbox[0],item_bbox[3]-item_bbox[1]) # (w,h)
|
||||
|
||||
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, label, None)
|
||||
|
||||
items.append(item)
|
||||
|
||||
shape = image.shape[1], image.shape[2]
|
||||
segs = shape, items
|
||||
|
||||
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"):
|
||||
segs = detailer_hook.post_detection(segs)
|
||||
|
||||
return segs
|
||||
|
||||
def detect_combined(self, image, threshold, dilation):
|
||||
detected_results = inference_segm(self.bbox_model, utils.tensor2pil(image), threshold)
|
||||
segmasks = create_segmasks(detected_results)
|
||||
if dilation > 0:
|
||||
segmasks = utils.dilate_masks(segmasks, dilation)
|
||||
|
||||
return utils.combine_masks(segmasks)
|
||||
|
||||
def setAux(self, x):
|
||||
pass
|
||||
Reference in New Issue
Block a user