Make old scaled fp8 format use the new mixed quant ops system. (#11000)
This commit is contained in:
@@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||
if model_type == "gemma2_2b":
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
@@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||
|
||||
Reference in New Issue
Block a user