Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
194 lines
9.3 KiB
Python
194 lines
9.3 KiB
Python
import yaml
|
|
from .sam2.modeling.sam2_base import SAM2Base
|
|
from .sam2.modeling.backbones.image_encoder import ImageEncoder
|
|
from .sam2.modeling.backbones.hieradet import Hiera
|
|
from .sam2.modeling.backbones.image_encoder import FpnNeck
|
|
from .sam2.modeling.position_encoding import PositionEmbeddingSine
|
|
from .sam2.modeling.memory_attention import MemoryAttention, MemoryAttentionLayer
|
|
from .sam2.modeling.sam.transformer import RoPEAttention
|
|
from .sam2.modeling.memory_encoder import MemoryEncoder, MaskDownSampler, Fuser, CXBlock
|
|
|
|
from .sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
from .sam2.sam2_video_predictor import SAM2VideoPredictor
|
|
from .sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
from comfy.utils import load_torch_file
|
|
|
|
def load_model(model_path, model_cfg_path, segmentor, dtype, device):
|
|
# Load the YAML configuration
|
|
with open(model_cfg_path, 'r') as file:
|
|
config = yaml.safe_load(file)
|
|
|
|
# Extract the model configuration
|
|
model_config = config['model']
|
|
|
|
# Instantiate the image encoder components
|
|
trunk_config = model_config['image_encoder']['trunk']
|
|
neck_config = model_config['image_encoder']['neck']
|
|
position_encoding_config = neck_config['position_encoding']
|
|
|
|
position_encoding = PositionEmbeddingSine(
|
|
num_pos_feats=position_encoding_config['num_pos_feats'],
|
|
normalize=position_encoding_config['normalize'],
|
|
scale=position_encoding_config['scale'],
|
|
temperature=position_encoding_config['temperature']
|
|
)
|
|
|
|
neck = FpnNeck(
|
|
position_encoding=position_encoding,
|
|
d_model=neck_config['d_model'],
|
|
backbone_channel_list=neck_config['backbone_channel_list'],
|
|
fpn_top_down_levels=neck_config['fpn_top_down_levels'],
|
|
fpn_interp_model=neck_config['fpn_interp_model']
|
|
)
|
|
|
|
keys_to_include = ['embed_dim', 'num_heads', 'global_att_blocks', 'window_pos_embed_bkg_spatial_size', 'stages']
|
|
trunk_kwargs = {key: trunk_config[key] for key in keys_to_include if key in trunk_config}
|
|
trunk = Hiera(**trunk_kwargs)
|
|
|
|
image_encoder = ImageEncoder(
|
|
scalp=model_config['image_encoder']['scalp'],
|
|
trunk=trunk,
|
|
neck=neck
|
|
)
|
|
# Instantiate the memory attention components
|
|
memory_attention_layer_config = config['model']['memory_attention']['layer']
|
|
self_attention_config = memory_attention_layer_config['self_attention']
|
|
cross_attention_config = memory_attention_layer_config['cross_attention']
|
|
|
|
self_attention = RoPEAttention(
|
|
rope_theta=self_attention_config['rope_theta'],
|
|
feat_sizes=self_attention_config['feat_sizes'],
|
|
embedding_dim=self_attention_config['embedding_dim'],
|
|
num_heads=self_attention_config['num_heads'],
|
|
downsample_rate=self_attention_config['downsample_rate'],
|
|
dropout=self_attention_config['dropout']
|
|
)
|
|
|
|
cross_attention = RoPEAttention(
|
|
rope_theta=cross_attention_config['rope_theta'],
|
|
feat_sizes=cross_attention_config['feat_sizes'],
|
|
rope_k_repeat=cross_attention_config['rope_k_repeat'],
|
|
embedding_dim=cross_attention_config['embedding_dim'],
|
|
num_heads=cross_attention_config['num_heads'],
|
|
downsample_rate=cross_attention_config['downsample_rate'],
|
|
dropout=cross_attention_config['dropout'],
|
|
kv_in_dim=cross_attention_config['kv_in_dim']
|
|
)
|
|
|
|
memory_attention_layer = MemoryAttentionLayer(
|
|
activation=memory_attention_layer_config['activation'],
|
|
dim_feedforward=memory_attention_layer_config['dim_feedforward'],
|
|
dropout=memory_attention_layer_config['dropout'],
|
|
pos_enc_at_attn=memory_attention_layer_config['pos_enc_at_attn'],
|
|
self_attention=self_attention,
|
|
d_model=memory_attention_layer_config['d_model'],
|
|
pos_enc_at_cross_attn_keys=memory_attention_layer_config['pos_enc_at_cross_attn_keys'],
|
|
pos_enc_at_cross_attn_queries=memory_attention_layer_config['pos_enc_at_cross_attn_queries'],
|
|
cross_attention=cross_attention
|
|
)
|
|
|
|
memory_attention = MemoryAttention(
|
|
d_model=config['model']['memory_attention']['d_model'],
|
|
pos_enc_at_input=config['model']['memory_attention']['pos_enc_at_input'],
|
|
layer=memory_attention_layer,
|
|
num_layers=config['model']['memory_attention']['num_layers']
|
|
)
|
|
|
|
# Instantiate the memory encoder components
|
|
memory_encoder_config = config['model']['memory_encoder']
|
|
position_encoding_mem_enc_config = memory_encoder_config['position_encoding']
|
|
mask_downsampler_config = memory_encoder_config['mask_downsampler']
|
|
fuser_layer_config = memory_encoder_config['fuser']['layer']
|
|
|
|
position_encoding_mem_enc = PositionEmbeddingSine(
|
|
num_pos_feats=position_encoding_mem_enc_config['num_pos_feats'],
|
|
normalize=position_encoding_mem_enc_config['normalize'],
|
|
scale=position_encoding_mem_enc_config['scale'],
|
|
temperature=position_encoding_mem_enc_config['temperature']
|
|
)
|
|
|
|
mask_downsampler = MaskDownSampler(
|
|
kernel_size=mask_downsampler_config['kernel_size'],
|
|
stride=mask_downsampler_config['stride'],
|
|
padding=mask_downsampler_config['padding']
|
|
)
|
|
|
|
fuser_layer = CXBlock(
|
|
dim=fuser_layer_config['dim'],
|
|
kernel_size=fuser_layer_config['kernel_size'],
|
|
padding=fuser_layer_config['padding'],
|
|
layer_scale_init_value=float(fuser_layer_config['layer_scale_init_value'])
|
|
)
|
|
fuser = Fuser(
|
|
num_layers=memory_encoder_config['fuser']['num_layers'],
|
|
layer=fuser_layer
|
|
)
|
|
|
|
memory_encoder = MemoryEncoder(
|
|
position_encoding=position_encoding_mem_enc,
|
|
mask_downsampler=mask_downsampler,
|
|
fuser=fuser,
|
|
out_dim=memory_encoder_config['out_dim']
|
|
)
|
|
|
|
sam_mask_decoder_extra_args = {
|
|
"dynamic_multimask_via_stability": True,
|
|
"dynamic_multimask_stability_delta": 0.05,
|
|
"dynamic_multimask_stability_thresh": 0.98,
|
|
}
|
|
|
|
def initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device):
|
|
return model_class(
|
|
image_encoder=image_encoder,
|
|
memory_attention=memory_attention,
|
|
memory_encoder=memory_encoder,
|
|
sam_mask_decoder_extra_args=sam_mask_decoder_extra_args,
|
|
num_maskmem=model_config['num_maskmem'],
|
|
image_size=model_config['image_size'],
|
|
sigmoid_scale_for_mem_enc=model_config['sigmoid_scale_for_mem_enc'],
|
|
sigmoid_bias_for_mem_enc=model_config['sigmoid_bias_for_mem_enc'],
|
|
use_mask_input_as_output_without_sam=model_config['use_mask_input_as_output_without_sam'],
|
|
directly_add_no_mem_embed=model_config['directly_add_no_mem_embed'],
|
|
use_high_res_features_in_sam=model_config['use_high_res_features_in_sam'],
|
|
multimask_output_in_sam=model_config['multimask_output_in_sam'],
|
|
iou_prediction_use_sigmoid=model_config['iou_prediction_use_sigmoid'],
|
|
use_obj_ptrs_in_encoder=model_config['use_obj_ptrs_in_encoder'],
|
|
add_tpos_enc_to_obj_ptrs=model_config['add_tpos_enc_to_obj_ptrs'],
|
|
only_obj_ptrs_in_the_past_for_eval=model_config['only_obj_ptrs_in_the_past_for_eval'],
|
|
pred_obj_scores=model_config['pred_obj_scores'],
|
|
pred_obj_scores_mlp=model_config['pred_obj_scores_mlp'],
|
|
fixed_no_obj_ptr=model_config['fixed_no_obj_ptr'],
|
|
multimask_output_for_tracking=model_config['multimask_output_for_tracking'],
|
|
use_multimask_token_for_obj_ptr=model_config['use_multimask_token_for_obj_ptr'],
|
|
compile_image_encoder=model_config['compile_image_encoder'],
|
|
multimask_min_pt_num=model_config['multimask_min_pt_num'],
|
|
multimask_max_pt_num=model_config['multimask_max_pt_num'],
|
|
use_mlp_for_obj_ptr_proj=model_config['use_mlp_for_obj_ptr_proj'],
|
|
proj_tpos_enc_in_obj_ptrs=model_config['proj_tpos_enc_in_obj_ptrs'],
|
|
no_obj_embed_spatial=model_config['no_obj_embed_spatial'],
|
|
use_signed_tpos_enc_to_obj_ptrs=model_config['use_signed_tpos_enc_to_obj_ptrs'],
|
|
binarize_mask_from_pts_for_mem_enc=True if segmentor == 'video' else False,
|
|
).to(dtype).to(device).eval()
|
|
|
|
# Load the state dictionary
|
|
sd = load_torch_file(model_path)
|
|
|
|
# Initialize model based on segmentor type
|
|
if segmentor == 'single_image':
|
|
model_class = SAM2Base
|
|
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
|
|
model.load_state_dict(sd)
|
|
model = SAM2ImagePredictor(model)
|
|
elif segmentor == 'video':
|
|
model_class = SAM2VideoPredictor
|
|
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
|
|
model.load_state_dict(sd)
|
|
elif segmentor == 'automaskgenerator':
|
|
model_class = SAM2Base
|
|
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
|
|
model.load_state_dict(sd)
|
|
model = SAM2AutomaticMaskGenerator(model)
|
|
else:
|
|
raise ValueError(f"Segmentor {segmentor} not supported")
|
|
|
|
return model |