Support the LTXV 2 model. (#11632)
This commit is contained in:
@@ -5,7 +5,9 @@ import comfy.model_management
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
|
||||
import folder_paths
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||
@classmethod
|
||||
@@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
|
||||
|
||||
if "blocks.0.block.0.conv.weight" in sd:
|
||||
config = {
|
||||
@@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
"global_residual": False,
|
||||
}
|
||||
model_type = "720p"
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||
config = {
|
||||
@@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||
}
|
||||
model_type = "1080p"
|
||||
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
|
||||
config = json.loads(metadata["config"])
|
||||
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
|
||||
model.load_state_dict(sd)
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user