Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
570
custom_nodes/ComfyUI-Easy-Use/py/libs/loader.py
Normal file
570
custom_nodes/ComfyUI-Easy-Use/py/libs/loader.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import re, time, os, psutil
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
import comfy.controlnet
|
||||
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from nodes import NODE_CLASS_MAPPINGS
|
||||
from collections import defaultdict
|
||||
from .log import log_node_info, log_node_error
|
||||
from ..modules.dit.pixArt.loader import load_pixart
|
||||
|
||||
diffusion_loaders = ["easy fullLoader", "easy a1111Loader", "easy fluxLoader", "easy comfyLoader", "easy hunyuanDiTLoader", "easy zero123Loader", "easy svdLoader"]
|
||||
stable_cascade_loaders = ["easy cascadeLoader"]
|
||||
dit_loaders = ['easy pixArtLoader']
|
||||
controlnet_loaders = ["easy controlnetLoader", "easy controlnetLoaderADV", "easy controlnetLoader++"]
|
||||
instant_loaders = ["easy instantIDApply", "easy instantIDApplyADV"]
|
||||
cascade_vae_node = ["easy preSamplingCascade", "easy fullCascadeKSampler"]
|
||||
model_merge_node = ["easy XYInputs: ModelMergeBlocks"]
|
||||
lora_widget = ["easy fullLoader", "easy a1111Loader", "easy comfyLoader", "easy fluxLoader"]
|
||||
|
||||
class easyLoader:
|
||||
def __init__(self):
|
||||
self.loaded_objects = {
|
||||
"ckpt": defaultdict(tuple), # {ckpt_name: (model, ...)}
|
||||
"unet": defaultdict(tuple),
|
||||
"clip": defaultdict(tuple),
|
||||
"clip_vision": defaultdict(tuple),
|
||||
"bvae": defaultdict(tuple),
|
||||
"vae": defaultdict(object),
|
||||
"lora": defaultdict(dict), # {lora_name: {UID: (model_lora, clip_lora)}}
|
||||
"controlnet": defaultdict(dict),
|
||||
"t5": defaultdict(tuple),
|
||||
"chatglm3": defaultdict(tuple),
|
||||
}
|
||||
self.memory_threshold = self.determine_memory_threshold(1)
|
||||
self.lora_name_cache = []
|
||||
|
||||
def clean_values(self, values: str):
|
||||
original_values = values.split("; ")
|
||||
cleaned_values = []
|
||||
|
||||
for value in original_values:
|
||||
cleaned_value = value.strip(';').strip()
|
||||
if cleaned_value == "":
|
||||
continue
|
||||
try:
|
||||
cleaned_value = int(cleaned_value)
|
||||
except ValueError:
|
||||
try:
|
||||
cleaned_value = float(cleaned_value)
|
||||
except ValueError:
|
||||
pass
|
||||
cleaned_values.append(cleaned_value)
|
||||
|
||||
return cleaned_values
|
||||
|
||||
def clear_unused_objects(self, desired_names: set, object_type: str):
|
||||
keys = set(self.loaded_objects[object_type].keys())
|
||||
for key in keys - desired_names:
|
||||
del self.loaded_objects[object_type][key]
|
||||
|
||||
def get_input_value(self, entry, key, prompt=None):
|
||||
val = entry["inputs"][key]
|
||||
if isinstance(val, str):
|
||||
return val
|
||||
elif isinstance(val, list):
|
||||
if prompt is not None and val[0]:
|
||||
return prompt[val[0]]['inputs'][key]
|
||||
else:
|
||||
return val[0]
|
||||
else:
|
||||
return str(val)
|
||||
|
||||
def process_pipe_loader(self, entry, desired_ckpt_names, desired_vae_names, desired_lora_names, desired_lora_settings, num_loras=3, suffix=""):
|
||||
for idx in range(1, num_loras + 1):
|
||||
lora_name_key = f"{suffix}lora{idx}_name"
|
||||
desired_lora_names.add(self.get_input_value(entry, lora_name_key))
|
||||
setting = f'{self.get_input_value(entry, lora_name_key)};{entry["inputs"][f"{suffix}lora{idx}_model_strength"]};{entry["inputs"][f"{suffix}lora{idx}_clip_strength"]}'
|
||||
desired_lora_settings.add(setting)
|
||||
|
||||
desired_ckpt_names.add(self.get_input_value(entry, f"{suffix}ckpt_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, f"{suffix}vae_name"))
|
||||
|
||||
def update_loaded_objects(self, prompt):
|
||||
desired_ckpt_names = set()
|
||||
desired_unet_names = set()
|
||||
desired_clip_names = set()
|
||||
desired_vae_names = set()
|
||||
desired_lora_names = set()
|
||||
desired_lora_settings = set()
|
||||
desired_controlnet_names = set()
|
||||
desired_t5_names = set()
|
||||
desired_glm3_names = set()
|
||||
|
||||
for entry in prompt.values():
|
||||
class_type = entry["class_type"]
|
||||
if class_type in lora_widget:
|
||||
lora_name = self.get_input_value(entry, "lora_name")
|
||||
desired_lora_names.add(lora_name)
|
||||
setting = f'{lora_name};{entry["inputs"]["lora_model_strength"]};{entry["inputs"]["lora_clip_strength"]}'
|
||||
desired_lora_settings.add(setting)
|
||||
|
||||
if class_type in diffusion_loaders:
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name", prompt))
|
||||
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
|
||||
|
||||
elif class_type in ['easy kolorsLoader']:
|
||||
desired_unet_names.add(self.get_input_value(entry, "unet_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, "vae_name"))
|
||||
desired_glm3_names.add(self.get_input_value(entry, "chatglm3_name"))
|
||||
|
||||
elif class_type in dit_loaders:
|
||||
t5_name = self.get_input_value(entry, "mt5_name") if "mt5_name" in entry["inputs"] else None
|
||||
clip_name = self.get_input_value(entry, "clip_name") if "clip_name" in entry["inputs"] else None
|
||||
model_name = self.get_input_value(entry, "model_name")
|
||||
ckpt_name = self.get_input_value(entry, "ckpt_name", prompt)
|
||||
if t5_name:
|
||||
desired_t5_names.add(t5_name)
|
||||
if clip_name:
|
||||
desired_clip_names.add(clip_name)
|
||||
desired_ckpt_names.add(ckpt_name+'_'+model_name)
|
||||
|
||||
elif class_type in stable_cascade_loaders:
|
||||
desired_unet_names.add(self.get_input_value(entry, "stage_c"))
|
||||
desired_unet_names.add(self.get_input_value(entry, "stage_b"))
|
||||
desired_clip_names.add(self.get_input_value(entry, "clip_name"))
|
||||
desired_vae_names.add(self.get_input_value(entry, "stage_a"))
|
||||
|
||||
elif class_type in cascade_vae_node:
|
||||
encode_vae_name = self.get_input_value(entry, "encode_vae_name")
|
||||
decode_vae_name = self.get_input_value(entry, "decode_vae_name")
|
||||
if encode_vae_name and encode_vae_name != 'None':
|
||||
desired_vae_names.add(encode_vae_name)
|
||||
if decode_vae_name and decode_vae_name != 'None':
|
||||
desired_vae_names.add(decode_vae_name)
|
||||
|
||||
elif class_type in controlnet_loaders:
|
||||
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
|
||||
scale_soft_weights = self.get_input_value(entry, "scale_soft_weights")
|
||||
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
|
||||
|
||||
elif class_type in instant_loaders:
|
||||
control_net_name = self.get_input_value(entry, "control_net_name", prompt)
|
||||
scale_soft_weights = self.get_input_value(entry, "cn_soft_weights")
|
||||
desired_controlnet_names.add(f'{control_net_name};{scale_soft_weights}')
|
||||
|
||||
elif class_type in model_merge_node:
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_1"))
|
||||
desired_ckpt_names.add(self.get_input_value(entry, "ckpt_name_2"))
|
||||
vae_use = self.get_input_value(entry, "vae_use")
|
||||
if vae_use != 'Use Model 1' and vae_use != 'Use Model 2':
|
||||
desired_vae_names.add(vae_use)
|
||||
|
||||
object_types = ["ckpt", "unet", "clip", "bvae", "vae", "lora", "controlnet", "t5"]
|
||||
for object_type in object_types:
|
||||
if object_type == 'unet':
|
||||
desired_names = desired_unet_names
|
||||
elif object_type in ["ckpt", "clip", "bvae"]:
|
||||
if object_type == 'clip':
|
||||
desired_names = desired_ckpt_names.union(desired_clip_names)
|
||||
else:
|
||||
desired_names = desired_ckpt_names
|
||||
elif object_type == "vae":
|
||||
desired_names = desired_vae_names
|
||||
elif object_type == "controlnet":
|
||||
desired_names = desired_controlnet_names
|
||||
elif object_type == "t5":
|
||||
desired_names = desired_t5_names
|
||||
elif object_type == "chatglm3":
|
||||
desired_names = desired_glm3_names
|
||||
else:
|
||||
desired_names = desired_lora_names
|
||||
self.clear_unused_objects(desired_names, object_type)
|
||||
|
||||
def add_to_cache(self, obj_type, key, value):
|
||||
"""
|
||||
Add an item to the cache with the current timestamp.
|
||||
"""
|
||||
timestamped_value = (value, time.time())
|
||||
self.loaded_objects[obj_type][key] = timestamped_value
|
||||
|
||||
def determine_memory_threshold(self, percentage=0.8):
|
||||
"""
|
||||
Determines the memory threshold as a percentage of the total available memory.
|
||||
Args:
|
||||
- percentage (float): The fraction of total memory to use as the threshold.
|
||||
Should be a value between 0 and 1. Default is 0.8 (80%).
|
||||
Returns:
|
||||
- memory_threshold (int): Memory threshold in bytes.
|
||||
"""
|
||||
total_memory = psutil.virtual_memory().total
|
||||
memory_threshold = total_memory * percentage
|
||||
return memory_threshold
|
||||
|
||||
def get_memory_usage(self):
|
||||
"""
|
||||
Returns the memory usage of the current process in bytes.
|
||||
"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss
|
||||
|
||||
def eviction_based_on_memory(self):
|
||||
"""
|
||||
Evicts objects from cache based on memory usage and priority.
|
||||
"""
|
||||
current_memory = self.get_memory_usage()
|
||||
if current_memory < self.memory_threshold:
|
||||
return
|
||||
eviction_order = ["vae", "lora", "bvae", "clip", "ckpt", "controlnet", "unet", "t5", "chatglm3"]
|
||||
for obj_type in eviction_order:
|
||||
if current_memory < self.memory_threshold:
|
||||
break
|
||||
# Sort items based on age (using the timestamp)
|
||||
items = list(self.loaded_objects[obj_type].items())
|
||||
items.sort(key=lambda x: x[1][1]) # Sorting by timestamp
|
||||
|
||||
for item in items:
|
||||
if current_memory < self.memory_threshold:
|
||||
break
|
||||
del self.loaded_objects[obj_type][item[0]]
|
||||
current_memory = self.get_memory_usage()
|
||||
|
||||
def load_checkpoint(self, ckpt_name, config_name=None, load_vision=False):
|
||||
cache_name = ckpt_name
|
||||
if config_name not in [None, "Default"]:
|
||||
cache_name = ckpt_name + "_" + config_name
|
||||
if cache_name in self.loaded_objects["ckpt"]:
|
||||
clip_vision = self.loaded_objects["clip_vision"][cache_name][0] if load_vision else None
|
||||
clip = self.loaded_objects["clip"][cache_name][0] if not load_vision else None
|
||||
return self.loaded_objects["ckpt"][cache_name][0], clip, self.loaded_objects["bvae"][cache_name][0], clip_vision
|
||||
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
|
||||
output_clip = False if load_vision else True
|
||||
output_clipvision = True if load_vision else False
|
||||
if config_name not in [None, "Default"]:
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
loaded_ckpt = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
else:
|
||||
model_options = {}
|
||||
if re.search("nf4", ckpt_name):
|
||||
from ..modules.bitsandbytes_NF4 import OPS
|
||||
model_options = {"custom_operations": OPS}
|
||||
loaded_ckpt = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=output_clip, output_clipvision=output_clipvision, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
|
||||
|
||||
self.add_to_cache("ckpt", cache_name, loaded_ckpt[0])
|
||||
self.add_to_cache("bvae", cache_name, loaded_ckpt[2])
|
||||
|
||||
clip = loaded_ckpt[1]
|
||||
clip_vision = loaded_ckpt[3]
|
||||
if clip:
|
||||
self.add_to_cache("clip", cache_name, clip)
|
||||
if clip_vision:
|
||||
self.add_to_cache("clip_vision", cache_name, clip_vision)
|
||||
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return loaded_ckpt[0], clip, loaded_ckpt[2], clip_vision
|
||||
|
||||
def load_vae(self, vae_name):
|
||||
if vae_name in self.loaded_objects["vae"]:
|
||||
return self.loaded_objects["vae"][vae_name][0]
|
||||
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
loaded_vae = comfy.sd.VAE(sd=sd)
|
||||
self.add_to_cache("vae", vae_name, loaded_vae)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return loaded_vae
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
if unet_name in self.loaded_objects["unet"]:
|
||||
log_node_info("Load UNet", f"{unet_name} cached")
|
||||
return self.loaded_objects["unet"][unet_name][0]
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path)
|
||||
self.add_to_cache("unet", unet_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return model
|
||||
|
||||
def load_controlnet(self, control_net_name, scale_soft_weights=1, use_cache=True):
|
||||
unique_id = f'{control_net_name};{str(scale_soft_weights)}'
|
||||
if use_cache and unique_id in self.loaded_objects["controlnet"]:
|
||||
return self.loaded_objects["controlnet"][unique_id][0]
|
||||
if scale_soft_weights < 1:
|
||||
if "ScaledSoftControlNetWeights" in NODE_CLASS_MAPPINGS:
|
||||
soft_weight_cls = NODE_CLASS_MAPPINGS['ScaledSoftControlNetWeights']
|
||||
(weights, timestep_keyframe) = soft_weight_cls().load_weights(scale_soft_weights, False)
|
||||
cn_adv_cls = NODE_CLASS_MAPPINGS['ControlNetLoaderAdvanced']
|
||||
control_net, = cn_adv_cls().load_controlnet(control_net_name, timestep_keyframe)
|
||||
else:
|
||||
raise Exception(f"[Advanced-ControlNet Not Found] you need to install 'COMFYUI-Advanced-ControlNet'")
|
||||
else:
|
||||
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
||||
control_net = comfy.controlnet.load_controlnet(controlnet_path)
|
||||
if use_cache:
|
||||
self.add_to_cache("controlnet", unique_id, control_net)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return control_net
|
||||
def load_clip(self, clip_name, type='stable_diffusion', load_clip=None):
|
||||
if clip_name in self.loaded_objects["clip"]:
|
||||
return self.loaded_objects["clip"][clip_name][0]
|
||||
|
||||
if type == 'stable_diffusion':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == 'stable_cascade':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
|
||||
elif type == 'sd3':
|
||||
clip_type = comfy.sd.CLIPType.SD3
|
||||
elif type == 'flux':
|
||||
clip_type = comfy.sd.CLIPType.FLUX
|
||||
elif type == 'stable_audio':
|
||||
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
load_clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
self.add_to_cache("clip", clip_name, load_clip)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return load_clip
|
||||
|
||||
def load_lora(self, lora, model=None, clip=None, type=None , use_cache=True):
|
||||
lora_name = lora["lora_name"]
|
||||
model = model if model is not None else lora["model"]
|
||||
clip = clip if clip is not None else lora["clip"]
|
||||
model_strength = lora["model_strength"]
|
||||
clip_strength = lora["clip_strength"]
|
||||
lbw = lora["lbw"] if "lbw" in lora else None
|
||||
lbw_a = lora["lbw_a"] if "lbw_a" in lora else None
|
||||
lbw_b = lora["lbw_b"] if "lbw_b" in lora else None
|
||||
|
||||
model_hash = str(model)[44:-1]
|
||||
clip_hash = str(clip)[25:-1] if clip else ''
|
||||
|
||||
unique_id = f'{model_hash};{clip_hash};{lora_name};{model_strength};{clip_strength}'
|
||||
|
||||
if use_cache and unique_id in self.loaded_objects["lora"]:
|
||||
log_node_info("Load LORA",f"{lora_name} cached")
|
||||
return self.loaded_objects["lora"][unique_id][0]
|
||||
|
||||
orig_lora_name = lora_name
|
||||
lora_name = self.resolve_lora_name(lora_name)
|
||||
|
||||
if lora_name is not None:
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
else:
|
||||
lora_path = None
|
||||
|
||||
if lora_path is not None:
|
||||
log_node_info("Load LORA",f"{lora_name}: model={model_strength:.3f}, clip={clip_strength:.3f}, LBW={lbw}, A={lbw_a}, B={lbw_b}")
|
||||
if lbw:
|
||||
lbw = lora["lbw"]
|
||||
lbw_a = lora["lbw_a"]
|
||||
lbw_b = lora["lbw_b"]
|
||||
if 'LoraLoaderBlockWeight //Inspire' not in NODE_CLASS_MAPPINGS:
|
||||
raise Exception('[InspirePack Not Found] you need to install ComfyUI-Inspire-Pack')
|
||||
cls = NODE_CLASS_MAPPINGS['LoraLoaderBlockWeight //Inspire']
|
||||
model, clip, _ = cls().doit(model, clip, lora_name, model_strength, clip_strength, False, 0,
|
||||
lbw_a, lbw_b, "", lbw)
|
||||
else:
|
||||
_lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
||||
keys = _lora.keys()
|
||||
if "down_blocks.0.resnets.0.norm1.bias" in keys:
|
||||
print('Using LORA for Resadapter')
|
||||
key_map = {}
|
||||
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
|
||||
mapping_norm = {}
|
||||
|
||||
for key in keys:
|
||||
if ".weight" in key:
|
||||
key_name_in_ori_sd = key_map[key.replace(".weight", "")]
|
||||
mapping_norm[key_name_in_ori_sd] = _lora[key]
|
||||
elif ".bias" in key:
|
||||
key_name_in_ori_sd = key_map[key.replace(".bias", "")]
|
||||
mapping_norm[key_name_in_ori_sd.replace(".weight", ".bias")] = _lora[
|
||||
key
|
||||
]
|
||||
else:
|
||||
print("===>Unexpected key", key)
|
||||
mapping_norm[key] = _lora[key]
|
||||
|
||||
for k in mapping_norm.keys():
|
||||
if k not in model.model.state_dict():
|
||||
print("===>Missing key:", k)
|
||||
model.model.load_state_dict(mapping_norm, strict=False)
|
||||
return (model, clip)
|
||||
|
||||
# PixArt
|
||||
if type is not None and type == 'PixArt':
|
||||
from ..modules.dit.pixArt.loader import load_pixart_lora
|
||||
model = load_pixart_lora(model, _lora, lora_path, model_strength)
|
||||
else:
|
||||
model, clip = comfy.sd.load_lora_for_models(model, clip, _lora, model_strength, clip_strength)
|
||||
|
||||
if use_cache:
|
||||
self.add_to_cache("lora", unique_id, (model, clip))
|
||||
self.eviction_based_on_memory()
|
||||
else:
|
||||
log_node_error(f"LORA NOT FOUND", orig_lora_name)
|
||||
|
||||
return model, clip
|
||||
|
||||
def resolve_lora_name(self, name):
|
||||
if os.path.exists(name):
|
||||
return name
|
||||
else:
|
||||
if len(self.lora_name_cache) == 0:
|
||||
loras = folder_paths.get_filename_list("loras")
|
||||
self.lora_name_cache.extend(loras)
|
||||
for x in self.lora_name_cache:
|
||||
if x.endswith(name):
|
||||
return x
|
||||
|
||||
# 如果刷新网页后新添加的lora走这个逻辑
|
||||
log_node_info("LORA NOT IN CACHE", f"{name}")
|
||||
loras = folder_paths.get_filename_list("loras")
|
||||
for x in loras:
|
||||
if x.endswith(name):
|
||||
self.lora_name_cache.append(x)
|
||||
return x
|
||||
|
||||
return None
|
||||
|
||||
def load_main(self, ckpt_name, config_name, vae_name, lora_name, lora_model_strength, lora_clip_strength, optional_lora_stack, model_override, clip_override, vae_override, prompt, nf4=False):
|
||||
model: ModelPatcher | None = None
|
||||
clip: comfy.sd.CLIP | None = None
|
||||
vae: comfy.sd.VAE | None = None
|
||||
clip_vision = None
|
||||
lora_stack = []
|
||||
|
||||
# Check for model override
|
||||
can_load_lora = True
|
||||
# 判断是否存在 模型或Lora叠加xyplot, 若存在优先缓存第一个模型
|
||||
# Determine whether there is a model or Lora overlapping xyplot, and if there is, prioritize caching the first model.
|
||||
xy_model_id = next((x for x in prompt if str(prompt[x]["class_type"]) in ["easy XYInputs: ModelMergeBlocks",
|
||||
"easy XYInputs: Checkpoint"]), None)
|
||||
# This will find nodes that aren't actively connected to anything, and skip loading lora's for them.
|
||||
xy_lora_id = next((x for x in prompt if str(prompt[x]["class_type"]) == "easy XYInputs: Lora"), None)
|
||||
if xy_lora_id is not None:
|
||||
can_load_lora = False
|
||||
if xy_model_id is not None:
|
||||
node = prompt[xy_model_id]
|
||||
if "ckpt_name_1" in node["inputs"]:
|
||||
ckpt_name_1 = node["inputs"]["ckpt_name_1"]
|
||||
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name_1)
|
||||
can_load_lora = False
|
||||
elif model_override is not None and clip_override is not None and vae_override is not None:
|
||||
model = model_override
|
||||
clip = clip_override
|
||||
vae = vae_override
|
||||
else:
|
||||
model, clip, vae, clip_vision = self.load_checkpoint(ckpt_name, config_name)
|
||||
if model_override is not None:
|
||||
model = model_override
|
||||
if vae_override is not None:
|
||||
vae = vae_override
|
||||
elif clip_override is not None:
|
||||
clip = clip_override
|
||||
|
||||
|
||||
if optional_lora_stack is not None and can_load_lora:
|
||||
for lora in optional_lora_stack:
|
||||
# This is a subtle bit of code because it uses the model created by the last call, and passes it to the next call.
|
||||
lora = {"lora_name": lora[0], "model": model, "clip": clip, "model_strength": lora[1],
|
||||
"clip_strength": lora[2]}
|
||||
model, clip = self.load_lora(lora)
|
||||
lora['model'] = model
|
||||
lora['clip'] = clip
|
||||
lora_stack.append(lora)
|
||||
|
||||
if lora_name != "None" and can_load_lora:
|
||||
lora = {"lora_name": lora_name, "model": model, "clip": clip, "model_strength": lora_model_strength,
|
||||
"clip_strength": lora_clip_strength}
|
||||
model, clip = self.load_lora(lora)
|
||||
lora_stack.append(lora)
|
||||
|
||||
# Check for custom VAE
|
||||
if vae_name not in ["Baked VAE", "Baked-VAE"]:
|
||||
vae = self.load_vae(vae_name)
|
||||
# CLIP skip
|
||||
if not clip:
|
||||
raise Exception("No CLIP found")
|
||||
|
||||
return model, clip, vae, clip_vision, lora_stack
|
||||
|
||||
# Kolors
|
||||
def load_kolors_unet(self, unet_name):
|
||||
if unet_name in self.loaded_objects["unet"]:
|
||||
log_node_info("Load Kolors UNet", f"{unet_name} cached")
|
||||
return self.loaded_objects["unet"][unet_name][0]
|
||||
else:
|
||||
from ..modules.kolors.loader import applyKolorsUnet
|
||||
with applyKolorsUnet():
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = comfy.sd.load_unet_state_dict(sd)
|
||||
if model is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
|
||||
self.add_to_cache("unet", unet_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return model
|
||||
|
||||
def load_chatglm3(self, chatglm3_name):
|
||||
from ..modules.kolors.loader import load_chatglm3
|
||||
if chatglm3_name in self.loaded_objects["chatglm3"]:
|
||||
log_node_info("Load ChatGLM3", f"{chatglm3_name} cached")
|
||||
return self.loaded_objects["chatglm3"][chatglm3_name][0]
|
||||
|
||||
chatglm_model = load_chatglm3(model_path=folder_paths.get_full_path("llm", chatglm3_name))
|
||||
self.add_to_cache("chatglm3", chatglm3_name, chatglm_model)
|
||||
self.eviction_based_on_memory()
|
||||
|
||||
return chatglm_model
|
||||
|
||||
|
||||
# DiT
|
||||
def load_dit_ckpt(self, ckpt_name, model_name, **kwargs):
|
||||
if (ckpt_name+'_'+model_name) in self.loaded_objects["ckpt"]:
|
||||
return self.loaded_objects["ckpt"][ckpt_name+'_'+model_name][0]
|
||||
model = None
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
model_type = kwargs['model_type'] if "model_type" in kwargs else 'PixArt'
|
||||
if model_type == 'PixArt':
|
||||
pixart_conf = kwargs['pixart_conf']
|
||||
model_conf = pixart_conf[model_name]
|
||||
model = load_pixart(ckpt_path, model_conf)
|
||||
if model:
|
||||
self.add_to_cache("ckpt", ckpt_name + '_' + model_name, model)
|
||||
self.eviction_based_on_memory()
|
||||
return model
|
||||
|
||||
def load_t5_from_sd3_clip(self, sd3_clip, padding):
|
||||
try:
|
||||
from comfy.text_encoders.sd3_clip import SD3Tokenizer, SD3ClipModel
|
||||
except:
|
||||
from comfy.sd3_clip import SD3Tokenizer, SD3ClipModel
|
||||
import copy
|
||||
|
||||
clip = sd3_clip.clone()
|
||||
assert clip.cond_stage_model.t5xxl is not None, "CLIP must have T5 loaded!"
|
||||
|
||||
# remove transformer
|
||||
transformer = clip.cond_stage_model.t5xxl.transformer
|
||||
clip.cond_stage_model.t5xxl.transformer = None
|
||||
|
||||
# clone object
|
||||
tmp = SD3ClipModel(clip_l=False, clip_g=False, t5=False)
|
||||
tmp.t5xxl = copy.deepcopy(clip.cond_stage_model.t5xxl)
|
||||
# put transformer back
|
||||
clip.cond_stage_model.t5xxl.transformer = transformer
|
||||
tmp.t5xxl.transformer = transformer
|
||||
|
||||
# override special tokens
|
||||
tmp.t5xxl.special_tokens = copy.deepcopy(clip.cond_stage_model.t5xxl.special_tokens)
|
||||
tmp.t5xxl.special_tokens.pop("end") # make sure empty tokens match
|
||||
|
||||
# tokenizer
|
||||
tok = SD3Tokenizer()
|
||||
tok.t5xxl.min_length = padding
|
||||
|
||||
clip.cond_stage_model = tmp
|
||||
clip.tokenizer = tok
|
||||
|
||||
return clip
|
||||
Reference in New Issue
Block a user