Add memory estimation function to ltxav text encoder. (#11716)

This commit is contained in:
comfyanonymous
2026-01-07 17:11:22 -08:00
committed by GitHub
parent 3cd19e99c1
commit 25bc1b5b57
2 changed files with 15 additions and 4 deletions

View File

@@ -121,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module):
return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
if comfy.model_management.should_use_bf16(device):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):