Reduce RAM usage, fix VRAM OOMs, and fix Windows shared memory spilling with adaptive model loading (#11845)

This commit is contained in:
rattus
2026-01-31 22:01:11 -08:00
committed by GitHub
parent 873de5f37a
commit f8acd9c402
23 changed files with 1030 additions and 114 deletions

View File

@@ -228,8 +228,10 @@ class CLIP:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@@ -389,8 +391,18 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
else:
can_assign = self.patcher.is_dynamic()
self.cond_stage_model.can_assign_sd = can_assign
# The CLIP models are a pretty complex web of wrappers and its
# a bit of an API change to plumb this all the way through.
# So spray paint the model with this flag that the loading
# nn.Module can then inspect for itself.
for m in self.cond_stage_model.modules():
m.can_assign_sd = can_assign
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
@@ -765,12 +777,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
logging.debug("Leftover VAE keys {}".format(u))
model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
@@ -782,7 +789,18 @@ class VAE:
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
mp = comfy.model_patcher.CoreModelPatcher
if self.disable_offload:
mp = comfy.model_patcher.ModelPatcher
self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
if len(u) > 0:
logging.debug("Leftover VAE keys {}".format(u))
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
@@ -897,7 +915,7 @@ class VAE:
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@@ -971,7 +989,7 @@ class VAE:
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
@@ -1432,7 +1450,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
@@ -1520,7 +1538,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model.load_model_weights(sd, diffusion_model_prefix)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
@@ -1563,7 +1582,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
logging.debug("left over keys: {}".format(left_over))
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
@@ -1655,13 +1673,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device):
model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in diffusion model: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
return model_patcher
def load_diffusion_model(unet_path, model_options={}):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
@@ -1692,9 +1711,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if metadata is None:
metadata = {}
model_management.load_models_gpu(load_models, force_patch_weights=True)
model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]