Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (#11635)

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
This commit is contained in:
comfyanonymous
2026-01-05 18:48:58 -08:00
committed by GitHub
parent 4f3f9e72a9
commit 6da00dd899
8 changed files with 223 additions and 795 deletions

View File

@@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if input is not None:
if dtype is None:
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
dtype = input.params.orig_dtype
else:
dtype = input.dtype
if bias_dtype is None:
@@ -412,26 +412,34 @@ def fp8_linear(self, input):
return None
input_dtype = input.dtype
input_shape = input.shape
tensor_3d = input.ndim == 3
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
if tensor_3d:
input = input.reshape(-1, input_shape[2])
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
if input.ndim != 2:
return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input_fp8 = input.to(dtype).contiguous()
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
quantized_input = QuantizedTensor(input_fp8, TensorCoreFP8Layout, layout_params_input)
uncast_bias_weight(self, w, bias, offload_stream)
return o
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
return None
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_3d:
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
return o
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
@@ -477,7 +485,12 @@ if CUBLAS_IS_AVAILABLE:
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
from .quant_ops import (
QuantizedTensor,
QUANT_ALGOS,
TensorCoreFP8Layout,
get_layout_class,
)
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
@@ -497,14 +510,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self._has_bias = bias
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@@ -512,6 +526,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
@@ -529,14 +553,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
@@ -547,31 +564,46 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = {
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
if scale is not None:
manually_loaded_keys.append(weight_scale_key)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
@@ -588,7 +620,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
layout_cls = self.weight._layout_cls
# Check if it's any FP8 variant (E4M3 or E5M2)
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout":
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
@@ -607,12 +647,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def forward(self, input, *args, **kwargs):
run_every_op()
input_shape = input.shape
tensor_3d = input.ndim == 3
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
if tensor_3d:
input = input.reshape(-1, input_shape[2])
if input.ndim != 2:
# Fall back to comfy_cast_weights for non-2D tensors
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
# dtype is now implicit in the layout class
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
output = self._forward(input, self.weight, self.bias)
# Reshape output back to 3D if input was 3D
if tensor_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
@@ -622,7 +683,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight: