Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
251
custom_nodes/ComfyUI-KJNodes/nodes/audioscheduler_nodes.py
Normal file
251
custom_nodes/ComfyUI-KJNodes/nodes/audioscheduler_nodes.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# to be used with https://github.com/a1lazydog/ComfyUI-AudioScheduler
|
||||
import torch
|
||||
from torchvision.transforms import functional as TF
|
||||
from PIL import Image, ImageDraw
|
||||
import numpy as np
|
||||
from ..utility.utility import pil2tensor
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class NormalizedAmplitudeToMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
|
||||
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
||||
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
|
||||
"frame_offset": ("INT", {"default": 0,"min": -255, "max": 255, "step": 1}),
|
||||
"location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
|
||||
"location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
|
||||
"size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
|
||||
"shape": (
|
||||
[
|
||||
'none',
|
||||
'circle',
|
||||
'square',
|
||||
'triangle',
|
||||
],
|
||||
{
|
||||
"default": 'none'
|
||||
}),
|
||||
"color": (
|
||||
[
|
||||
'white',
|
||||
'amplitude',
|
||||
],
|
||||
{
|
||||
"default": 'amplitude'
|
||||
}),
|
||||
},}
|
||||
|
||||
CATEGORY = "KJNodes/audio"
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "convert"
|
||||
DESCRIPTION = """
|
||||
Works as a bridge to the AudioScheduler -nodes:
|
||||
https://github.com/a1lazydog/ComfyUI-AudioScheduler
|
||||
Creates masks based on the normalized amplitude.
|
||||
"""
|
||||
|
||||
def convert(self, normalized_amp, width, height, frame_offset, shape, location_x, location_y, size, color):
|
||||
# Ensure normalized_amp is an array and within the range [0, 1]
|
||||
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
|
||||
|
||||
# Offset the amplitude values by rolling the array
|
||||
normalized_amp = np.roll(normalized_amp, frame_offset)
|
||||
|
||||
# Initialize an empty list to hold the image tensors
|
||||
out = []
|
||||
# Iterate over each amplitude value to create an image
|
||||
for amp in normalized_amp:
|
||||
# Scale the amplitude value to cover the full range of grayscale values
|
||||
if color == 'amplitude':
|
||||
grayscale_value = int(amp * 255)
|
||||
elif color == 'white':
|
||||
grayscale_value = 255
|
||||
# Convert the grayscale value to an RGB format
|
||||
gray_color = (grayscale_value, grayscale_value, grayscale_value)
|
||||
finalsize = size * amp
|
||||
|
||||
if shape == 'none':
|
||||
shapeimage = Image.new("RGB", (width, height), gray_color)
|
||||
else:
|
||||
shapeimage = Image.new("RGB", (width, height), "black")
|
||||
|
||||
draw = ImageDraw.Draw(shapeimage)
|
||||
if shape == 'circle' or shape == 'square':
|
||||
# Define the bounding box for the shape
|
||||
left_up_point = (location_x - finalsize, location_y - finalsize)
|
||||
right_down_point = (location_x + finalsize,location_y + finalsize)
|
||||
two_points = [left_up_point, right_down_point]
|
||||
|
||||
if shape == 'circle':
|
||||
draw.ellipse(two_points, fill=gray_color)
|
||||
elif shape == 'square':
|
||||
draw.rectangle(two_points, fill=gray_color)
|
||||
|
||||
elif shape == 'triangle':
|
||||
# Define the points for the triangle
|
||||
left_up_point = (location_x - finalsize, location_y + finalsize) # bottom left
|
||||
right_down_point = (location_x + finalsize, location_y + finalsize) # bottom right
|
||||
top_point = (location_x, location_y) # top point
|
||||
draw.polygon([top_point, left_up_point, right_down_point], fill=gray_color)
|
||||
|
||||
shapeimage = pil2tensor(shapeimage)
|
||||
mask = shapeimage[:, :, :, 0]
|
||||
out.append(mask)
|
||||
|
||||
return (torch.cat(out, dim=0),)
|
||||
|
||||
class NormalizedAmplitudeToFloatList:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
|
||||
},}
|
||||
|
||||
CATEGORY = "KJNodes/audio"
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "convert"
|
||||
DESCRIPTION = """
|
||||
Works as a bridge to the AudioScheduler -nodes:
|
||||
https://github.com/a1lazydog/ComfyUI-AudioScheduler
|
||||
Creates a list of floats from the normalized amplitude.
|
||||
"""
|
||||
|
||||
def convert(self, normalized_amp):
|
||||
# Ensure normalized_amp is an array and within the range [0, 1]
|
||||
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
|
||||
return (normalized_amp.tolist(),)
|
||||
|
||||
class OffsetMaskByNormalizedAmplitude:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
|
||||
"mask": ("MASK",),
|
||||
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||
"rotate": ("BOOLEAN", { "default": False }),
|
||||
"angle_multiplier": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
FUNCTION = "offset"
|
||||
CATEGORY = "KJNodes/audio"
|
||||
DESCRIPTION = """
|
||||
Works as a bridge to the AudioScheduler -nodes:
|
||||
https://github.com/a1lazydog/ComfyUI-AudioScheduler
|
||||
Offsets masks based on the normalized amplitude.
|
||||
"""
|
||||
|
||||
def offset(self, mask, x, y, angle_multiplier, rotate, normalized_amp):
|
||||
|
||||
# Ensure normalized_amp is an array and within the range [0, 1]
|
||||
offsetmask = mask.clone()
|
||||
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
|
||||
|
||||
batch_size, height, width = mask.shape
|
||||
|
||||
if rotate:
|
||||
for i in range(batch_size):
|
||||
rotation_amp = int(normalized_amp[i] * (360 * angle_multiplier))
|
||||
rotation_angle = rotation_amp
|
||||
offsetmask[i] = TF.rotate(offsetmask[i].unsqueeze(0), rotation_angle).squeeze(0)
|
||||
if x != 0 or y != 0:
|
||||
for i in range(batch_size):
|
||||
offset_amp = normalized_amp[i] * 10
|
||||
shift_x = min(x*offset_amp, width-1)
|
||||
shift_y = min(y*offset_amp, height-1)
|
||||
if shift_x != 0:
|
||||
offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_x), dims=1)
|
||||
if shift_y != 0:
|
||||
offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_y), dims=0)
|
||||
|
||||
return offsetmask,
|
||||
|
||||
class ImageTransformByNormalizedAmplitude:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
|
||||
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
|
||||
"x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||
"y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
|
||||
"cumulative": ("BOOLEAN", { "default": False }),
|
||||
"image": ("IMAGE",),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "amptransform"
|
||||
CATEGORY = "KJNodes/audio"
|
||||
DESCRIPTION = """
|
||||
Works as a bridge to the AudioScheduler -nodes:
|
||||
https://github.com/a1lazydog/ComfyUI-AudioScheduler
|
||||
Transforms image based on the normalized amplitude.
|
||||
"""
|
||||
|
||||
def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
|
||||
# Ensure normalized_amp is an array and within the range [0, 1]
|
||||
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
|
||||
transformed_images = []
|
||||
|
||||
# Initialize the cumulative zoom factor
|
||||
prev_amp = 0.0
|
||||
|
||||
for i in range(image.shape[0]):
|
||||
img = image[i] # Get the i-th image in the batch
|
||||
amp = normalized_amp[i] # Get the corresponding amplitude value
|
||||
|
||||
# Incrementally increase the cumulative zoom factor
|
||||
if cumulative:
|
||||
prev_amp += amp
|
||||
amp += prev_amp
|
||||
|
||||
# Convert the image tensor from BxHxWxC to CxHxW format expected by torchvision
|
||||
img = img.permute(2, 0, 1)
|
||||
|
||||
# Convert PyTorch tensor to PIL Image for processing
|
||||
pil_img = TF.to_pil_image(img)
|
||||
|
||||
# Calculate the crop size based on the amplitude
|
||||
width, height = pil_img.size
|
||||
crop_size = int(min(width, height) * (1 - amp * zoom_scale))
|
||||
crop_size = max(crop_size, 1)
|
||||
|
||||
# Calculate the crop box coordinates (centered crop)
|
||||
left = (width - crop_size) // 2
|
||||
top = (height - crop_size) // 2
|
||||
right = (width + crop_size) // 2
|
||||
bottom = (height + crop_size) // 2
|
||||
|
||||
# Crop and resize back to original size
|
||||
cropped_img = TF.crop(pil_img, top, left, crop_size, crop_size)
|
||||
resized_img = TF.resize(cropped_img, (height, width))
|
||||
|
||||
# Convert back to tensor in CxHxW format
|
||||
tensor_img = TF.to_tensor(resized_img)
|
||||
|
||||
# Convert the tensor back to BxHxWxC format
|
||||
tensor_img = tensor_img.permute(1, 2, 0)
|
||||
|
||||
# Offset the image based on the amplitude
|
||||
offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
|
||||
shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
|
||||
shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
|
||||
|
||||
# Apply the offset to the image tensor
|
||||
if shift_x != 0:
|
||||
tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
|
||||
if shift_y != 0:
|
||||
tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
|
||||
|
||||
# Add to the list
|
||||
transformed_images.append(tensor_img)
|
||||
|
||||
# Stack all transformed images into a batch
|
||||
transformed_batch = torch.stack(transformed_images)
|
||||
|
||||
return (transformed_batch,)
|
||||
768
custom_nodes/ComfyUI-KJNodes/nodes/batchcrop_nodes.py
Normal file
768
custom_nodes/ComfyUI-KJNodes/nodes/batchcrop_nodes.py
Normal file
@@ -0,0 +1,768 @@
|
||||
from ..utility.utility import tensor2pil, pil2tensor
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.transforms import Resize, CenterCrop, InterpolationMode
|
||||
import math
|
||||
|
||||
#based on nodes from mtb https://github.com/melMass/comfy_mtb
|
||||
|
||||
def bbox_to_region(bbox, target_size=None):
|
||||
bbox = bbox_check(bbox, target_size)
|
||||
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
|
||||
|
||||
def bbox_check(bbox, target_size=None):
|
||||
if not target_size:
|
||||
return bbox
|
||||
|
||||
new_bbox = (
|
||||
bbox[0],
|
||||
bbox[1],
|
||||
min(target_size[0] - bbox[0], bbox[2]),
|
||||
min(target_size[1] - bbox[1], bbox[3]),
|
||||
)
|
||||
return new_bbox
|
||||
|
||||
class BatchCropFromMask:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"masks": ("MASK",),
|
||||
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
|
||||
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (
|
||||
"IMAGE",
|
||||
"IMAGE",
|
||||
"BBOX",
|
||||
"INT",
|
||||
"INT",
|
||||
)
|
||||
RETURN_NAMES = (
|
||||
"original_images",
|
||||
"cropped_images",
|
||||
"bboxes",
|
||||
"width",
|
||||
"height",
|
||||
)
|
||||
FUNCTION = "crop"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
|
||||
if alpha == 0:
|
||||
return prev_bbox_size
|
||||
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
|
||||
|
||||
def smooth_center(self, prev_center, curr_center, alpha=0.5):
|
||||
if alpha == 0:
|
||||
return prev_center
|
||||
return (
|
||||
round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
|
||||
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1])
|
||||
)
|
||||
|
||||
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
|
||||
|
||||
bounding_boxes = []
|
||||
cropped_images = []
|
||||
|
||||
self.max_bbox_width = 0
|
||||
self.max_bbox_height = 0
|
||||
|
||||
# First, calculate the maximum bounding box size across all masks
|
||||
curr_max_bbox_width = 0
|
||||
curr_max_bbox_height = 0
|
||||
for mask in masks:
|
||||
_mask = tensor2pil(mask)[0]
|
||||
non_zero_indices = np.nonzero(np.array(_mask))
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
width = max_x - min_x
|
||||
height = max_y - min_y
|
||||
curr_max_bbox_width = max(curr_max_bbox_width, width)
|
||||
curr_max_bbox_height = max(curr_max_bbox_height, height)
|
||||
|
||||
# Smooth the changes in the bounding box size
|
||||
self.max_bbox_width = self.smooth_bbox_size(self.max_bbox_width, curr_max_bbox_width, bbox_smooth_alpha)
|
||||
self.max_bbox_height = self.smooth_bbox_size(self.max_bbox_height, curr_max_bbox_height, bbox_smooth_alpha)
|
||||
|
||||
# Apply the crop size multiplier
|
||||
self.max_bbox_width = round(self.max_bbox_width * crop_size_mult)
|
||||
self.max_bbox_height = round(self.max_bbox_height * crop_size_mult)
|
||||
bbox_aspect_ratio = self.max_bbox_width / self.max_bbox_height
|
||||
|
||||
# Then, for each mask and corresponding image...
|
||||
for i, (mask, img) in enumerate(zip(masks, original_images)):
|
||||
_mask = tensor2pil(mask)[0]
|
||||
non_zero_indices = np.nonzero(np.array(_mask))
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
|
||||
# Calculate center of bounding box
|
||||
center_x = np.mean(non_zero_indices[1])
|
||||
center_y = np.mean(non_zero_indices[0])
|
||||
curr_center = (round(center_x), round(center_y))
|
||||
|
||||
# If this is the first frame, initialize prev_center with curr_center
|
||||
if not hasattr(self, 'prev_center'):
|
||||
self.prev_center = curr_center
|
||||
|
||||
# Smooth the changes in the center coordinates from the second frame onwards
|
||||
if i > 0:
|
||||
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
|
||||
else:
|
||||
center = curr_center
|
||||
|
||||
# Update prev_center for the next frame
|
||||
self.prev_center = center
|
||||
|
||||
# Create bounding box using max_bbox_width and max_bbox_height
|
||||
half_box_width = round(self.max_bbox_width / 2)
|
||||
half_box_height = round(self.max_bbox_height / 2)
|
||||
min_x = max(0, center[0] - half_box_width)
|
||||
max_x = min(img.shape[1], center[0] + half_box_width)
|
||||
min_y = max(0, center[1] - half_box_height)
|
||||
max_y = min(img.shape[0], center[1] + half_box_height)
|
||||
|
||||
# Append bounding box coordinates
|
||||
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
||||
|
||||
# Crop the image from the bounding box
|
||||
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||
|
||||
# Calculate the new dimensions while maintaining the aspect ratio
|
||||
new_height = min(cropped_img.shape[0], self.max_bbox_height)
|
||||
new_width = round(new_height * bbox_aspect_ratio)
|
||||
|
||||
# Resize the image
|
||||
resize_transform = Resize((new_height, new_width))
|
||||
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
|
||||
|
||||
# Perform the center crop to the desired size
|
||||
crop_transform = CenterCrop((self.max_bbox_height, self.max_bbox_width)) # swap the order here if necessary
|
||||
cropped_resized_img = crop_transform(resized_img)
|
||||
|
||||
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
|
||||
|
||||
cropped_out = torch.stack(cropped_images, dim=0)
|
||||
|
||||
return (original_images, cropped_out, bounding_boxes, self.max_bbox_width, self.max_bbox_height, )
|
||||
|
||||
class BatchUncrop:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"cropped_images": ("IMAGE",),
|
||||
"bboxes": ("BBOX",),
|
||||
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
||||
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"border_top": ("BOOLEAN", {"default": True}),
|
||||
"border_bottom": ("BOOLEAN", {"default": True}),
|
||||
"border_left": ("BOOLEAN", {"default": True}),
|
||||
"border_right": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "uncrop"
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale, border_top, border_bottom, border_left, border_right):
|
||||
def inset_border(image, border_width, border_color, border_top, border_bottom, border_left, border_right):
|
||||
draw = ImageDraw.Draw(image)
|
||||
width, height = image.size
|
||||
if border_top:
|
||||
draw.rectangle((0, 0, width, border_width), fill=border_color)
|
||||
if border_bottom:
|
||||
draw.rectangle((0, height - border_width, width, height), fill=border_color)
|
||||
if border_left:
|
||||
draw.rectangle((0, 0, border_width, height), fill=border_color)
|
||||
if border_right:
|
||||
draw.rectangle((width - border_width, 0, width, height), fill=border_color)
|
||||
return image
|
||||
|
||||
if len(original_images) != len(cropped_images):
|
||||
raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
|
||||
|
||||
# Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
|
||||
if len(bboxes) > len(original_images):
|
||||
print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
|
||||
bboxes = bboxes[:len(original_images)]
|
||||
elif len(bboxes) < len(original_images):
|
||||
raise ValueError("There should be at least as many bboxes as there are original and cropped images")
|
||||
|
||||
input_images = tensor2pil(original_images)
|
||||
crop_imgs = tensor2pil(cropped_images)
|
||||
|
||||
out_images = []
|
||||
for i in range(len(input_images)):
|
||||
img = input_images[i]
|
||||
crop = crop_imgs[i]
|
||||
bbox = bboxes[i]
|
||||
|
||||
# uncrop the image based on the bounding box
|
||||
bb_x, bb_y, bb_width, bb_height = bbox
|
||||
|
||||
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||
|
||||
# scale factors
|
||||
scale_x = crop_rescale
|
||||
scale_y = crop_rescale
|
||||
|
||||
# scaled paste_region
|
||||
paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
|
||||
|
||||
# rescale the crop image to fit the paste_region
|
||||
crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
|
||||
crop_img = crop.convert("RGB")
|
||||
|
||||
if border_blending > 1.0:
|
||||
border_blending = 1.0
|
||||
elif border_blending < 0.0:
|
||||
border_blending = 0.0
|
||||
|
||||
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
|
||||
|
||||
blend = img.convert("RGBA")
|
||||
mask = Image.new("L", img.size, 0)
|
||||
|
||||
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
|
||||
mask_block = inset_border(mask_block, round(blend_ratio / 2), (0), border_top, border_bottom, border_left, border_right)
|
||||
|
||||
mask.paste(mask_block, paste_region)
|
||||
blend.paste(crop_img, paste_region)
|
||||
|
||||
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
|
||||
|
||||
blend.putalpha(mask)
|
||||
img = Image.alpha_composite(img.convert("RGBA"), blend)
|
||||
out_images.append(img.convert("RGB"))
|
||||
|
||||
return (pil2tensor(out_images),)
|
||||
|
||||
class BatchCropFromMaskAdvanced:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"masks": ("MASK",),
|
||||
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (
|
||||
"IMAGE",
|
||||
"IMAGE",
|
||||
"MASK",
|
||||
"IMAGE",
|
||||
"MASK",
|
||||
"BBOX",
|
||||
"BBOX",
|
||||
"INT",
|
||||
"INT",
|
||||
)
|
||||
RETURN_NAMES = (
|
||||
"original_images",
|
||||
"cropped_images",
|
||||
"cropped_masks",
|
||||
"combined_crop_image",
|
||||
"combined_crop_masks",
|
||||
"bboxes",
|
||||
"combined_bounding_box",
|
||||
"bbox_width",
|
||||
"bbox_height",
|
||||
)
|
||||
FUNCTION = "crop"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
|
||||
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
|
||||
|
||||
def smooth_center(self, prev_center, curr_center, alpha=0.5):
|
||||
return (round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
|
||||
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
|
||||
|
||||
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
|
||||
bounding_boxes = []
|
||||
combined_bounding_box = []
|
||||
cropped_images = []
|
||||
cropped_masks = []
|
||||
cropped_masks_out = []
|
||||
combined_crop_out = []
|
||||
combined_cropped_images = []
|
||||
combined_cropped_masks = []
|
||||
|
||||
def calculate_bbox(mask):
|
||||
non_zero_indices = np.nonzero(np.array(mask))
|
||||
|
||||
# handle empty masks
|
||||
min_x, max_x, min_y, max_y = 0, 0, 0, 0
|
||||
if len(non_zero_indices[1]) > 0 and len(non_zero_indices[0]) > 0:
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
|
||||
width = max_x - min_x
|
||||
height = max_y - min_y
|
||||
bbox_size = max(width, height)
|
||||
return min_x, max_x, min_y, max_y, bbox_size
|
||||
|
||||
combined_mask = torch.max(masks, dim=0)[0]
|
||||
_mask = tensor2pil(combined_mask)[0]
|
||||
new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
|
||||
center_x = (new_min_x + new_max_x) / 2
|
||||
center_y = (new_min_y + new_max_y) / 2
|
||||
half_box_size = round(combined_bbox_size // 2)
|
||||
new_min_x = max(0, round(center_x - half_box_size))
|
||||
new_max_x = min(original_images[0].shape[1], round(center_x + half_box_size))
|
||||
new_min_y = max(0, round(center_y - half_box_size))
|
||||
new_max_y = min(original_images[0].shape[0], round(center_y + half_box_size))
|
||||
|
||||
combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
|
||||
|
||||
self.max_bbox_size = 0
|
||||
|
||||
# First, calculate the maximum bounding box size across all masks
|
||||
curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
|
||||
# Smooth the changes in the bounding box size
|
||||
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
|
||||
# Apply the crop size multiplier
|
||||
self.max_bbox_size = round(self.max_bbox_size * crop_size_mult)
|
||||
# Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is
|
||||
self.max_bbox_size = math.ceil(self.max_bbox_size / 16) * 16
|
||||
|
||||
if self.max_bbox_size > original_images[0].shape[0] or self.max_bbox_size > original_images[0].shape[1]:
|
||||
# max_bbox_size can only be as big as our input's width or height, and it has to be even
|
||||
self.max_bbox_size = math.floor(min(original_images[0].shape[0], original_images[0].shape[1]) / 2) * 2
|
||||
|
||||
# Then, for each mask and corresponding image...
|
||||
for i, (mask, img) in enumerate(zip(masks, original_images)):
|
||||
_mask = tensor2pil(mask)[0]
|
||||
non_zero_indices = np.nonzero(np.array(_mask))
|
||||
|
||||
# check for empty masks
|
||||
if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0:
|
||||
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
|
||||
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
|
||||
|
||||
# Calculate center of bounding box
|
||||
center_x = np.mean(non_zero_indices[1])
|
||||
center_y = np.mean(non_zero_indices[0])
|
||||
curr_center = (round(center_x), round(center_y))
|
||||
|
||||
# If this is the first frame, initialize prev_center with curr_center
|
||||
if not hasattr(self, 'prev_center'):
|
||||
self.prev_center = curr_center
|
||||
|
||||
# Smooth the changes in the center coordinates from the second frame onwards
|
||||
if i > 0:
|
||||
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
|
||||
else:
|
||||
center = curr_center
|
||||
|
||||
# Update prev_center for the next frame
|
||||
self.prev_center = center
|
||||
|
||||
# Create bounding box using max_bbox_size
|
||||
half_box_size = self.max_bbox_size // 2
|
||||
min_x = max(0, center[0] - half_box_size)
|
||||
max_x = min(img.shape[1], center[0] + half_box_size)
|
||||
min_y = max(0, center[1] - half_box_size)
|
||||
max_y = min(img.shape[0], center[1] + half_box_size)
|
||||
|
||||
# Append bounding box coordinates
|
||||
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
|
||||
|
||||
# Crop the image from the bounding box
|
||||
cropped_img = img[min_y:max_y, min_x:max_x, :]
|
||||
cropped_mask = mask[min_y:max_y, min_x:max_x]
|
||||
|
||||
# Resize the cropped image to a fixed size
|
||||
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
|
||||
resize_transform = Resize(new_size, interpolation=InterpolationMode.NEAREST, max_size=max(img.shape[0], img.shape[1]))
|
||||
resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
|
||||
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
|
||||
# Perform the center crop to the desired size
|
||||
# Constrain the crop to the smaller of our bbox or our image so we don't expand past the image dimensions.
|
||||
crop_transform = CenterCrop((min(self.max_bbox_size, resized_img.shape[1]), min(self.max_bbox_size, resized_img.shape[2])))
|
||||
|
||||
cropped_resized_img = crop_transform(resized_img)
|
||||
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
|
||||
|
||||
cropped_resized_mask = crop_transform(resized_mask)
|
||||
cropped_masks.append(cropped_resized_mask)
|
||||
|
||||
combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
|
||||
combined_cropped_images.append(combined_cropped_img)
|
||||
|
||||
combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
|
||||
combined_cropped_masks.append(combined_cropped_mask)
|
||||
else:
|
||||
bounding_boxes.append((0, 0, img.shape[1], img.shape[0]))
|
||||
cropped_images.append(img)
|
||||
cropped_masks.append(mask)
|
||||
combined_cropped_images.append(img)
|
||||
combined_cropped_masks.append(mask)
|
||||
|
||||
cropped_out = torch.stack(cropped_images, dim=0)
|
||||
combined_crop_out = torch.stack(combined_cropped_images, dim=0)
|
||||
cropped_masks_out = torch.stack(cropped_masks, dim=0)
|
||||
combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
|
||||
|
||||
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
|
||||
|
||||
class FilterZeroMasksAndCorrespondingImages:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"masks": ("MASK",),
|
||||
},
|
||||
"optional": {
|
||||
"original_images": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK", "IMAGE", "IMAGE", "INDEXES",)
|
||||
RETURN_NAMES = ("non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", "zero_mask_images_out_indexes",)
|
||||
FUNCTION = "filter"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
DESCRIPTION = """
|
||||
Filter out all the empty (i.e. all zero) mask in masks
|
||||
Also filter out all the corresponding images in original_images by indexes if provide
|
||||
|
||||
original_images (optional): If provided, need have same length as masks.
|
||||
"""
|
||||
|
||||
def filter(self, masks, original_images=None):
|
||||
non_zero_masks = []
|
||||
non_zero_mask_images = []
|
||||
zero_mask_images = []
|
||||
zero_mask_images_indexes = []
|
||||
|
||||
masks_num = len(masks)
|
||||
also_process_images = False
|
||||
if original_images is not None:
|
||||
imgs_num = len(original_images)
|
||||
if len(original_images) == masks_num:
|
||||
also_process_images = True
|
||||
else:
|
||||
print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
|
||||
|
||||
for i in range(masks_num):
|
||||
non_zero_num = np.count_nonzero(np.array(masks[i]))
|
||||
if non_zero_num > 0:
|
||||
non_zero_masks.append(masks[i])
|
||||
if also_process_images:
|
||||
non_zero_mask_images.append(original_images[i])
|
||||
else:
|
||||
zero_mask_images.append(original_images[i])
|
||||
zero_mask_images_indexes.append(i)
|
||||
|
||||
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
|
||||
non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
|
||||
|
||||
if also_process_images:
|
||||
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
|
||||
if len(zero_mask_images) > 0:
|
||||
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
|
||||
zero_mask_images_out_indexes = zero_mask_images_indexes
|
||||
|
||||
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)
|
||||
|
||||
class InsertImageBatchByIndexes:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"images_to_insert": ("IMAGE",),
|
||||
"insert_indexes": ("INDEXES",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", )
|
||||
RETURN_NAMES = ("images_after_insert", )
|
||||
FUNCTION = "insert"
|
||||
CATEGORY = "KJNodes/image"
|
||||
DESCRIPTION = """
|
||||
This node is designed to be use with node FilterZeroMasksAndCorrespondingImages
|
||||
It inserts the images_to_insert into images according to insert_indexes
|
||||
|
||||
Returns:
|
||||
images_after_insert: updated original images with origonal sequence order
|
||||
"""
|
||||
|
||||
def insert(self, images, images_to_insert, insert_indexes):
|
||||
images_after_insert = images
|
||||
|
||||
if images_to_insert is not None and insert_indexes is not None:
|
||||
images_to_insert_num = len(images_to_insert)
|
||||
insert_indexes_num = len(insert_indexes)
|
||||
if images_to_insert_num == insert_indexes_num:
|
||||
images_after_insert = []
|
||||
|
||||
i_images = 0
|
||||
for i in range(len(images) + images_to_insert_num):
|
||||
if i in insert_indexes:
|
||||
images_after_insert.append(images_to_insert[insert_indexes.index(i)])
|
||||
else:
|
||||
images_after_insert.append(images[i_images])
|
||||
i_images += 1
|
||||
|
||||
images_after_insert = torch.stack(images_after_insert, dim=0)
|
||||
|
||||
else:
|
||||
print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})")
|
||||
|
||||
|
||||
return (images_after_insert, )
|
||||
|
||||
class BatchUncropAdvanced:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"original_images": ("IMAGE",),
|
||||
"cropped_images": ("IMAGE",),
|
||||
"cropped_masks": ("MASK",),
|
||||
"combined_crop_mask": ("MASK",),
|
||||
"bboxes": ("BBOX",),
|
||||
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
||||
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"use_combined_mask": ("BOOLEAN", {"default": False}),
|
||||
"use_square_mask": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
"combined_bounding_box": ("BBOX", {"default": None}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "uncrop"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
|
||||
def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None):
|
||||
|
||||
def inset_border(image, border_width=20, border_color=(0)):
|
||||
width, height = image.size
|
||||
bordered_image = Image.new(image.mode, (width, height), border_color)
|
||||
bordered_image.paste(image, (0, 0))
|
||||
draw = ImageDraw.Draw(bordered_image)
|
||||
draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width)
|
||||
return bordered_image
|
||||
|
||||
if len(original_images) != len(cropped_images):
|
||||
raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
|
||||
|
||||
# Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
|
||||
if len(bboxes) > len(original_images):
|
||||
print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
|
||||
bboxes = bboxes[:len(original_images)]
|
||||
elif len(bboxes) < len(original_images):
|
||||
raise ValueError("There should be at least as many bboxes as there are original and cropped images")
|
||||
|
||||
crop_imgs = tensor2pil(cropped_images)
|
||||
input_images = tensor2pil(original_images)
|
||||
out_images = []
|
||||
|
||||
for i in range(len(input_images)):
|
||||
img = input_images[i]
|
||||
crop = crop_imgs[i]
|
||||
bbox = bboxes[i]
|
||||
|
||||
if use_combined_mask:
|
||||
bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0]
|
||||
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||
mask = combined_crop_mask[i]
|
||||
else:
|
||||
bb_x, bb_y, bb_width, bb_height = bbox
|
||||
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
|
||||
mask = cropped_masks[i]
|
||||
|
||||
# scale paste_region
|
||||
scale_x = scale_y = crop_rescale
|
||||
paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
|
||||
|
||||
# rescale the crop image to fit the paste_region
|
||||
crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
|
||||
crop_img = crop.convert("RGB")
|
||||
|
||||
#border blending
|
||||
if border_blending > 1.0:
|
||||
border_blending = 1.0
|
||||
elif border_blending < 0.0:
|
||||
border_blending = 0.0
|
||||
|
||||
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
|
||||
blend = img.convert("RGBA")
|
||||
|
||||
if use_square_mask:
|
||||
mask = Image.new("L", img.size, 0)
|
||||
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
|
||||
mask_block = inset_border(mask_block, round(blend_ratio / 2), (0))
|
||||
mask.paste(mask_block, paste_region)
|
||||
else:
|
||||
original_mask = tensor2pil(mask)[0]
|
||||
original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]))
|
||||
mask = Image.new("L", img.size, 0)
|
||||
mask.paste(original_mask, paste_region)
|
||||
|
||||
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
|
||||
|
||||
blend.paste(crop_img, paste_region)
|
||||
blend.putalpha(mask)
|
||||
|
||||
img = Image.alpha_composite(img.convert("RGBA"), blend)
|
||||
out_images.append(img.convert("RGB"))
|
||||
|
||||
return (pil2tensor(out_images),)
|
||||
|
||||
class SplitBboxes:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"bboxes": ("BBOX",),
|
||||
"index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BBOX","BBOX",)
|
||||
RETURN_NAMES = ("bboxes_a","bboxes_b",)
|
||||
FUNCTION = "splitbbox"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
DESCRIPTION = """
|
||||
Splits the specified bbox list at the given index into two lists.
|
||||
"""
|
||||
|
||||
def splitbbox(self, bboxes, index):
|
||||
bboxes_a = bboxes[:index] # Sub-list from the start of bboxes up to (but not including) the index
|
||||
bboxes_b = bboxes[index:] # Sub-list from the index to the end of bboxes
|
||||
|
||||
return (bboxes_a, bboxes_b,)
|
||||
|
||||
class BboxToInt:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"bboxes": ("BBOX",),
|
||||
"index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT","INT","INT","INT","INT","INT",)
|
||||
RETURN_NAMES = ("x_min","y_min","width","height", "center_x","center_y",)
|
||||
FUNCTION = "bboxtoint"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
DESCRIPTION = """
|
||||
Returns selected index from bounding box list as integers.
|
||||
"""
|
||||
def bboxtoint(self, bboxes, index):
|
||||
x_min, y_min, width, height = bboxes[index]
|
||||
center_x = int(x_min + width / 2)
|
||||
center_y = int(y_min + height / 2)
|
||||
|
||||
return (x_min, y_min, width, height, center_x, center_y,)
|
||||
|
||||
class BboxVisualize:
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE",),
|
||||
"bboxes": ("BBOX",),
|
||||
"line_width": ("INT", {"default": 1,"min": 1, "max": 10, "step": 1}),
|
||||
"bbox_format": (["xywh", "xyxy"], {"default": "xywh"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "visualizebbox"
|
||||
DESCRIPTION = """
|
||||
Visualizes the specified bbox on the image.
|
||||
"""
|
||||
|
||||
CATEGORY = "KJNodes/masking"
|
||||
|
||||
def visualizebbox(self, bboxes, images, line_width, bbox_format):
|
||||
image_list = []
|
||||
for image, bbox in zip(images, bboxes):
|
||||
# Ensure bbox is a sequence of 4 values
|
||||
if isinstance(bbox, (list, tuple, np.ndarray)) and len(bbox) == 4:
|
||||
if bbox_format == "xywh":
|
||||
x_min, y_min, width, height = bbox
|
||||
elif bbox_format == "xyxy":
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
width = x_max - x_min
|
||||
height = y_max - y_min
|
||||
else:
|
||||
raise ValueError(f"Unknown bbox_format: {bbox_format}")
|
||||
else:
|
||||
print("Invalid bbox:", bbox)
|
||||
continue
|
||||
|
||||
# Ensure bbox coordinates are integers
|
||||
x_min = int(x_min)
|
||||
y_min = int(y_min)
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
|
||||
# Permute the image dimensions
|
||||
image = image.permute(2, 0, 1)
|
||||
|
||||
# Clone the image to draw bounding boxes
|
||||
img_with_bbox = image.clone()
|
||||
|
||||
# Define the color for the bbox, e.g., red
|
||||
color = torch.tensor([1, 0, 0], dtype=torch.float32)
|
||||
|
||||
# Ensure color tensor matches the image channels
|
||||
if color.shape[0] != img_with_bbox.shape[0]:
|
||||
color = color.unsqueeze(1).expand(-1, line_width)
|
||||
|
||||
# Draw lines for each side of the bbox with the specified line width
|
||||
for lw in range(line_width):
|
||||
# Top horizontal line
|
||||
if y_min + lw < img_with_bbox.shape[1]:
|
||||
img_with_bbox[:, y_min + lw, x_min:x_min + width] = color[:, None]
|
||||
|
||||
# Bottom horizontal line
|
||||
if y_min + height - lw < img_with_bbox.shape[1]:
|
||||
img_with_bbox[:, y_min + height - lw, x_min:x_min + width] = color[:, None]
|
||||
|
||||
# Left vertical line
|
||||
if x_min + lw < img_with_bbox.shape[2]:
|
||||
img_with_bbox[:, y_min:y_min + height, x_min + lw] = color[:, None]
|
||||
|
||||
# Right vertical line
|
||||
if x_min + width - lw < img_with_bbox.shape[2]:
|
||||
img_with_bbox[:, y_min:y_min + height, x_min + width - lw] = color[:, None]
|
||||
|
||||
# Permute the image dimensions back
|
||||
img_with_bbox = img_with_bbox.permute(1, 2, 0).unsqueeze(0)
|
||||
image_list.append(img_with_bbox)
|
||||
|
||||
return (torch.cat(image_list, dim=0),)
|
||||
1645
custom_nodes/ComfyUI-KJNodes/nodes/curve_nodes.py
Normal file
1645
custom_nodes/ComfyUI-KJNodes/nodes/curve_nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
4226
custom_nodes/ComfyUI-KJNodes/nodes/image_nodes.py
Normal file
4226
custom_nodes/ComfyUI-KJNodes/nodes/image_nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
115
custom_nodes/ComfyUI-KJNodes/nodes/intrinsic_lora_nodes.py
Normal file
115
custom_nodes/ComfyUI-KJNodes/nodes/intrinsic_lora_nodes.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import folder_paths
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from comfy.utils import ProgressBar, load_torch_file
|
||||
import comfy.sample
|
||||
from nodes import CLIPTextEncode
|
||||
|
||||
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
folder_paths.add_model_folder_path("intrinsic_loras", os.path.join(script_directory, "intrinsic_loras"))
|
||||
|
||||
class Intrinsic_lora_sampling:
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"lora_name": (folder_paths.get_filename_list("intrinsic_loras"), ),
|
||||
"task": (
|
||||
[
|
||||
'depth map',
|
||||
'surface normals',
|
||||
'albedo',
|
||||
'shading',
|
||||
],
|
||||
{
|
||||
"default": 'depth map'
|
||||
}),
|
||||
"text": ("STRING", {"multiline": True, "default": ""}),
|
||||
"clip": ("CLIP", ),
|
||||
"vae": ("VAE", ),
|
||||
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
"image": ("IMAGE",),
|
||||
"optional_latent": ("LATENT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "LATENT",)
|
||||
FUNCTION = "onestepsample"
|
||||
CATEGORY = "KJNodes"
|
||||
DESCRIPTION = """
|
||||
Sampler to use the intrinsic loras:
|
||||
https://github.com/duxiaodan/intrinsic-lora
|
||||
These LoRAs are tiny and thus included
|
||||
with this node pack.
|
||||
"""
|
||||
|
||||
def onestepsample(self, model, lora_name, clip, vae, text, task, per_batch, image=None, optional_latent=None):
|
||||
pbar = ProgressBar(3)
|
||||
|
||||
if optional_latent is None:
|
||||
image_list = []
|
||||
for start_idx in range(0, image.shape[0], per_batch):
|
||||
sub_pixels = vae.vae_encode_crop_pixels(image[start_idx:start_idx+per_batch])
|
||||
image_list.append(vae.encode(sub_pixels[:,:,:,:3]))
|
||||
sample = torch.cat(image_list, dim=0)
|
||||
else:
|
||||
sample = optional_latent["samples"]
|
||||
noise = torch.zeros(sample.size(), dtype=sample.dtype, layout=sample.layout, device="cpu")
|
||||
prompt = task + "," + text
|
||||
positive, = CLIPTextEncode.encode(self, clip, prompt)
|
||||
negative = positive #negative shouldn't do anything in this scenario
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
#custom model sampling to pass latent through as it is
|
||||
class X0_PassThrough(comfy.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
return model_output
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
|
||||
sampling_type = X0_PassThrough
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
|
||||
#load lora
|
||||
model_clone = model.clone()
|
||||
lora_path = folder_paths.get_full_path("intrinsic_loras", lora_name)
|
||||
lora = load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_clone_with_lora = comfy.sd.load_lora_for_models(model_clone, None, lora, 1.0, 0)[0]
|
||||
|
||||
model_clone_with_lora.add_object_patch("model_sampling", model_sampling)
|
||||
|
||||
samples = {"samples": comfy.sample.sample(model_clone_with_lora, noise, 1, 1.0, "euler", "simple", positive, negative, sample,
|
||||
denoise=1.0, disable_noise=True, start_step=0, last_step=1,
|
||||
force_full_denoise=True, noise_mask=None, callback=None, disable_pbar=True, seed=None)}
|
||||
pbar.update(1)
|
||||
|
||||
decoded = []
|
||||
for start_idx in range(0, samples["samples"].shape[0], per_batch):
|
||||
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
|
||||
image_out = torch.cat(decoded, dim=0)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
if task == 'depth map':
|
||||
imax = image_out.max()
|
||||
imin = image_out.min()
|
||||
image_out = (image_out-imin)/(imax-imin)
|
||||
image_out = torch.max(image_out, dim=3, keepdim=True)[0].repeat(1, 1, 1, 3)
|
||||
elif task == 'surface normals':
|
||||
image_out = F.normalize(image_out * 2 - 1, dim=3) / 2 + 0.5
|
||||
image_out = 1.0 - image_out
|
||||
else:
|
||||
image_out = image_out.clamp(-1.,1.)
|
||||
|
||||
return (image_out, samples,)
|
||||
583
custom_nodes/ComfyUI-KJNodes/nodes/lora_nodes.py
Normal file
583
custom_nodes/ComfyUI-KJNodes/nodes/lora_nodes.py
Normal file
@@ -0,0 +1,583 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import os
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
|
||||
"""
|
||||
Extracts LoRA weights from a weight difference tensor using SVD.
|
||||
"""
|
||||
conv2d = (len(diff.shape) == 4)
|
||||
kernel_size = None if not conv2d else diff.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = diff.size()[0:2]
|
||||
|
||||
if conv2d:
|
||||
if conv2d_3x3:
|
||||
diff = diff.flatten(start_dim=1)
|
||||
else:
|
||||
diff = diff.squeeze()
|
||||
|
||||
diff_float = diff.float()
|
||||
if algorithm == "svd_lowrank":
|
||||
U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters)
|
||||
U = U @ torch.diag(S)
|
||||
Vh = V.t()
|
||||
else:
|
||||
#torch.linalg.svdvals()
|
||||
U, S, Vh = torch.linalg.svd(diff_float)
|
||||
# Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py
|
||||
if "adaptive" in lora_type:
|
||||
if lora_type == "adaptive_ratio":
|
||||
min_s = torch.max(S) * adaptive_param
|
||||
lora_rank = torch.sum(S > min_s).item()
|
||||
elif lora_type == "adaptive_energy":
|
||||
energy = torch.cumsum(S**2, dim=0)
|
||||
total_energy = torch.sum(S**2)
|
||||
threshold = adaptive_param * total_energy # e.g., adaptive_param=0.95 for 95%
|
||||
lora_rank = torch.sum(energy < threshold).item() + 1
|
||||
elif lora_type == "adaptive_quantile":
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = adaptive_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
||||
elif lora_type == "adaptive_fro":
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1
|
||||
lora_rank = max(1, min(lora_rank, len(S)))
|
||||
else:
|
||||
pass # Will print after capping
|
||||
|
||||
# Cap adaptive rank by the specified max rank
|
||||
lora_rank = min(lora_rank, rank)
|
||||
|
||||
# Calculate and print actual fro percentage retained after capping
|
||||
if lora_type == "adaptive_fro":
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank]))
|
||||
fro_percent = float(s_red_fro / s_fro)
|
||||
print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}")
|
||||
else:
|
||||
print(f"{key} Extracted LoRA rank: {lora_rank}")
|
||||
else:
|
||||
lora_rank = rank
|
||||
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_dim, in_dim, lora_rank)
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
if clamp_quantile:
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
if dist.numel() > 100_000:
|
||||
# Sample 100,000 elements for quantile estimation
|
||||
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
|
||||
dist_sample = dist[idx]
|
||||
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
|
||||
else:
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, lora_rank, 1, 1)
|
||||
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
return (U, Vh)
|
||||
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True):
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
model_diff.model.diffusion_model.cpu()
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
del model_diff
|
||||
comfy.model_management.soft_empty_cache()
|
||||
for k, v in sd.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
sd[k] = v.cpu()
|
||||
|
||||
# Get total number of keys to process for progress bar
|
||||
total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))])
|
||||
|
||||
# Create progress bar
|
||||
progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})")
|
||||
comfy_pbar = comfy.utils.ProgressBar(total_keys)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
weight_diff = sd[k]
|
||||
if weight_diff.ndim == 5:
|
||||
logging.info(f"Skipping 5D tensor for key {k}") #skip patch embed
|
||||
progress_bar.update(1)
|
||||
comfy_pbar.update(1)
|
||||
continue
|
||||
if lora_type != "full":
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
|
||||
progress_bar.update(1)
|
||||
comfy_pbar.update(1)
|
||||
continue
|
||||
try:
|
||||
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
|
||||
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
|
||||
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not generate lora weights for key {k}, error {e}")
|
||||
else:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
|
||||
|
||||
progress_bar.update(1)
|
||||
comfy_pbar.update(1)
|
||||
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu()
|
||||
progress_bar.update(1)
|
||||
comfy_pbar.update(1)
|
||||
progress_bar.close()
|
||||
return output_sd
|
||||
|
||||
class LoraExtractKJ:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{
|
||||
"finetuned_model": ("MODEL",),
|
||||
"original_model": ("MODEL",),
|
||||
"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}),
|
||||
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],),
|
||||
"algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}),
|
||||
"lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}),
|
||||
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
|
||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
||||
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}),
|
||||
"clamp_quantile": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "KJNodes/lora"
|
||||
|
||||
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile):
|
||||
if algorithm == "svd_lowrank" and lora_type != "standard":
|
||||
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
|
||||
|
||||
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype]
|
||||
m = finetuned_model.clone()
|
||||
kp = original_model.get_key_patches("diffusion_model.")
|
||||
for k in kp:
|
||||
m.add_patches({k: kp[k]}, - 1.0, 1.0)
|
||||
model_diff = m
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
|
||||
output_sd = {}
|
||||
if model_diff is not None:
|
||||
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
|
||||
if "adaptive" in lora_type:
|
||||
rank_str = f"{lora_type}_{adaptive_param:.2f}"
|
||||
else:
|
||||
rank_str = rank
|
||||
output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraExtractKJ": LoraExtractKJ
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraExtractKJ": "LoraExtractKJ"
|
||||
}
|
||||
|
||||
class LoraReduceRank:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{
|
||||
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
|
||||
"new_rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The new rank to resize the LoRA. Acts as max rank when using dynamic_method."}),
|
||||
"dynamic_method": (["disabled", "sv_ratio", "sv_cumulative", "sv_fro"], {"default": "disabled", "tooltip": "Method to use for dynamically determining new alphas and dims"}),
|
||||
"dynamic_param": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Method to use for dynamically determining new alphas and dims"}),
|
||||
"output_dtype": (["match_original", "fp16", "bf16", "fp32"], {"default": "match_original", "tooltip": "Data type to save the LoRA as."}),
|
||||
"verbose": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
|
||||
}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "Resize a LoRA model by reducing it's rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py"
|
||||
|
||||
CATEGORY = "KJNodes/lora"
|
||||
|
||||
def save(self, lora_name, new_rank, output_dtype, dynamic_method, dynamic_param, verbose):
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True)
|
||||
|
||||
if output_dtype == "fp16":
|
||||
save_dtype = torch.float16
|
||||
elif output_dtype == "bf16":
|
||||
save_dtype = torch.bfloat16
|
||||
elif output_dtype == "fp32":
|
||||
save_dtype = torch.float32
|
||||
elif output_dtype == "match_original":
|
||||
first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor))
|
||||
save_dtype = lora_sd[first_weight_key].dtype
|
||||
|
||||
new_lora_sd = {}
|
||||
for k, v in lora_sd.items():
|
||||
new_lora_sd[k.replace(".default", "")] = v
|
||||
del lora_sd
|
||||
print("Resizing Lora...")
|
||||
output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
|
||||
if dynamic_method == "disabled":
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}"
|
||||
metadata["ss_network_dim"] = "Dynamic"
|
||||
metadata["ss_network_alpha"] = "Dynamic"
|
||||
|
||||
# cast to save_dtype before calculating hashes
|
||||
for key in list(output_sd.keys()):
|
||||
value = output_sd[key]
|
||||
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
|
||||
output_sd[key] = value.to(save_dtype)
|
||||
|
||||
output_filename_prefix = "loras/" + lora_name
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, self.output_dir)
|
||||
output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else ""
|
||||
average_rank = str(int(np.mean(rank_list)))
|
||||
rank_str = new_rank if dynamic_method == "disabled" else f"dynamic_{average_rank}"
|
||||
output_checkpoint = f"{filename.replace('.safetensors', '')}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
print(f"Saving resized LoRA to {output_checkpoint}")
|
||||
|
||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LoraExtractKJ": LoraExtractKJ
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraExtractKJ": "LoraExtractKJ"
|
||||
}
|
||||
|
||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
# This version is based on
|
||||
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
LORA_DOWN_UP_FORMATS = [
|
||||
("lora_down", "lora_up"), # sd-scripts LoRA
|
||||
("lora_A", "lora_B"), # PEFT LoRA
|
||||
("down", "up"), # ControlLoRA
|
||||
]
|
||||
|
||||
# Indexing functions
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
if weight.dtype != torch.float32:
|
||||
weight = weight.to(torch.float32)
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
if weight.dtype != torch.float32:
|
||||
weight = weight.to(torch.float32)
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method == "sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method == "sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_rank = rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
# Calculate resize info
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro / s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank) / s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
rank_list = []
|
||||
|
||||
if dynamic_method:
|
||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
|
||||
o_lora_sd = lora_sd.copy()
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
|
||||
total_keys = len([k for k in lora_sd if k.endswith(".weight")])
|
||||
|
||||
pbar = comfy.utils.ProgressBar(total_keys)
|
||||
for key, value in tqdm(lora_sd.items(), leave=True, desc="Resizing LoRA weights"):
|
||||
key_parts = key.split(".")
|
||||
block_down_name = None
|
||||
for _format in LORA_DOWN_UP_FORMATS:
|
||||
# Currently we only match lora_down_name in the last two parts of key
|
||||
# because ("down", "up") are general words and may appear in block_down_name
|
||||
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
|
||||
block_down_name = ".".join(key_parts[:-2])
|
||||
lora_down_name = "." + _format[0]
|
||||
lora_up_name = "." + _format[1]
|
||||
weight_name = "." + key_parts[-1]
|
||||
break
|
||||
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
|
||||
block_down_name = ".".join(key_parts[:-1])
|
||||
lora_down_name = "." + _format[0]
|
||||
lora_up_name = "." + _format[1]
|
||||
weight_name = ""
|
||||
break
|
||||
|
||||
if block_down_name is None:
|
||||
# This parameter is not lora_down
|
||||
continue
|
||||
|
||||
# Now weight_name can be ".weight" or ""
|
||||
# Find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_down_weight = value
|
||||
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
||||
|
||||
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
||||
|
||||
if weights_loaded:
|
||||
|
||||
conv2d = len(lora_down_weight.size()) == 4
|
||||
old_rank = lora_down_weight.size()[0]
|
||||
max_old_rank = max(max_old_rank or 0, old_rank)
|
||||
|
||||
# Skip if merged weight would be too large (>100k elements in any dimension)
|
||||
if conv2d:
|
||||
in_rank, in_size, kernel_size, _ = lora_down_weight.shape
|
||||
out_size, out_rank, _, _ = lora_up_weight.shape
|
||||
merged_size = out_size * in_size * kernel_size * kernel_size
|
||||
else:
|
||||
in_rank, in_size = lora_down_weight.shape
|
||||
out_size, out_rank = lora_up_weight.shape
|
||||
merged_size = out_size * in_size
|
||||
|
||||
if merged_size > 100_000_000: # Skip if >100M elements
|
||||
logging.warning(f"Skipping {block_down_name}: merged weight too large ({merged_size:,} elements)")
|
||||
tqdm.write(f"SKIPPED: {block_down_name} - too large ({merged_size:,} elements)")
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha / old_rank
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
sum_retained = param_dict["sum_retained"]
|
||||
fro_retained = param_dict["fro_retained"]
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}"
|
||||
tqdm.write(log_str)
|
||||
verbose_str += log_str
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
rank_list.append(param_dict["new_rank"])
|
||||
del param_dict
|
||||
pbar.update(1)
|
||||
|
||||
if verbose:
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
return o_lora_sd, max_old_rank, new_alpha, rank_list
|
||||
1604
custom_nodes/ComfyUI-KJNodes/nodes/ltxv_nodes.py
Normal file
1604
custom_nodes/ComfyUI-KJNodes/nodes/ltxv_nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
1691
custom_nodes/ComfyUI-KJNodes/nodes/mask_nodes.py
Normal file
1691
custom_nodes/ComfyUI-KJNodes/nodes/mask_nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
2286
custom_nodes/ComfyUI-KJNodes/nodes/model_optimization_nodes.py
Normal file
2286
custom_nodes/ComfyUI-KJNodes/nodes/model_optimization_nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
3311
custom_nodes/ComfyUI-KJNodes/nodes/nodes.py
Normal file
3311
custom_nodes/ComfyUI-KJNodes/nodes/nodes.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user