[API Nodes]: fixes and refactor (#11104)

* chore(api-nodes): applied ruff's pyupgrade(python3.10) to api-nodes client's to folder

* chore(api-nodes): add validate_video_frame_count function from LTX PR

* chore(api-nodes): replace deprecated V1 imports

* fix(api-nodes): the types returned by the "poll_op" function are now correct.
This commit is contained in:
Alexander Piskun
2025-12-05 00:05:28 +02:00
committed by GitHub
parent 9bc893c5bb
commit 3c8456223c
8 changed files with 146 additions and 135 deletions

View File

@@ -4,10 +4,11 @@ import json
import logging
import time
import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import Any, Literal, TypeVar
from urllib.parse import urljoin, urlparse
import aiohttp
@@ -37,8 +38,8 @@ class ApiEndpoint:
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
query_params: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
query_params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
):
self.path = path
self.method = method
@@ -52,18 +53,18 @@ class _RequestConfig:
endpoint: ApiEndpoint
timeout: float
content_type: str
data: Optional[dict[str, Any]]
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
multipart_parser: Optional[Callable]
data: dict[str, Any] | None
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
estimated_total: Optional[int] = None
final_label_on_success: Optional[str] = "Completed"
progress_origin_ts: Optional[float] = None
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
estimated_total: int | None = None
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
@dataclass
@@ -71,10 +72,10 @@ class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
price: Optional[float] = None
estimated_duration: Optional[int] = None
price: float | None = None
estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: Optional[float] = None # start time of current active interval (None if queued)
active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@@ -87,20 +88,20 @@ async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
response_model: Type[M],
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
data: Optional[BaseModel] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
response_model: type[M],
price_extractor: Callable[[M | Any], float | None] | None = None,
data: BaseModel | None = None,
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
estimated_duration: int | None = None,
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
@@ -131,22 +132,22 @@ async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
response_model: Type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[BaseModel] = None,
response_model: type[M],
status_extractor: Callable[[M | Any], str | int | None],
progress_extractor: Callable[[M | Any], int | None] | None = None,
price_extractor: Callable[[M | Any], float | None] | None = None,
completed_statuses: list[str | int] | None = None,
failed_statuses: list[str | int] | None = None,
queued_statuses: list[str | int] | None = None,
data: BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
@@ -178,22 +179,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
data: dict[str, Any] | BaseModel | None = None,
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
estimated_duration: int | None = None,
as_binary: bool = False,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
) -> Union[dict[str, Any], bytes]:
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
@@ -229,21 +230,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
status_extractor: Callable[[dict[str, Any]], str | int | None],
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
completed_statuses: list[str | int] | None = None,
failed_statuses: list[str | int] | None = None,
queued_statuses: list[str | int] | None = None,
data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
@@ -261,7 +262,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: Optional[int] = None
last_progress: int | None = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
@@ -420,10 +421,10 @@ async def poll_op_raw(
def _display_text(
node_cls: type[IO.ComfyNode],
text: Optional[str],
text: str | None,
*,
status: Optional[Union[str, int]] = None,
price: Optional[float] = None,
status: str | int | None = None,
price: float | None = None,
) -> None:
display_lines: list[str] = []
if status:
@@ -440,13 +441,13 @@ def _display_text(
def _display_time_progress(
node_cls: type[IO.ComfyNode],
status: Optional[Union[str, int]],
status: str | int | None,
elapsed_seconds: int,
estimated_total: Optional[int] = None,
estimated_total: int | None = None,
*,
price: Optional[float] = None,
is_queued: Optional[bool] = None,
processing_elapsed_seconds: Optional[int] = None,
price: float | None = None,
is_queued: bool | None = None,
processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
data: Optional[dict[str, Any]],
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
) -> Optional[Union[dict[str, Any], str]]:
data: dict[str, Any] | None,
files: dict[str, Any] | list[tuple[str, Any]] | None,
) -> dict[str, Any] | str | None:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
@@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None
extracted_price: Optional[float] = None
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
while True:
attempt += 1
stop_event = asyncio.Event()
monitor_task: Optional[asyncio.Task] = None
sess: Optional[aiohttp.ClientSession] = None
monitor_task: asyncio.Task | None = None
sess: aiohttp.ClientSession | None = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
)
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
@@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _wrap_model_extractor(
response_model: Type[M],
extractor: Optional[Callable[[M], Any]],
) -> Optional[Callable[[dict[str, Any]], Any]]:
response_model: type[M],
extractor: Callable[[M], Any] | None,
) -> Callable[[dict[str, Any]], Any] | None:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@@ -929,10 +930,10 @@ def _wrap_model_extractor(
return _wrapped
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
if not values:
return set()
out: set[Union[str, int]] = set()
out: set[str | int] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
@@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
return out
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
def _normalize_status_value(val: str | int | None) -> str | int | None:
if isinstance(val, str):
return val.strip().lower()
return val