Make old scaled fp8 format use the new mixed quant ops system. (#11000)
This commit is contained in:
@@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
import numbers
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
@@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
|
||||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
if "_quantization_metadata" in state_dict:
|
||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["llama_quantization_metadata"] = quant
|
||||
|
||||
return out
|
||||
|
||||
@@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
textmodel_json_config = {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
@@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HunyuanVideoClipModel_
|
||||
|
||||
Reference in New Issue
Block a user