Make old scaled fp8 format use the new mixed quant ops system. (#11000)

This commit is contained in:
comfyanonymous
2025-12-05 11:35:42 -08:00
committed by GitHub
parent 0ec05b1481
commit 43071e3de3
24 changed files with 278 additions and 275 deletions

View File

@@ -126,27 +126,11 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3