Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Includes 30 custom nodes committed directly, 7 Civitai-exclusive loras stored via Git LFS, and a setup script that installs all dependencies and downloads HuggingFace-hosted models on vast.ai. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
247 lines
10 KiB
Python
247 lines
10 KiB
Python
import os
|
|
import fal_client
|
|
import folder_paths
|
|
import configparser
|
|
import base64
|
|
import io
|
|
from PIL import Image
|
|
import logging
|
|
import json
|
|
import requests
|
|
import numpy as np
|
|
import torch
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class BaseFalAPIFluxNode:
|
|
def __init__(self):
|
|
self.api_key = self.get_api_key()
|
|
os.environ['FAL_KEY'] = self.api_key
|
|
self.api_endpoint = None
|
|
|
|
def get_api_key(self):
|
|
config = configparser.ConfigParser()
|
|
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'config.ini')
|
|
if os.path.exists(config_path):
|
|
config.read(config_path)
|
|
return config.get('falai', 'api_key', fallback=None)
|
|
return None
|
|
|
|
def set_api_endpoint(self, endpoint):
|
|
self.api_endpoint = endpoint
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"prompt": ("STRING", {"multiline": True}),
|
|
"width": ("INT", {"default": 1024, "step": 8}),
|
|
"height": ("INT", {"default": 1024, "step": 8}),
|
|
"num_inference_steps": ("INT", {"default": 28, "min": 1, "max": 100}),
|
|
"guidance_scale": ("FLOAT", {"default": 3.5, "min": 0.1, "max": 40.0}),
|
|
"num_images": ("INT", {"default": 1, "min": 1, "max": 4}),
|
|
"enable_safety_checker": ("BOOLEAN", {"default": True}),
|
|
},
|
|
"optional": {
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "generate"
|
|
CATEGORY = "image generation"
|
|
|
|
def prepare_arguments(self, prompt, width, height, num_inference_steps, guidance_scale, num_images, enable_safety_checker, seed=None, **kwargs):
|
|
if not self.api_key:
|
|
raise ValueError("API key is not set. Please check your config.ini file.")
|
|
|
|
arguments = {
|
|
"prompt": prompt,
|
|
"num_inference_steps": num_inference_steps,
|
|
"guidance_scale": guidance_scale,
|
|
"num_images": num_images,
|
|
"enable_safety_checker": enable_safety_checker
|
|
}
|
|
|
|
# Handle custom image size
|
|
if width is None or height is None:
|
|
raise ValueError("Width and height must be provided when using custom image size")
|
|
arguments["image_size"] = {
|
|
"width": width,
|
|
"height": height
|
|
}
|
|
|
|
if seed is not None and seed != 0:
|
|
arguments["seed"] = seed
|
|
|
|
return arguments
|
|
|
|
def call_api(self, arguments):
|
|
logger.debug(f"Full API request payload: {json.dumps(arguments, indent=2)}")
|
|
|
|
if not self.api_endpoint:
|
|
raise ValueError("API endpoint is not set. Please set it using set_api_endpoint() method.")
|
|
|
|
try:
|
|
handler = fal_client.submit(
|
|
self.api_endpoint,
|
|
arguments=arguments,
|
|
)
|
|
result = handler.get()
|
|
logger.debug(f"API response: {json.dumps(result, indent=2)}")
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"API error details: {str(e)}")
|
|
if hasattr(e, 'response'):
|
|
logger.error(f"API error response: {e.response.text}")
|
|
raise RuntimeError(f"An error occurred when calling the fal.ai API: {str(e)}") from e
|
|
|
|
def process_images(self, result):
|
|
if "images" not in result or not result["images"]:
|
|
logger.error("No images were generated by the API.")
|
|
raise RuntimeError("No images were generated by the API.")
|
|
|
|
output_images = []
|
|
for index, img_info in enumerate(result["images"]):
|
|
try:
|
|
logger.debug(f"Processing image {index}: {json.dumps(img_info, indent=2)}")
|
|
if not isinstance(img_info, dict) or "url" not in img_info or not img_info["url"]:
|
|
logger.error(f"Invalid image info for image {index}")
|
|
continue
|
|
|
|
img_url = img_info["url"]
|
|
logger.debug(f"Image URL: {img_url[:100]}...") # Log the first 100 characters of the URL
|
|
|
|
if img_url.startswith("data:image"):
|
|
# Handle Base64 encoded image
|
|
try:
|
|
_, img_data = img_url.split(",", 1)
|
|
img_data = base64.b64decode(img_data)
|
|
except ValueError:
|
|
logger.error(f"Failed to split image URL for image {index}")
|
|
continue
|
|
else:
|
|
# Handle regular URL
|
|
try:
|
|
response = requests.get(img_url)
|
|
response.raise_for_status()
|
|
img_data = response.content
|
|
except requests.RequestException as e:
|
|
logger.error(f"Failed to download image from URL for image {index}: {str(e)}")
|
|
continue
|
|
|
|
# Log the first few bytes of the image data
|
|
logger.debug(f"First 20 bytes of image data: {img_data[:20]}")
|
|
|
|
# Try to interpret the data as an image
|
|
try:
|
|
img = Image.open(io.BytesIO(img_data))
|
|
logger.debug(f"Opened image with size: {img.size} and mode: {img.mode}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to open image data: {str(e)}")
|
|
# If opening as an image fails, try to interpret it as raw pixel data
|
|
img_np = np.frombuffer(img_data, dtype=np.uint8)
|
|
logger.debug(f"Interpreted as raw pixel data with shape: {img_np.shape}")
|
|
|
|
# If the shape is (1024,), reshape it to a more sensible image size
|
|
if img_np.shape == (1024,):
|
|
img_np = img_np.reshape(32, 32) # Reshape to 32x32 image
|
|
elif img_np.shape == (1, 1, 1024):
|
|
img_np = img_np.reshape(32, 32)
|
|
|
|
# Normalize the data to 0-255 range
|
|
img_np = ((img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255).astype(np.uint8)
|
|
|
|
img = Image.fromarray(img_np, 'L') # Create grayscale image
|
|
img = img.convert('RGB') # Convert to RGB
|
|
|
|
# Ensure image is in RGB mode
|
|
if img.mode != 'RGB':
|
|
img = img.convert('RGB')
|
|
|
|
# Convert PIL Image to NumPy array
|
|
img_np = np.array(img).astype(np.float32) / 255.0
|
|
|
|
# Create tensor with batch dimension (1, H, W, C)
|
|
img_tensor = torch.from_numpy(img_np)
|
|
img_tensor = img_tensor.unsqueeze(0) # (1, H, W, C)
|
|
|
|
output_images.append(img_tensor)
|
|
except Exception as e:
|
|
logger.error(f"Failed to process image {index}: {str(e)}")
|
|
|
|
if not output_images:
|
|
logger.error("Failed to process any of the generated images.")
|
|
raise RuntimeError("Failed to process any of the generated images.")
|
|
|
|
# Stack all images into a single batch tensor
|
|
if output_images:
|
|
output_tensor = torch.cat(output_images, dim=0)
|
|
logger.debug(f"Returning batched tensor with shape: {output_tensor.shape}")
|
|
return [output_tensor]
|
|
else:
|
|
logger.error("No images were successfully processed")
|
|
raise RuntimeError("No images were successfully processed")
|
|
|
|
def upload_image(self, image):
|
|
try:
|
|
# Convert PyTorch tensor to numpy array
|
|
if isinstance(image, torch.Tensor):
|
|
image = image.cpu().numpy()
|
|
|
|
# Handle different shapes of numpy arrays
|
|
if isinstance(image, np.ndarray):
|
|
if image.ndim == 4 and image.shape[0] == 1: # (1, H, W, 3) or (1, H, W, 1)
|
|
image = image.squeeze(0)
|
|
|
|
if image.ndim == 3:
|
|
if image.shape[2] == 3: # (H, W, 3) RGB image
|
|
pass
|
|
elif image.shape[2] == 1: # (H, W, 1) grayscale
|
|
image = np.repeat(image, 3, axis=2)
|
|
elif image.shape[0] == 3: # (3, H, W) RGB
|
|
image = np.transpose(image, (1, 2, 0))
|
|
elif image.shape[0] == 1: # (1, H, W) grayscale
|
|
image = np.repeat(image.squeeze(0)[..., np.newaxis], 3, axis=2)
|
|
elif image.shape == (1, 1, 1536): # Special case for (1, 1, 1536) shape
|
|
image = image.reshape(32, 48)
|
|
image = np.repeat(image[..., np.newaxis], 3, axis=2)
|
|
else:
|
|
raise ValueError(f"Unsupported image shape: {image.shape}")
|
|
|
|
# Normalize to 0-255 range if not already
|
|
if image.dtype != np.uint8:
|
|
image = (image - image.min()) / (image.max() - image.min()) * 255
|
|
image = image.astype(np.uint8)
|
|
|
|
image = Image.fromarray(image)
|
|
|
|
# Ensure image is in RGB mode
|
|
if image.mode != 'RGB':
|
|
image = image.convert('RGB')
|
|
|
|
# Resize image if it's too large (optional, adjust max_size as needed)
|
|
max_size = 1024 # Example max size
|
|
if max(image.size) > max_size:
|
|
image.thumbnail((max_size, max_size), Image.LANCZOS)
|
|
|
|
# Convert PIL Image to bytes
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
img_byte = buffered.getvalue()
|
|
|
|
# Upload the image using fal_client
|
|
url = fal_client.upload(img_byte, "image/png")
|
|
logger.info(f"Image uploaded successfully. URL: {url}")
|
|
return url
|
|
except Exception as e:
|
|
logger.error(f"Failed to process or upload image: {str(e)}")
|
|
raise
|
|
|
|
def generate(self, **kwargs):
|
|
arguments = self.prepare_arguments(**kwargs)
|
|
result = self.call_api(arguments)
|
|
output_images = self.process_images(result)
|
|
return tuple(output_images)
|