Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,251 @@
# to be used with https://github.com/a1lazydog/ComfyUI-AudioScheduler
import torch
from torchvision.transforms import functional as TF
from PIL import Image, ImageDraw
import numpy as np
from ..utility.utility import pil2tensor
from nodes import MAX_RESOLUTION
class NormalizedAmplitudeToMask:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}),
"frame_offset": ("INT", {"default": 0,"min": -255, "max": 255, "step": 1}),
"location_x": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
"location_y": ("INT", {"default": 256,"min": 0, "max": 4096, "step": 1}),
"size": ("INT", {"default": 128,"min": 8, "max": 4096, "step": 1}),
"shape": (
[
'none',
'circle',
'square',
'triangle',
],
{
"default": 'none'
}),
"color": (
[
'white',
'amplitude',
],
{
"default": 'amplitude'
}),
},}
CATEGORY = "KJNodes/audio"
RETURN_TYPES = ("MASK",)
FUNCTION = "convert"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Creates masks based on the normalized amplitude.
"""
def convert(self, normalized_amp, width, height, frame_offset, shape, location_x, location_y, size, color):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
# Offset the amplitude values by rolling the array
normalized_amp = np.roll(normalized_amp, frame_offset)
# Initialize an empty list to hold the image tensors
out = []
# Iterate over each amplitude value to create an image
for amp in normalized_amp:
# Scale the amplitude value to cover the full range of grayscale values
if color == 'amplitude':
grayscale_value = int(amp * 255)
elif color == 'white':
grayscale_value = 255
# Convert the grayscale value to an RGB format
gray_color = (grayscale_value, grayscale_value, grayscale_value)
finalsize = size * amp
if shape == 'none':
shapeimage = Image.new("RGB", (width, height), gray_color)
else:
shapeimage = Image.new("RGB", (width, height), "black")
draw = ImageDraw.Draw(shapeimage)
if shape == 'circle' or shape == 'square':
# Define the bounding box for the shape
left_up_point = (location_x - finalsize, location_y - finalsize)
right_down_point = (location_x + finalsize,location_y + finalsize)
two_points = [left_up_point, right_down_point]
if shape == 'circle':
draw.ellipse(two_points, fill=gray_color)
elif shape == 'square':
draw.rectangle(two_points, fill=gray_color)
elif shape == 'triangle':
# Define the points for the triangle
left_up_point = (location_x - finalsize, location_y + finalsize) # bottom left
right_down_point = (location_x + finalsize, location_y + finalsize) # bottom right
top_point = (location_x, location_y) # top point
draw.polygon([top_point, left_up_point, right_down_point], fill=gray_color)
shapeimage = pil2tensor(shapeimage)
mask = shapeimage[:, :, :, 0]
out.append(mask)
return (torch.cat(out, dim=0),)
class NormalizedAmplitudeToFloatList:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
},}
CATEGORY = "KJNodes/audio"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "convert"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Creates a list of floats from the normalized amplitude.
"""
def convert(self, normalized_amp):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
return (normalized_amp.tolist(),)
class OffsetMaskByNormalizedAmplitude:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"mask": ("MASK",),
"x": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"y": ("INT", { "default": 0, "min": -4096, "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"rotate": ("BOOLEAN", { "default": False }),
"angle_multiplier": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
}
}
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
FUNCTION = "offset"
CATEGORY = "KJNodes/audio"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Offsets masks based on the normalized amplitude.
"""
def offset(self, mask, x, y, angle_multiplier, rotate, normalized_amp):
# Ensure normalized_amp is an array and within the range [0, 1]
offsetmask = mask.clone()
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
batch_size, height, width = mask.shape
if rotate:
for i in range(batch_size):
rotation_amp = int(normalized_amp[i] * (360 * angle_multiplier))
rotation_angle = rotation_amp
offsetmask[i] = TF.rotate(offsetmask[i].unsqueeze(0), rotation_angle).squeeze(0)
if x != 0 or y != 0:
for i in range(batch_size):
offset_amp = normalized_amp[i] * 10
shift_x = min(x*offset_amp, width-1)
shift_y = min(y*offset_amp, height-1)
if shift_x != 0:
offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_x), dims=1)
if shift_y != 0:
offsetmask[i] = torch.roll(offsetmask[i], shifts=int(shift_y), dims=0)
return offsetmask,
class ImageTransformByNormalizedAmplitude:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"normalized_amp": ("NORMALIZED_AMPLITUDE",),
"zoom_scale": ("FLOAT", { "default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001, "display": "number" }),
"x_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"y_offset": ("INT", { "default": 0, "min": (1 -MAX_RESOLUTION), "max": MAX_RESOLUTION, "step": 1, "display": "number" }),
"cumulative": ("BOOLEAN", { "default": False }),
"image": ("IMAGE",),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "amptransform"
CATEGORY = "KJNodes/audio"
DESCRIPTION = """
Works as a bridge to the AudioScheduler -nodes:
https://github.com/a1lazydog/ComfyUI-AudioScheduler
Transforms image based on the normalized amplitude.
"""
def amptransform(self, image, normalized_amp, zoom_scale, cumulative, x_offset, y_offset):
# Ensure normalized_amp is an array and within the range [0, 1]
normalized_amp = np.clip(normalized_amp, 0.0, 1.0)
transformed_images = []
# Initialize the cumulative zoom factor
prev_amp = 0.0
for i in range(image.shape[0]):
img = image[i] # Get the i-th image in the batch
amp = normalized_amp[i] # Get the corresponding amplitude value
# Incrementally increase the cumulative zoom factor
if cumulative:
prev_amp += amp
amp += prev_amp
# Convert the image tensor from BxHxWxC to CxHxW format expected by torchvision
img = img.permute(2, 0, 1)
# Convert PyTorch tensor to PIL Image for processing
pil_img = TF.to_pil_image(img)
# Calculate the crop size based on the amplitude
width, height = pil_img.size
crop_size = int(min(width, height) * (1 - amp * zoom_scale))
crop_size = max(crop_size, 1)
# Calculate the crop box coordinates (centered crop)
left = (width - crop_size) // 2
top = (height - crop_size) // 2
right = (width + crop_size) // 2
bottom = (height + crop_size) // 2
# Crop and resize back to original size
cropped_img = TF.crop(pil_img, top, left, crop_size, crop_size)
resized_img = TF.resize(cropped_img, (height, width))
# Convert back to tensor in CxHxW format
tensor_img = TF.to_tensor(resized_img)
# Convert the tensor back to BxHxWxC format
tensor_img = tensor_img.permute(1, 2, 0)
# Offset the image based on the amplitude
offset_amp = amp * 10 # Calculate the offset magnitude based on the amplitude
shift_x = min(x_offset * offset_amp, img.shape[1] - 1) # Calculate the shift in x direction
shift_y = min(y_offset * offset_amp, img.shape[0] - 1) # Calculate the shift in y direction
# Apply the offset to the image tensor
if shift_x != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_x), dims=1)
if shift_y != 0:
tensor_img = torch.roll(tensor_img, shifts=int(shift_y), dims=0)
# Add to the list
transformed_images.append(tensor_img)
# Stack all transformed images into a batch
transformed_batch = torch.stack(transformed_images)
return (transformed_batch,)

View File

@@ -0,0 +1,768 @@
from ..utility.utility import tensor2pil, pil2tensor
from PIL import Image, ImageDraw, ImageFilter
import numpy as np
import torch
from torchvision.transforms import Resize, CenterCrop, InterpolationMode
import math
#based on nodes from mtb https://github.com/melMass/comfy_mtb
def bbox_to_region(bbox, target_size=None):
bbox = bbox_check(bbox, target_size)
return (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
def bbox_check(bbox, target_size=None):
if not target_size:
return bbox
new_bbox = (
bbox[0],
bbox[1],
min(target_size[0] - bbox[0], bbox[2]),
min(target_size[1] - bbox[1], bbox[3]),
)
return new_bbox
class BatchCropFromMask:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"masks": ("MASK",),
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"BBOX",
"INT",
"INT",
)
RETURN_NAMES = (
"original_images",
"cropped_images",
"bboxes",
"width",
"height",
)
FUNCTION = "crop"
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
if alpha == 0:
return prev_bbox_size
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
if alpha == 0:
return prev_center
return (
round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1])
)
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
cropped_images = []
self.max_bbox_width = 0
self.max_bbox_height = 0
# First, calculate the maximum bounding box size across all masks
curr_max_bbox_width = 0
curr_max_bbox_height = 0
for mask in masks:
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
width = max_x - min_x
height = max_y - min_y
curr_max_bbox_width = max(curr_max_bbox_width, width)
curr_max_bbox_height = max(curr_max_bbox_height, height)
# Smooth the changes in the bounding box size
self.max_bbox_width = self.smooth_bbox_size(self.max_bbox_width, curr_max_bbox_width, bbox_smooth_alpha)
self.max_bbox_height = self.smooth_bbox_size(self.max_bbox_height, curr_max_bbox_height, bbox_smooth_alpha)
# Apply the crop size multiplier
self.max_bbox_width = round(self.max_bbox_width * crop_size_mult)
self.max_bbox_height = round(self.max_bbox_height * crop_size_mult)
bbox_aspect_ratio = self.max_bbox_width / self.max_bbox_height
# Then, for each mask and corresponding image...
for i, (mask, img) in enumerate(zip(masks, original_images)):
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = np.mean(non_zero_indices[1])
center_y = np.mean(non_zero_indices[0])
curr_center = (round(center_x), round(center_y))
# If this is the first frame, initialize prev_center with curr_center
if not hasattr(self, 'prev_center'):
self.prev_center = curr_center
# Smooth the changes in the center coordinates from the second frame onwards
if i > 0:
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
else:
center = curr_center
# Update prev_center for the next frame
self.prev_center = center
# Create bounding box using max_bbox_width and max_bbox_height
half_box_width = round(self.max_bbox_width / 2)
half_box_height = round(self.max_bbox_height / 2)
min_x = max(0, center[0] - half_box_width)
max_x = min(img.shape[1], center[0] + half_box_width)
min_y = max(0, center[1] - half_box_height)
max_y = min(img.shape[0], center[1] + half_box_height)
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
# Calculate the new dimensions while maintaining the aspect ratio
new_height = min(cropped_img.shape[0], self.max_bbox_height)
new_width = round(new_height * bbox_aspect_ratio)
# Resize the image
resize_transform = Resize((new_height, new_width))
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
crop_transform = CenterCrop((self.max_bbox_height, self.max_bbox_width)) # swap the order here if necessary
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_out = torch.stack(cropped_images, dim=0)
return (original_images, cropped_out, bounding_boxes, self.max_bbox_width, self.max_bbox_height, )
class BatchUncrop:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"cropped_images": ("IMAGE",),
"bboxes": ("BBOX",),
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"border_top": ("BOOLEAN", {"default": True}),
"border_bottom": ("BOOLEAN", {"default": True}),
"border_left": ("BOOLEAN", {"default": True}),
"border_right": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "uncrop"
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, bboxes, border_blending, crop_rescale, border_top, border_bottom, border_left, border_right):
def inset_border(image, border_width, border_color, border_top, border_bottom, border_left, border_right):
draw = ImageDraw.Draw(image)
width, height = image.size
if border_top:
draw.rectangle((0, 0, width, border_width), fill=border_color)
if border_bottom:
draw.rectangle((0, height - border_width, width, height), fill=border_color)
if border_left:
draw.rectangle((0, 0, border_width, height), fill=border_color)
if border_right:
draw.rectangle((width - border_width, 0, width, height), fill=border_color)
return image
if len(original_images) != len(cropped_images):
raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
# Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
if len(bboxes) > len(original_images):
print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
bboxes = bboxes[:len(original_images)]
elif len(bboxes) < len(original_images):
raise ValueError("There should be at least as many bboxes as there are original and cropped images")
input_images = tensor2pil(original_images)
crop_imgs = tensor2pil(cropped_images)
out_images = []
for i in range(len(input_images)):
img = input_images[i]
crop = crop_imgs[i]
bbox = bboxes[i]
# uncrop the image based on the bounding box
bb_x, bb_y, bb_width, bb_height = bbox
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
# scale factors
scale_x = crop_rescale
scale_y = crop_rescale
# scaled paste_region
paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
# rescale the crop image to fit the paste_region
crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
crop_img = crop.convert("RGB")
if border_blending > 1.0:
border_blending = 1.0
elif border_blending < 0.0:
border_blending = 0.0
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
blend = img.convert("RGBA")
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
mask_block = inset_border(mask_block, round(blend_ratio / 2), (0), border_top, border_bottom, border_left, border_right)
mask.paste(mask_block, paste_region)
blend.paste(crop_img, paste_region)
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
blend.putalpha(mask)
img = Image.alpha_composite(img.convert("RGBA"), blend)
out_images.append(img.convert("RGB"))
return (pil2tensor(out_images),)
class BatchCropFromMaskAdvanced:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"masks": ("MASK",),
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"MASK",
"IMAGE",
"MASK",
"BBOX",
"BBOX",
"INT",
"INT",
)
RETURN_NAMES = (
"original_images",
"cropped_images",
"cropped_masks",
"combined_crop_image",
"combined_crop_masks",
"bboxes",
"combined_bounding_box",
"bbox_width",
"bbox_height",
)
FUNCTION = "crop"
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
return (round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
combined_bounding_box = []
cropped_images = []
cropped_masks = []
cropped_masks_out = []
combined_crop_out = []
combined_cropped_images = []
combined_cropped_masks = []
def calculate_bbox(mask):
non_zero_indices = np.nonzero(np.array(mask))
# handle empty masks
min_x, max_x, min_y, max_y = 0, 0, 0, 0
if len(non_zero_indices[1]) > 0 and len(non_zero_indices[0]) > 0:
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
width = max_x - min_x
height = max_y - min_y
bbox_size = max(width, height)
return min_x, max_x, min_y, max_y, bbox_size
combined_mask = torch.max(masks, dim=0)[0]
_mask = tensor2pil(combined_mask)[0]
new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
center_x = (new_min_x + new_max_x) / 2
center_y = (new_min_y + new_max_y) / 2
half_box_size = round(combined_bbox_size // 2)
new_min_x = max(0, round(center_x - half_box_size))
new_max_x = min(original_images[0].shape[1], round(center_x + half_box_size))
new_min_y = max(0, round(center_y - half_box_size))
new_max_y = min(original_images[0].shape[0], round(center_y + half_box_size))
combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
self.max_bbox_size = 0
# First, calculate the maximum bounding box size across all masks
curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
# Smooth the changes in the bounding box size
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
# Apply the crop size multiplier
self.max_bbox_size = round(self.max_bbox_size * crop_size_mult)
# Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is
self.max_bbox_size = math.ceil(self.max_bbox_size / 16) * 16
if self.max_bbox_size > original_images[0].shape[0] or self.max_bbox_size > original_images[0].shape[1]:
# max_bbox_size can only be as big as our input's width or height, and it has to be even
self.max_bbox_size = math.floor(min(original_images[0].shape[0], original_images[0].shape[1]) / 2) * 2
# Then, for each mask and corresponding image...
for i, (mask, img) in enumerate(zip(masks, original_images)):
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
# check for empty masks
if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0:
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = np.mean(non_zero_indices[1])
center_y = np.mean(non_zero_indices[0])
curr_center = (round(center_x), round(center_y))
# If this is the first frame, initialize prev_center with curr_center
if not hasattr(self, 'prev_center'):
self.prev_center = curr_center
# Smooth the changes in the center coordinates from the second frame onwards
if i > 0:
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
else:
center = curr_center
# Update prev_center for the next frame
self.prev_center = center
# Create bounding box using max_bbox_size
half_box_size = self.max_bbox_size // 2
min_x = max(0, center[0] - half_box_size)
max_x = min(img.shape[1], center[0] + half_box_size)
min_y = max(0, center[1] - half_box_size)
max_y = min(img.shape[0], center[1] + half_box_size)
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
cropped_mask = mask[min_y:max_y, min_x:max_x]
# Resize the cropped image to a fixed size
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
resize_transform = Resize(new_size, interpolation=InterpolationMode.NEAREST, max_size=max(img.shape[0], img.shape[1]))
resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
# Constrain the crop to the smaller of our bbox or our image so we don't expand past the image dimensions.
crop_transform = CenterCrop((min(self.max_bbox_size, resized_img.shape[1]), min(self.max_bbox_size, resized_img.shape[2])))
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_resized_mask = crop_transform(resized_mask)
cropped_masks.append(cropped_resized_mask)
combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
combined_cropped_images.append(combined_cropped_img)
combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
combined_cropped_masks.append(combined_cropped_mask)
else:
bounding_boxes.append((0, 0, img.shape[1], img.shape[0]))
cropped_images.append(img)
cropped_masks.append(mask)
combined_cropped_images.append(img)
combined_cropped_masks.append(mask)
cropped_out = torch.stack(cropped_images, dim=0)
combined_crop_out = torch.stack(combined_cropped_images, dim=0)
cropped_masks_out = torch.stack(cropped_masks, dim=0)
combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)
class FilterZeroMasksAndCorrespondingImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"masks": ("MASK",),
},
"optional": {
"original_images": ("IMAGE",),
},
}
RETURN_TYPES = ("MASK", "IMAGE", "IMAGE", "INDEXES",)
RETURN_NAMES = ("non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", "zero_mask_images_out_indexes",)
FUNCTION = "filter"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Filter out all the empty (i.e. all zero) mask in masks
Also filter out all the corresponding images in original_images by indexes if provide
original_images (optional): If provided, need have same length as masks.
"""
def filter(self, masks, original_images=None):
non_zero_masks = []
non_zero_mask_images = []
zero_mask_images = []
zero_mask_images_indexes = []
masks_num = len(masks)
also_process_images = False
if original_images is not None:
imgs_num = len(original_images)
if len(original_images) == masks_num:
also_process_images = True
else:
print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
for i in range(masks_num):
non_zero_num = np.count_nonzero(np.array(masks[i]))
if non_zero_num > 0:
non_zero_masks.append(masks[i])
if also_process_images:
non_zero_mask_images.append(original_images[i])
else:
zero_mask_images.append(original_images[i])
zero_mask_images_indexes.append(i)
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
if also_process_images:
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
if len(zero_mask_images) > 0:
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
zero_mask_images_out_indexes = zero_mask_images_indexes
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)
class InsertImageBatchByIndexes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"images_to_insert": ("IMAGE",),
"insert_indexes": ("INDEXES",),
},
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("images_after_insert", )
FUNCTION = "insert"
CATEGORY = "KJNodes/image"
DESCRIPTION = """
This node is designed to be use with node FilterZeroMasksAndCorrespondingImages
It inserts the images_to_insert into images according to insert_indexes
Returns:
images_after_insert: updated original images with origonal sequence order
"""
def insert(self, images, images_to_insert, insert_indexes):
images_after_insert = images
if images_to_insert is not None and insert_indexes is not None:
images_to_insert_num = len(images_to_insert)
insert_indexes_num = len(insert_indexes)
if images_to_insert_num == insert_indexes_num:
images_after_insert = []
i_images = 0
for i in range(len(images) + images_to_insert_num):
if i in insert_indexes:
images_after_insert.append(images_to_insert[insert_indexes.index(i)])
else:
images_after_insert.append(images[i_images])
i_images += 1
images_after_insert = torch.stack(images_after_insert, dim=0)
else:
print(f"[WARNING] skip this node, due to number of images_to_insert ({images_to_insert_num}) is not equal to number of insert_indexes ({insert_indexes_num})")
return (images_after_insert, )
class BatchUncropAdvanced:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"cropped_images": ("IMAGE",),
"cropped_masks": ("MASK",),
"combined_crop_mask": ("MASK",),
"bboxes": ("BBOX",),
"border_blending": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}, ),
"crop_rescale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"use_combined_mask": ("BOOLEAN", {"default": False}),
"use_square_mask": ("BOOLEAN", {"default": True}),
},
"optional": {
"combined_bounding_box": ("BBOX", {"default": None}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "uncrop"
CATEGORY = "KJNodes/masking"
def uncrop(self, original_images, cropped_images, cropped_masks, combined_crop_mask, bboxes, border_blending, crop_rescale, use_combined_mask, use_square_mask, combined_bounding_box = None):
def inset_border(image, border_width=20, border_color=(0)):
width, height = image.size
bordered_image = Image.new(image.mode, (width, height), border_color)
bordered_image.paste(image, (0, 0))
draw = ImageDraw.Draw(bordered_image)
draw.rectangle((0, 0, width - 1, height - 1), outline=border_color, width=border_width)
return bordered_image
if len(original_images) != len(cropped_images):
raise ValueError(f"The number of original_images ({len(original_images)}) and cropped_images ({len(cropped_images)}) should be the same")
# Ensure there are enough bboxes, but drop the excess if there are more bboxes than images
if len(bboxes) > len(original_images):
print(f"Warning: Dropping excess bounding boxes. Expected {len(original_images)}, but got {len(bboxes)}")
bboxes = bboxes[:len(original_images)]
elif len(bboxes) < len(original_images):
raise ValueError("There should be at least as many bboxes as there are original and cropped images")
crop_imgs = tensor2pil(cropped_images)
input_images = tensor2pil(original_images)
out_images = []
for i in range(len(input_images)):
img = input_images[i]
crop = crop_imgs[i]
bbox = bboxes[i]
if use_combined_mask:
bb_x, bb_y, bb_width, bb_height = combined_bounding_box[0]
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
mask = combined_crop_mask[i]
else:
bb_x, bb_y, bb_width, bb_height = bbox
paste_region = bbox_to_region((bb_x, bb_y, bb_width, bb_height), img.size)
mask = cropped_masks[i]
# scale paste_region
scale_x = scale_y = crop_rescale
paste_region = (round(paste_region[0]*scale_x), round(paste_region[1]*scale_y), round(paste_region[2]*scale_x), round(paste_region[3]*scale_y))
# rescale the crop image to fit the paste_region
crop = crop.resize((round(paste_region[2]-paste_region[0]), round(paste_region[3]-paste_region[1])))
crop_img = crop.convert("RGB")
#border blending
if border_blending > 1.0:
border_blending = 1.0
elif border_blending < 0.0:
border_blending = 0.0
blend_ratio = (max(crop_img.size) / 2) * float(border_blending)
blend = img.convert("RGBA")
if use_square_mask:
mask = Image.new("L", img.size, 0)
mask_block = Image.new("L", (paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]), 255)
mask_block = inset_border(mask_block, round(blend_ratio / 2), (0))
mask.paste(mask_block, paste_region)
else:
original_mask = tensor2pil(mask)[0]
original_mask = original_mask.resize((paste_region[2]-paste_region[0], paste_region[3]-paste_region[1]))
mask = Image.new("L", img.size, 0)
mask.paste(original_mask, paste_region)
mask = mask.filter(ImageFilter.BoxBlur(radius=blend_ratio / 4))
mask = mask.filter(ImageFilter.GaussianBlur(radius=blend_ratio / 4))
blend.paste(crop_img, paste_region)
blend.putalpha(mask)
img = Image.alpha_composite(img.convert("RGBA"), blend)
out_images.append(img.convert("RGB"))
return (pil2tensor(out_images),)
class SplitBboxes:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"bboxes": ("BBOX",),
"index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
},
}
RETURN_TYPES = ("BBOX","BBOX",)
RETURN_NAMES = ("bboxes_a","bboxes_b",)
FUNCTION = "splitbbox"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Splits the specified bbox list at the given index into two lists.
"""
def splitbbox(self, bboxes, index):
bboxes_a = bboxes[:index] # Sub-list from the start of bboxes up to (but not including) the index
bboxes_b = bboxes[index:] # Sub-list from the index to the end of bboxes
return (bboxes_a, bboxes_b,)
class BboxToInt:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"bboxes": ("BBOX",),
"index": ("INT", {"default": 0,"min": 0, "max": 99999999, "step": 1}),
},
}
RETURN_TYPES = ("INT","INT","INT","INT","INT","INT",)
RETURN_NAMES = ("x_min","y_min","width","height", "center_x","center_y",)
FUNCTION = "bboxtoint"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Returns selected index from bounding box list as integers.
"""
def bboxtoint(self, bboxes, index):
x_min, y_min, width, height = bboxes[index]
center_x = int(x_min + width / 2)
center_y = int(y_min + height / 2)
return (x_min, y_min, width, height, center_x, center_y,)
class BboxVisualize:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"bboxes": ("BBOX",),
"line_width": ("INT", {"default": 1,"min": 1, "max": 10, "step": 1}),
"bbox_format": (["xywh", "xyxy"], {"default": "xywh"}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("images",)
FUNCTION = "visualizebbox"
DESCRIPTION = """
Visualizes the specified bbox on the image.
"""
CATEGORY = "KJNodes/masking"
def visualizebbox(self, bboxes, images, line_width, bbox_format):
image_list = []
for image, bbox in zip(images, bboxes):
# Ensure bbox is a sequence of 4 values
if isinstance(bbox, (list, tuple, np.ndarray)) and len(bbox) == 4:
if bbox_format == "xywh":
x_min, y_min, width, height = bbox
elif bbox_format == "xyxy":
x_min, y_min, x_max, y_max = bbox
width = x_max - x_min
height = y_max - y_min
else:
raise ValueError(f"Unknown bbox_format: {bbox_format}")
else:
print("Invalid bbox:", bbox)
continue
# Ensure bbox coordinates are integers
x_min = int(x_min)
y_min = int(y_min)
width = int(width)
height = int(height)
# Permute the image dimensions
image = image.permute(2, 0, 1)
# Clone the image to draw bounding boxes
img_with_bbox = image.clone()
# Define the color for the bbox, e.g., red
color = torch.tensor([1, 0, 0], dtype=torch.float32)
# Ensure color tensor matches the image channels
if color.shape[0] != img_with_bbox.shape[0]:
color = color.unsqueeze(1).expand(-1, line_width)
# Draw lines for each side of the bbox with the specified line width
for lw in range(line_width):
# Top horizontal line
if y_min + lw < img_with_bbox.shape[1]:
img_with_bbox[:, y_min + lw, x_min:x_min + width] = color[:, None]
# Bottom horizontal line
if y_min + height - lw < img_with_bbox.shape[1]:
img_with_bbox[:, y_min + height - lw, x_min:x_min + width] = color[:, None]
# Left vertical line
if x_min + lw < img_with_bbox.shape[2]:
img_with_bbox[:, y_min:y_min + height, x_min + lw] = color[:, None]
# Right vertical line
if x_min + width - lw < img_with_bbox.shape[2]:
img_with_bbox[:, y_min:y_min + height, x_min + width - lw] = color[:, None]
# Permute the image dimensions back
img_with_bbox = img_with_bbox.permute(1, 2, 0).unsqueeze(0)
image_list.append(img_with_bbox)
return (torch.cat(image_list, dim=0),)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,115 @@
import folder_paths
import os
import torch
import torch.nn.functional as F
from comfy.utils import ProgressBar, load_torch_file
import comfy.sample
from nodes import CLIPTextEncode
script_directory = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
folder_paths.add_model_folder_path("intrinsic_loras", os.path.join(script_directory, "intrinsic_loras"))
class Intrinsic_lora_sampling:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("intrinsic_loras"), ),
"task": (
[
'depth map',
'surface normals',
'albedo',
'shading',
],
{
"default": 'depth map'
}),
"text": ("STRING", {"multiline": True, "default": ""}),
"clip": ("CLIP", ),
"vae": ("VAE", ),
"per_batch": ("INT", {"default": 16, "min": 1, "max": 4096, "step": 1}),
},
"optional": {
"image": ("IMAGE",),
"optional_latent": ("LATENT",),
},
}
RETURN_TYPES = ("IMAGE", "LATENT",)
FUNCTION = "onestepsample"
CATEGORY = "KJNodes"
DESCRIPTION = """
Sampler to use the intrinsic loras:
https://github.com/duxiaodan/intrinsic-lora
These LoRAs are tiny and thus included
with this node pack.
"""
def onestepsample(self, model, lora_name, clip, vae, text, task, per_batch, image=None, optional_latent=None):
pbar = ProgressBar(3)
if optional_latent is None:
image_list = []
for start_idx in range(0, image.shape[0], per_batch):
sub_pixels = vae.vae_encode_crop_pixels(image[start_idx:start_idx+per_batch])
image_list.append(vae.encode(sub_pixels[:,:,:,:3]))
sample = torch.cat(image_list, dim=0)
else:
sample = optional_latent["samples"]
noise = torch.zeros(sample.size(), dtype=sample.dtype, layout=sample.layout, device="cpu")
prompt = task + "," + text
positive, = CLIPTextEncode.encode(self, clip, prompt)
negative = positive #negative shouldn't do anything in this scenario
pbar.update(1)
#custom model sampling to pass latent through as it is
class X0_PassThrough(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
def calculate_input(self, sigma, noise):
return noise
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
sampling_type = X0_PassThrough
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
#load lora
model_clone = model.clone()
lora_path = folder_paths.get_full_path("intrinsic_loras", lora_name)
lora = load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_clone_with_lora = comfy.sd.load_lora_for_models(model_clone, None, lora, 1.0, 0)[0]
model_clone_with_lora.add_object_patch("model_sampling", model_sampling)
samples = {"samples": comfy.sample.sample(model_clone_with_lora, noise, 1, 1.0, "euler", "simple", positive, negative, sample,
denoise=1.0, disable_noise=True, start_step=0, last_step=1,
force_full_denoise=True, noise_mask=None, callback=None, disable_pbar=True, seed=None)}
pbar.update(1)
decoded = []
for start_idx in range(0, samples["samples"].shape[0], per_batch):
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
image_out = torch.cat(decoded, dim=0)
pbar.update(1)
if task == 'depth map':
imax = image_out.max()
imin = image_out.min()
image_out = (image_out-imin)/(imax-imin)
image_out = torch.max(image_out, dim=3, keepdim=True)[0].repeat(1, 1, 1, 3)
elif task == 'surface normals':
image_out = F.normalize(image_out * 2 - 1, dim=3) / 2 + 0.5
image_out = 1.0 - image_out
else:
image_out = image_out.clamp(-1.,1.)
return (image_out, samples,)

View File

@@ -0,0 +1,583 @@
import torch
import comfy.model_management
import comfy.utils
import folder_paths
import os
import logging
from tqdm import tqdm
import numpy as np
device = comfy.model_management.get_torch_device()
CLAMP_QUANTILE = 0.99
def extract_lora(diff, key, rank, algorithm, lora_type, lowrank_iters=7, adaptive_param=1.0, clamp_quantile=True):
"""
Extracts LoRA weights from a weight difference tensor using SVD.
"""
conv2d = (len(diff.shape) == 4)
kernel_size = None if not conv2d else diff.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = diff.size()[0:2]
if conv2d:
if conv2d_3x3:
diff = diff.flatten(start_dim=1)
else:
diff = diff.squeeze()
diff_float = diff.float()
if algorithm == "svd_lowrank":
U, S, V = torch.svd_lowrank(diff_float, q=min(rank, in_dim, out_dim), niter=lowrank_iters)
U = U @ torch.diag(S)
Vh = V.t()
else:
#torch.linalg.svdvals()
U, S, Vh = torch.linalg.svd(diff_float)
# Flexible rank selection logic like locon: https://github.com/KohakuBlueleaf/LyCORIS/blob/main/tools/extract_locon.py
if "adaptive" in lora_type:
if lora_type == "adaptive_ratio":
min_s = torch.max(S) * adaptive_param
lora_rank = torch.sum(S > min_s).item()
elif lora_type == "adaptive_energy":
energy = torch.cumsum(S**2, dim=0)
total_energy = torch.sum(S**2)
threshold = adaptive_param * total_energy # e.g., adaptive_param=0.95 for 95%
lora_rank = torch.sum(energy < threshold).item() + 1
elif lora_type == "adaptive_quantile":
s_cum = torch.cumsum(S, dim=0)
min_cum_sum = adaptive_param * torch.sum(S)
lora_rank = torch.sum(s_cum < min_cum_sum).item()
elif lora_type == "adaptive_fro":
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
lora_rank = int(torch.searchsorted(sum_S_squared, adaptive_param**2)) + 1
lora_rank = max(1, min(lora_rank, len(S)))
else:
pass # Will print after capping
# Cap adaptive rank by the specified max rank
lora_rank = min(lora_rank, rank)
# Calculate and print actual fro percentage retained after capping
if lora_type == "adaptive_fro":
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:lora_rank]))
fro_percent = float(s_red_fro / s_fro)
print(f"{key} Extracted LoRA rank: {lora_rank}, Frobenius retained: {fro_percent:.1%}")
else:
print(f"{key} Extracted LoRA rank: {lora_rank}")
else:
lora_rank = rank
lora_rank = max(1, lora_rank)
lora_rank = min(out_dim, in_dim, lora_rank)
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
if clamp_quantile:
dist = torch.cat([U.flatten(), Vh.flatten()])
if dist.numel() > 100_000:
# Sample 100,000 elements for quantile estimation
idx = torch.randperm(dist.numel(), device=dist.device)[:100_000]
dist_sample = dist[idx]
hi_val = torch.quantile(dist_sample, CLAMP_QUANTILE)
else:
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.reshape(out_dim, lora_rank, 1, 1)
Vh = Vh.reshape(lora_rank, in_dim, kernel_size[0], kernel_size[1])
return (U, Vh)
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, algorithm, lowrank_iters, out_dtype, bias_diff=False, adaptive_param=1.0, clamp_quantile=True):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
model_diff.model.diffusion_model.cpu()
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
del model_diff
comfy.model_management.soft_empty_cache()
for k, v in sd.items():
if isinstance(v, torch.Tensor):
sd[k] = v.cpu()
# Get total number of keys to process for progress bar
total_keys = len([k for k in sd if k.endswith(".weight") or (bias_diff and k.endswith(".bias"))])
# Create progress bar
progress_bar = tqdm(total=total_keys, desc=f"Extracting LoRA ({prefix_lora.strip('.')})")
comfy_pbar = comfy.utils.ProgressBar(total_keys)
for k in sd:
if k.endswith(".weight"):
weight_diff = sd[k]
if weight_diff.ndim == 5:
logging.info(f"Skipping 5D tensor for key {k}") #skip patch embed
progress_bar.update(1)
comfy_pbar.update(1)
continue
if lora_type != "full":
if weight_diff.ndim < 2:
if bias_diff:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
progress_bar.update(1)
comfy_pbar.update(1)
continue
try:
out = extract_lora(weight_diff.to(device), k, rank, algorithm, lora_type, lowrank_iters=lowrank_iters, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().to(out_dtype).cpu()
output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().to(out_dtype).cpu()
except Exception as e:
logging.warning(f"Could not generate lora weights for key {k}, error {e}")
else:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().to(out_dtype).cpu()
progress_bar.update(1)
comfy_pbar.update(1)
elif bias_diff and k.endswith(".bias"):
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().to(out_dtype).cpu()
progress_bar.update(1)
comfy_pbar.update(1)
progress_bar.close()
return output_sd
class LoraExtractKJ:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"finetuned_model": ("MODEL",),
"original_model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The rank to use for standard LoRA, or maximum rank limit for adaptive methods."}),
"lora_type": (["standard", "full", "adaptive_ratio", "adaptive_quantile", "adaptive_energy", "adaptive_fro"],),
"algorithm": (["svd_linalg", "svd_lowrank"], {"default": "svd_linalg", "tooltip": "SVD algorithm to use, svd_lowrank is faster but less accurate."}),
"lowrank_iters": ("INT", {"default": 7, "min": 1, "max": 100, "step": 1, "tooltip": "The number of subspace iterations for lowrank SVD algorithm."}),
"output_dtype": (["fp16", "bf16", "fp32"], {"default": "fp16"}),
"bias_diff": ("BOOLEAN", {"default": True}),
"adaptive_param": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "For ratio mode, this is the ratio of the maximum singular value. For quantile mode, this is the quantile of the singular values. For fro mode, this is the Frobenius norm retention ratio."}),
"clamp_quantile": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "KJNodes/lora"
def save(self, finetuned_model, original_model, filename_prefix, rank, lora_type, algorithm, lowrank_iters, output_dtype, bias_diff, adaptive_param, clamp_quantile):
if algorithm == "svd_lowrank" and lora_type != "standard":
raise ValueError("svd_lowrank algorithm is only supported for standard LoRA extraction.")
dtype = {"fp8_e4m3fn": torch.float8_e4m3fn, "bf16": torch.bfloat16, "fp16": torch.float16, "fp16_fast": torch.float16, "fp32": torch.float32}[output_dtype]
m = finetuned_model.clone()
kp = original_model.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, - 1.0, 1.0)
model_diff = m
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
output_sd = {}
if model_diff is not None:
output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, algorithm, lowrank_iters, dtype, bias_diff=bias_diff, adaptive_param=adaptive_param, clamp_quantile=clamp_quantile)
if "adaptive" in lora_type:
rank_str = f"{lora_type}_{adaptive_param:.2f}"
else:
rank_str = rank
output_checkpoint = f"{filename}_rank_{rank_str}_{output_dtype}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {}
NODE_CLASS_MAPPINGS = {
"LoraExtractKJ": LoraExtractKJ
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraExtractKJ": "LoraExtractKJ"
}
class LoraReduceRank:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required":
{
"lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}),
"new_rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1, "tooltip": "The new rank to resize the LoRA. Acts as max rank when using dynamic_method."}),
"dynamic_method": (["disabled", "sv_ratio", "sv_cumulative", "sv_fro"], {"default": "disabled", "tooltip": "Method to use for dynamically determining new alphas and dims"}),
"dynamic_param": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Method to use for dynamically determining new alphas and dims"}),
"output_dtype": (["match_original", "fp16", "bf16", "fp32"], {"default": "match_original", "tooltip": "Data type to save the LoRA as."}),
"verbose": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
EXPERIMENTAL = True
DESCRIPTION = "Resize a LoRA model by reducing it's rank. Based on kohya's sd-scripts: https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py"
CATEGORY = "KJNodes/lora"
def save(self, lora_name, new_rank, output_dtype, dynamic_method, dynamic_param, verbose):
lora_path = folder_paths.get_full_path("loras", lora_name)
lora_sd, metadata = comfy.utils.load_torch_file(lora_path, return_metadata=True)
if output_dtype == "fp16":
save_dtype = torch.float16
elif output_dtype == "bf16":
save_dtype = torch.bfloat16
elif output_dtype == "fp32":
save_dtype = torch.float32
elif output_dtype == "match_original":
first_weight_key = next(k for k in lora_sd if k.endswith(".weight") and isinstance(lora_sd[k], torch.Tensor))
save_dtype = lora_sd[first_weight_key].dtype
new_lora_sd = {}
for k, v in lora_sd.items():
new_lora_sd[k.replace(".default", "")] = v
del lora_sd
print("Resizing Lora...")
output_sd, old_dim, new_alpha, rank_list = resize_lora_model(new_lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
if dynamic_method == "disabled":
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {new_rank}; {comment}"
metadata["ss_network_dim"] = str(new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {dynamic_method}: {dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
# cast to save_dtype before calculating hashes
for key in list(output_sd.keys()):
value = output_sd[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
output_sd[key] = value.to(save_dtype)
output_filename_prefix = "loras/" + lora_name
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(output_filename_prefix, self.output_dir)
output_dtype_str = f"_{output_dtype}" if output_dtype != "match_original" else ""
average_rank = str(int(np.mean(rank_list)))
rank_str = new_rank if dynamic_method == "disabled" else f"dynamic_{average_rank}"
output_checkpoint = f"{filename.replace('.safetensors', '')}_resized_from_{old_dim}_to_{rank_str}{output_dtype_str}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
print(f"Saving resized LoRA to {output_checkpoint}")
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = {
"LoraExtractKJ": LoraExtractKJ
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoraExtractKJ": "LoraExtractKJ"
}
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo
# This version is based on
# https://github.com/kohya-ss/sd-scripts/blob/main/networks/resize_lora.py
MIN_SV = 1e-6
LORA_DOWN_UP_FORMATS = [
("lora_down", "lora_up"), # sd-scripts LoRA
("lora_A", "lora_B"), # PEFT LoRA
("down", "up"), # ControlLoRA
]
# Indexing functions
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
if weight.dtype != torch.float32:
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
del U, S, Vh, weight
return param_dict
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
if weight.dtype != torch.float32:
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
return param_dict
def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
del lora_up, lora_down
return weight
def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
weight = lora_up @ lora_down
del lora_up, lora_down
return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale * new_rank)
elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale * new_rank)
elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale * new_rank)
else:
new_rank = rank
new_alpha = float(scale * new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale * new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale * new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro / s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank) / s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
max_old_rank = None
new_alpha = None
verbose_str = "\n"
fro_list = []
rank_list = []
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
total_keys = len([k for k in lora_sd if k.endswith(".weight")])
pbar = comfy.utils.ProgressBar(total_keys)
for key, value in tqdm(lora_sd.items(), leave=True, desc="Resizing LoRA weights"):
key_parts = key.split(".")
block_down_name = None
for _format in LORA_DOWN_UP_FORMATS:
# Currently we only match lora_down_name in the last two parts of key
# because ("down", "up") are general words and may appear in block_down_name
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
block_down_name = ".".join(key_parts[:-2])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = "." + key_parts[-1]
break
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
block_down_name = ".".join(key_parts[:-1])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = ""
break
if block_down_name is None:
# This parameter is not lora_down
continue
# Now weight_name can be ".weight" or ""
# Find corresponding lora_up and alpha
block_up_name = block_down_name
lora_down_weight = value
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
if weights_loaded:
conv2d = len(lora_down_weight.size()) == 4
old_rank = lora_down_weight.size()[0]
max_old_rank = max(max_old_rank or 0, old_rank)
# Skip if merged weight would be too large (>100k elements in any dimension)
if conv2d:
in_rank, in_size, kernel_size, _ = lora_down_weight.shape
out_size, out_rank, _, _ = lora_up_weight.shape
merged_size = out_size * in_size * kernel_size * kernel_size
else:
in_rank, in_size = lora_down_weight.shape
out_size, out_rank = lora_up_weight.shape
merged_size = out_size * in_size
if merged_size > 100_000_000: # Skip if >100M elements
logging.warning(f"Skipping {block_down_name}: merged weight too large ({merged_size:,} elements)")
tqdm.write(f"SKIPPED: {block_down_name} - too large ({merged_size:,} elements)")
pbar.update(1)
continue
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / old_rank
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if verbose:
max_ratio = param_dict["max_ratio"]
sum_retained = param_dict["sum_retained"]
fro_retained = param_dict["fro_retained"]
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
log_str = f"{block_down_name:75} | sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}, new dim: {param_dict['new_rank']}"
tqdm.write(log_str)
verbose_str += log_str
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += "\n"
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
rank_list.append(param_dict["new_rank"])
del param_dict
pbar.update(1)
if verbose:
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
return o_lora_sd, max_old_rank, new_alpha, rank_list

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff