Files
ComfyUI/custom_nodes/rgthree-comfy/py/power_puter.py
jaidaken f09734b0ee
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 00:56:42 +00:00

843 lines
31 KiB
Python

"""The Power Puter is a powerful node that can compute and evaluate Python-like code safely allowing
for complex operations for primitives and workflow items for output. From string concatenation, to
math operations, list comprehension, and node value output.
Originally based off https://github.com/pythongosssss/ComfyUI-Custom-Scripts/blob/aac13aa7ce35b07d43633c3bbe654a38c00d74f5/py/math_expression.py
under an MIT License https://github.com/pythongosssss/ComfyUI-Custom-Scripts/blob/aac13aa7ce35b07d43633c3bbe654a38c00d74f5/LICENSE
"""
import math
import ast
import json
import random
import dataclasses
import re
import time
import operator as op
import datetime
import numpy as np
from typing import Any, Callable, Iterable, Optional, Union
from types import MappingProxyType
from .constants import get_category, get_name
from .utils import ByPassTypeTuple, FlexibleOptionalInputType, any_type, get_dict_value
from .log import log_node_error, log_node_warn, log_node_info
from .power_lora_loader import RgthreePowerLoraLoader
from nodes import ImageBatch
from comfy_extras.nodes_latent import LatentBatch
class LoopBreak(Exception):
"""A special error type that is caught in a loop for correct breaking behavior."""
def __init__(self):
super().__init__('Cannot use "break" outside of a loop.')
class LoopContinue(Exception):
"""A special error type that is caught in a loop for correct continue behavior."""
def __init__(self):
super().__init__('Cannot use "continue" outside of a loop.')
@dataclasses.dataclass(frozen=True) # Note, kw_only=True is only python 3.10+
class Function():
"""Function data.
Attributes:
name: The name of the function as called from the node.
call: The callable (reference, lambda, etc), or a string if on _Puter instance.
args: A tuple that represents the minimum and maximum number of args (or arg for no limit).
"""
name: str
call: Union[Callable, str]
args: tuple[int, Optional[int]]
def purge_vram(purge_models=True):
"""Purges vram and, optionally, unloads models."""
import gc
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if purge_models:
import comfy
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
def batch(*args):
"""Batches multiple image or latents together."""
def check_is_latent(item) -> bool:
return isinstance(item, dict) and 'samples' in item
args = list(args)
result = args.pop(0)
is_latent = check_is_latent(result)
node = LatentBatch() if is_latent else ImageBatch()
for arg in args:
if is_latent != check_is_latent(arg):
raise ValueError(
f'batch() error: Expecting "{"LATENT" if is_latent else "IMAGE"}"'
f' but got "{"IMAGE" if is_latent else "LATENT"}".'
)
result = node.batch(result, arg)[0]
return result
_BUILTIN_FN_PREFIX = '__rgthreefn.'
def _get_built_in_fn_key(fn: Function) -> str:
"""Returns a key for a built-in function."""
return f'{_BUILTIN_FN_PREFIX}{hash(fn.name)}'
def _get_built_in_fn_by_key(fn_key: str):
"""Returns the `Function` for the provided key (purposefully, not name)."""
if not fn_key.startswith(_BUILTIN_FN_PREFIX) or fn_key not in _BUILT_INS_BY_NAME_AND_KEY:
raise ValueError('No built in function found.')
return _BUILT_INS_BY_NAME_AND_KEY[fn_key]
_BUILT_IN_FNS_LIST = [
Function(name="round", call=round, args=(1, 2)),
Function(name="ceil", call=math.ceil, args=(1, 1)),
Function(name="floor", call=math.floor, args=(1, 1)),
Function(name="sqrt", call=math.sqrt, args=(1, 1)),
Function(name="min", call=min, args=(2, None)),
Function(name="max", call=max, args=(2, None)),
Function(name=".random_int", call=random.randint, args=(2, 2)),
Function(name=".random_choice", call=random.choice, args=(1, 1)),
Function(name=".random_seed", call=random.seed, args=(1, 1)),
Function(name="re", call=re.compile, args=(1, 1)),
Function(name="len", call=len, args=(1, 1)),
Function(name="enumerate", call=enumerate, args=(1, 1)),
Function(name="range", call=range, args=(1, 3)),
# Casts
Function(name="int", call=int, args=(1, 1)),
Function(name="float", call=float, args=(1, 1)),
Function(name="str", call=str, args=(1, 1)),
Function(name="bool", call=bool, args=(1, 1)),
Function(name="list", call=list, args=(1, 1)),
Function(name="tuple", call=tuple, args=(1, 1)),
# Special
Function(name="dir", call=dir, args=(1, 1)),
Function(name="type", call=type, args=(1, 1)),
Function(name="print", call=print, args=(0, None)),
# Comfy Specials
Function(name="node", call='_get_node', args=(0, 1)),
Function(name="nodes", call='_get_nodes', args=(0, 1)),
Function(name="input_node", call='_get_input_node', args=(0, 1)),
Function(name="purge_vram", call=purge_vram, args=(0, 1)),
Function(name="batch", call=batch, args=(2, None)),
]
_BUILT_INS_BY_NAME_AND_KEY = {
fn.name: fn for fn in _BUILT_IN_FNS_LIST
} | {
key: fn for fn in _BUILT_IN_FNS_LIST if (key := _get_built_in_fn_key(fn))
}
_BUILT_INS = MappingProxyType(
{fn.name: key for fn in _BUILT_IN_FNS_LIST if (key := _get_built_in_fn_key(fn))} | {
'random':
MappingProxyType({
'int': _get_built_in_fn_key(_BUILT_INS_BY_NAME_AND_KEY['.random_int']),
'choice': _get_built_in_fn_key(_BUILT_INS_BY_NAME_AND_KEY['.random_choice']),
'seed': _get_built_in_fn_key(_BUILT_INS_BY_NAME_AND_KEY['.random_seed']),
}),
}
)
# A dict of types to blocked attributes/methods. Used to disallow file system access or other
# invocations we may want to block. Necessary for any instance type that is possible to create from
# the code or standard ComfyUI inputs.
#
# For instance, a user does not have access to the numpy module directly, so they cannot invoke
# `numpy.save`. However, a user can access a numpy.ndarray instance from a tensor and, from there,
# an attempt to call `tofile` or `dump` etc. would need to be blocked.
_BLOCKED_METHODS_OR_ATTRS = MappingProxyType({np.ndarray: ['tofile', 'dump']})
# Special functions by class type (called from the Attrs.)
_SPECIAL_FUNCTIONS = {
RgthreePowerLoraLoader.NAME: {
# Get a list of the enabled loras from a power lora loader.
"loras": RgthreePowerLoraLoader.get_enabled_loras_from_prompt_node,
"triggers": RgthreePowerLoraLoader.get_enabled_triggers_from_prompt_node,
}
}
# Series of regex checks for usage of a non-deterministic function. Using these is fine, but means
# the output can't be cached because it's either random, or is associated with another node that is
# not connected to ours (like looking up a node in the prompt). Using these means downstream nodes
# would always be run; that is fine for something like a final JSON output, but less so for a prompt
# text.
_NON_DETERMINISTIC_FUNCTION_CHECKS = [r'(?<!input_)(nodes?)\(',]
_OPERATORS = {
# operator
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.MatMult: op.matmul,
ast.Div: op.truediv,
ast.Mod: op.mod,
ast.Pow: op.pow,
ast.RShift: op.rshift,
ast.LShift: op.lshift,
ast.BitOr: op.or_,
ast.BitXor: op.xor,
ast.BitAnd: op.and_,
ast.FloorDiv: op.floordiv,
# boolop
ast.And: lambda a, b: a and b,
ast.Or: lambda a, b: a or b,
# unaryop
ast.Invert: op.invert,
ast.Not: lambda a: 0 if a else 1,
ast.USub: op.neg,
# cmpop
ast.Eq: op.eq,
ast.NotEq: op.ne,
ast.Lt: op.lt,
ast.LtE: op.le,
ast.Gt: op.gt,
ast.GtE: op.ge,
ast.Is: op.is_,
ast.IsNot: op.is_not,
ast.In: lambda a, b: a in b,
ast.NotIn: lambda a, b: a not in b,
}
_NODE_NAME = get_name("Power Puter")
def _update_code(code: str, unique_id: str, log=False):
"""Updates the code to either newer syntax or general cleaning."""
# Change usage of `input_node` so the passed variable is a string, if it isn't. So, instead of
# `input_node(a)` it needs to be `input_node('a')`
code = re.sub(r'input_node\(([^\'"].*?)\)', r'input_node("\1")', code)
# Update use of `random_int` to `random.int`
srch = re.compile(r'random_int\(')
if re.search(srch, code):
if log:
log_node_warn(
_NODE_NAME, f"Power Puter node #{unique_id} should update to use the `random.int`"
" built-in instead of `random_int`."
)
code = re.sub(srch, 'random.int(', code)
# Update use of `random_choice` to `random.choice`
srch = re.compile(r'random_choice\(')
if re.search(srch, code):
if log:
log_node_warn(
_NODE_NAME, f"Power Puter node #{unique_id} should update to use the `random.choice`"
" built-in instead of `random_choice`."
)
code = re.sub(srch, 'random.choice(', code)
return code
class RgthreePowerPuter:
"""A powerful node that can compute and evaluate expressions and output as various types."""
NAME = _NODE_NAME
CATEGORY = get_category()
@classmethod
def INPUT_TYPES(cls): # pylint: disable = invalid-name, missing-function-docstring
return {
"required": {},
"optional": FlexibleOptionalInputType(any_type),
"hidden": {
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
"prompt": "PROMPT",
"dynprompt": "DYNPROMPT"
},
}
RETURN_TYPES = ByPassTypeTuple((any_type,))
RETURN_NAMES = ByPassTypeTuple(("*",))
FUNCTION = "main"
@classmethod
def IS_CHANGED(cls, **kwargs):
"""Forces a changed state if we could be unaware of data changes (like using `node()`)."""
code = _update_code(kwargs['code'], unique_id=kwargs['unique_id'])
# Strip string literals and comments.
code = re.sub(r"'[^']+?'", "''", code)
code = re.sub(r'"[^"]+?"', '""', code)
code = re.sub(r'#.*\n', '\n', code)
# If we have a non-deterministic function, then we'll always consider ourself changed since we
# cannot be sure that the data would be the same (random, another unconnected node, etc).
for check in _NON_DETERMINISTIC_FUNCTION_CHECKS:
matches = re.search(check, code)
if matches:
log_node_warn(
_NODE_NAME,
f"Note, Power Puter (node #{kwargs['unique_id']}) cannot be cached b/c it's using a"
f" non-deterministic function call. Matches function call for '{matches.group(1)}'."
)
return time.time()
# Advanced checks.
has_rand_seed = re.search(r'random\.seed\(', code)
has_rand_int_or_choice = re.search(r'(?<!\.)(random\.(int|choice))\(', code)
if has_rand_int_or_choice:
if not has_rand_seed or has_rand_seed.span()[0] > has_rand_int_or_choice.span()[0]:
log_node_warn(
_NODE_NAME,
f"Note, Power Puter (node #{kwargs['unique_id']}) cannot be cached b/c it's using a"
" non-deterministic function call. Matches function call for"
f" `{has_rand_int_or_choice.group(1)}`."
)
return time.time()
if has_rand_seed:
log_node_info(
_NODE_NAME,
f"Power Puter node #{kwargs['unique_id']} WILL be cached eventhough it's using"
f" a non-deterministic random call `{has_rand_int_or_choice.group(1)}` because it also"
f" calls `random.seed` first. NOTE: Please ensure that the seed value is deterministic."
)
return 42
def main(self, **kwargs):
"""Does the nodes' work."""
code = kwargs['code']
unique_id = kwargs['unique_id']
pnginfo = kwargs['extra_pnginfo']
workflow = pnginfo["workflow"] if "workflow" in pnginfo else {"nodes": []}
prompt = kwargs['prompt']
dynprompt = kwargs['dynprompt']
outputs = get_dict_value(kwargs, 'outputs.outputs', None)
if not outputs:
output = kwargs.get('output', None)
if not output:
output = 'STRING'
outputs = [output]
ctx = {}
# Set variable names, defaulting to None instead of KeyErrors
for c in list('abcdefghijklmnopqrstuvwxyz'):
ctx[c] = kwargs[c] if c in kwargs else None
code = _update_code(kwargs['code'], unique_id=kwargs['unique_id'], log=True)
eva = _Puter(
code=code,
ctx=ctx,
workflow=workflow,
prompt=prompt,
dynprompt=dynprompt,
unique_id=unique_id
)
values = eva.execute()
# Check if we have multiple outputs that the returned value is a tuple and raise if not.
if len(outputs) > 1 and not isinstance(values, tuple):
t = re.sub(r'^<[a-z]*\s(.*?)>$', r'\1', str(type(values)))
msg = (
f"When using multiple node outputs, the value from the code should be a 'tuple' with the"
f" number of items equal to the number of outputs. But value from code was of type {t}."
)
log_node_error(_NODE_NAME, f'{msg}\n')
raise ValueError(msg)
if len(outputs) == 1:
values = (values,)
if len(values) > len(outputs):
log_node_warn(
_NODE_NAME,
f"Expected value from code to be tuple with {len(outputs)} items, but value from code had"
f" {len(values)} items. Extra values will be dropped."
)
elif len(values) < len(outputs):
log_node_warn(
_NODE_NAME,
f"Expected value from code to be tuple with {len(outputs)} items, but value from code had"
f" {len(values)} items. Extra outputs will be null."
)
# Now, we'll go over out return tuple, and cast as the output types.
response = []
for i, output in enumerate(outputs):
value = values[i] if len(values) > i else None
if value is not None:
if output == 'INT':
value = int(value)
elif output == 'FLOAT':
value = float(value)
# Accidentally defined "BOOL" when should have been "BOOLEAN."
# TODO: Can prob get rid of BOOl after a bit when UIs would be updated from sending
# BOOL incorrectly.
elif output in ('BOOL', 'BOOLEAN'):
value = bool(value)
elif output == 'STRING':
if isinstance(value, (dict, list)):
value = json.dumps(value, indent=2)
else:
value = str(value)
elif output == '*':
# Do nothing, the output will be passed as-is. This could be anything and it's up to the
# user to control the intended output, like passing through an input value, etc.
pass
response.append(value)
return tuple(response)
class _Puter:
"""The main computation evaluator, using ast.parse the code.
See https://www.basicexamples.com/example/python/ast for examples.
"""
def __init__(self, *, code: str, ctx: dict[str, Any], workflow, prompt, dynprompt, unique_id):
ctx = ctx or {}
self._ctx = {**ctx}
self._code = code
self._workflow = workflow
self._prompt = prompt
self._unique_id = unique_id
self._dynprompt = dynprompt
# These are now expanded lazily when needed.
self._prompt_nodes = None
self._prompt_node = None
def execute(self, code=Optional[str]) -> Any:
"""Evaluates a the code block."""
# Always store random state and initialize a new seed. We'll restore the state later.
initial_random_state = random.getstate()
random.seed(datetime.datetime.now().timestamp())
last_value = None
try:
code = code or self._code
node = ast.parse(self._code)
ctx = {**self._ctx}
for body in node.body:
last_value = self._eval_statement(body, ctx)
# If we got a return, then that's it folks.
if isinstance(body, ast.Return):
break
except:
random.setstate(initial_random_state)
raise
random.setstate(initial_random_state)
return last_value
def _get_prompt_nodes(self):
"""Expands the prompt nodes lazily from the dynamic prompt.
https://github.com/comfyanonymous/ComfyUI/blob/fc657f471a29d07696ca16b566000e8e555d67d1/comfy_execution/graph.py#L22
"""
if self._prompt_nodes is None:
self._prompt_nodes = []
if self._dynprompt:
all_ids = self._dynprompt.all_node_ids()
self._prompt_nodes = [{'id': k} | {**self._dynprompt.get_node(k)} for k in all_ids]
return self._prompt_nodes
def _get_prompt_node(self):
if self._prompt_nodes is None:
self._prompt_node = [n for n in self._get_prompt_nodes() if n['id'] == self._unique_id][0]
return self._prompt_node
def _get_nodes(self, node_id: Union[int, str, re.Pattern, None] = None) -> list[Any]:
"""Get a list of the nodes that match the node_id, or all the nodes in the prompt."""
nodes = self._get_prompt_nodes().copy()
if not node_id:
return nodes
if isinstance(node_id, re.Pattern):
found = [n for n in nodes if re.search(node_id, get_dict_value(n, '_meta.title', ''))]
else:
node_id = str(node_id)
found = None
if re.match(r'\d+$', node_id):
found = [n for n in nodes if node_id == n['id']]
if not found:
found = [n for n in nodes if node_id == get_dict_value(n, '_meta.title', '')]
return found
def _get_node(self, node_id: Union[int, str, re.Pattern, None] = None) -> Union[Any, None]:
"""Returns a prompt-node from the hidden prompt."""
if node_id is None:
return self._get_prompt_node()
nodes = self._get_nodes(node_id)
if nodes and len(nodes) > 1:
log_node_warn(_NODE_NAME, f"More than one node found for '{node_id}'. Returning first.")
return nodes[0] if nodes else None
def _get_input_node(self, input_name, node=None):
"""Gets the (non-muted) node of an input connection from a node (default to the power puter)."""
node = node if node else self._get_prompt_node()
try:
connected_node_id = node['inputs'][input_name][0]
return [n for n in self._get_prompt_nodes() if n['id'] == connected_node_id][0]
except (TypeError, IndexError, KeyError):
log_node_warn(_NODE_NAME, f'No input node found for "{input_name}". ')
return None
def _eval_statement(self, stmt: ast.AST, ctx: dict, prev_stmt: Union[ast.AST, None] = None):
"""Evaluates an ast.stmt."""
if '__returned__' in ctx:
return ctx['__returned__']
# print('\n\n----: _eval_statement')
# print(type(stmt))
# print(ctx)
if isinstance(stmt, (ast.FormattedValue, ast.Expr)):
return self._eval_statement(stmt.value, ctx=ctx)
if isinstance(stmt, (ast.Constant, ast.Num)):
return stmt.n
if isinstance(stmt, ast.BinOp):
left = self._eval_statement(stmt.left, ctx=ctx)
right = self._eval_statement(stmt.right, ctx=ctx)
return _OPERATORS[type(stmt.op)](left, right)
if isinstance(stmt, ast.BoolOp):
is_and = isinstance(stmt.op, ast.And)
is_or = isinstance(stmt.op, ast.Or)
stmt_value_eval = None
for stmt_value in stmt.values:
stmt_value_eval = self._eval_statement(stmt_value, ctx=ctx)
# If we're an and operator and have a falsyt value, then we stop and return. Likewise, if
# we're an or operator and have a truthy value, we can stop and return.
if (is_and and not stmt_value_eval) or (is_or and stmt_value_eval):
return stmt_value_eval
# Always return the last if we made it here w/o success.
return stmt_value_eval
if isinstance(stmt, ast.UnaryOp):
return _OPERATORS[type(stmt.op)](self._eval_statement(stmt.operand, ctx=ctx))
if isinstance(stmt, (ast.Attribute, ast.Subscript)):
# Like: node(14).inputs.sampler_name (Attribute)
# Like: node(14)['inputs']['sampler_name'] (Subscript)
item = self._eval_statement(stmt.value, ctx=ctx)
attr = None
# if hasattr(stmt, 'attr'):
if isinstance(stmt, ast.Attribute):
attr = stmt.attr
else:
# Slice could be a name or a constant; evaluate it
attr = self._eval_statement(stmt.slice, ctx=ctx)
# Check if we're blocking access to this attribute/method on this item type.
for typ, names in _BLOCKED_METHODS_OR_ATTRS.items():
if isinstance(item, typ) and isinstance(attr, str) and attr in names:
raise ValueError(f'Disallowed access to "{attr}" for type {typ}.')
try:
val = item[attr]
except (TypeError, IndexError, KeyError):
try:
val = getattr(item, attr)
except AttributeError:
# If we're a dict, then just return None instead of error; saves time.
if isinstance(item, dict):
# Any special cases in the _SPECIAL_FUNCTIONS
class_type = get_dict_value(item, "class_type")
if class_type in _SPECIAL_FUNCTIONS and attr in _SPECIAL_FUNCTIONS[class_type]:
val = _SPECIAL_FUNCTIONS[class_type][attr]
# If our previous statment was a Call, then send back a tuple of the callable and
# the evaluated item, and it will make the call; perhaps also adding other arguments
# only it knows about.
if isinstance(prev_stmt, ast.Call):
return (val, item)
val = val(item)
else:
val = None
else:
raise
return val
if isinstance(stmt, (ast.List, ast.Tuple)):
value = []
for elt in stmt.elts:
value.append(self._eval_statement(elt, ctx=ctx))
return tuple(value) if isinstance(stmt, ast.Tuple) else value
if isinstance(stmt, ast.Dict):
the_dict = {}
if stmt.keys:
if len(stmt.keys) != len(stmt.values):
raise ValueError('Expected same number of keys as values for dict.')
for i, k in enumerate(stmt.keys):
item_key = self._eval_statement(k, ctx=ctx)
item_value = self._eval_statement(stmt.values[i], ctx=ctx)
the_dict[item_key] = item_value
return the_dict
# f-strings: https://www.basicexamples.com/example/python/ast-JoinedStr
# Note, this will str() all evaluated items in the fstrings, and doesn't handle f-string
# directives, like padding, etc.
if isinstance(stmt, ast.JoinedStr):
vals = [str(self._eval_statement(v, ctx=ctx)) for v in stmt.values]
val = ''.join(vals)
return val
if isinstance(stmt, ast.Slice):
if not stmt.lower or not stmt.upper:
raise ValueError('Unhandled Slice w/o lower or upper.')
slice_lower = self._eval_statement(stmt.lower, ctx=ctx)
slice_upper = self._eval_statement(stmt.upper, ctx=ctx)
if stmt.step:
slice_step = self._eval_statement(stmt.step, ctx=ctx)
return slice(slice_lower, slice_upper, slice_step)
return slice(slice_lower, slice_upper)
if isinstance(stmt, ast.Name):
if stmt.id in ctx:
val = ctx[stmt.id]
return val
if stmt.id in _BUILT_INS:
val = _BUILT_INS[stmt.id]
return val
raise NameError(f"Name not found: {stmt.id}")
if isinstance(stmt, ast.For):
for_iter = self._eval_statement(stmt.iter, ctx=ctx)
for item in for_iter:
# Set the for var(s)
if isinstance(stmt.target, ast.Name):
ctx[stmt.target.id] = item
elif isinstance(stmt.target, ast.Tuple): # dict, like `for k, v in d.entries()`
for i, elt in enumerate(stmt.target.elts):
ctx[elt.id] = item[i]
bodies = stmt.body if isinstance(stmt.body, list) else [stmt.body]
breaked = False
for body in bodies:
# Catch any breaks or continues and handle inside the loop normally.
try:
value = self._eval_statement(body, ctx=ctx)
except (LoopBreak, LoopContinue) as e:
breaked = isinstance(e, LoopBreak)
break
if breaked:
break
return None
if isinstance(stmt, ast.While):
while self._eval_statement(stmt.test, ctx=ctx):
bodies = stmt.body if isinstance(stmt.body, list) else [stmt.body]
breaked = False
for body in bodies:
# Catch any breaks or continues and handle inside the loop normally.
try:
value = self._eval_statement(body, ctx=ctx)
except (LoopBreak, LoopContinue) as e:
breaked = isinstance(e, LoopBreak)
break
if breaked:
break
return None
if isinstance(stmt, ast.ListComp):
# Like: [v.lora for name, v in node(19).inputs.items() if name.startswith('lora_')]
# Like: [v.lower() for v in lora_list]
# Like: [v for v in l if v.startswith('B')]
# Like: [v.lower() for v in l if v.startswith('B') or v.startswith('F')]
# ---
# Like: [l for n in nodes(re('Loras')).values() if (l := n.loras)]
final_list = []
gen_ctx = {**ctx}
generators = [*stmt.generators]
def handle_gen(generators: list[ast.comprehension]):
gen = generators.pop(0)
if isinstance(gen.target, ast.Name):
gen_ctx[gen.target.id] = None
elif isinstance(gen.target, ast.Tuple): # dict, like `for k, v in d.entries()`
for elt in gen.target.elts:
gen_ctx[elt.id] = None
else:
raise ValueError('Na')
gen_iters = None
# A call, like my_dct.items(), or a named ctx list
if isinstance(gen.iter, ast.Call):
gen_iters = self._eval_statement(gen.iter, ctx=gen_ctx)
elif isinstance(gen.iter, (ast.Name, ast.Attribute, ast.List, ast.Tuple)):
gen_iters = self._eval_statement(gen.iter, ctx=gen_ctx)
if not isinstance(gen_iters, Iterable):
raise ValueError('No iteraors found for list comprehension')
for gen_iter in gen_iters:
if_ctx = {**gen_ctx}
if isinstance(gen.target, ast.Tuple): # dict, like `for k, v in d.entries()`
for i, elt in enumerate(gen.target.elts):
if_ctx[elt.id] = gen_iter[i]
else:
if_ctx[gen.target.id] = gen_iter
good = True
for ifcall in gen.ifs:
if not self._eval_statement(ifcall, ctx=if_ctx):
good = False
break
if not good:
continue
gen_ctx.update(if_ctx)
if len(generators):
handle_gen(generators)
else:
final_list.append(self._eval_statement(stmt.elt, gen_ctx))
generators.insert(0, gen)
handle_gen(generators)
return final_list
if isinstance(stmt, ast.Call):
call = None
args = []
kwargs = {}
if isinstance(stmt.func, ast.Attribute):
call = self._eval_statement(stmt.func, prev_stmt=stmt, ctx=ctx)
if isinstance(call, tuple):
args.append(call[1])
call = call[0]
if not call:
raise ValueError(f'No call for ast.Call {stmt.func}')
name = ''
if isinstance(stmt.func, ast.Name):
name = stmt.func.id
if name in _BUILT_INS:
call = _BUILT_INS[name]
if isinstance(call, str) and call.startswith(_BUILTIN_FN_PREFIX):
fn = _get_built_in_fn_by_key(call)
call = fn.call
if isinstance(call, str):
call = getattr(self, call)
num_args = len(stmt.args)
if num_args < fn.args[0] or (fn.args[1] is not None and num_args > fn.args[1]):
toErr = " or more" if fn.args[1] is None else f" to {fn.args[1]}"
raise SyntaxError(f"Invalid function call: {fn.name} requires {fn.args[0]}{toErr} args")
if not call:
raise ValueError(f'No call for ast.Call {name}')
for arg in stmt.args:
args.append(self._eval_statement(arg, ctx=ctx))
for kwarg in stmt.keywords:
kwargs[kwarg.arg] = self._eval_statement(kwarg.value, ctx=ctx)
return call(*args, **kwargs)
if isinstance(stmt, ast.Compare):
l = self._eval_statement(stmt.left, ctx=ctx)
r = self._eval_statement(stmt.comparators[0], ctx=ctx)
if isinstance(stmt.ops[0], ast.Eq):
return 1 if l == r else 0
if isinstance(stmt.ops[0], ast.NotEq):
return 1 if l != r else 0
if isinstance(stmt.ops[0], ast.Gt):
return 1 if l > r else 0
if isinstance(stmt.ops[0], ast.GtE):
return 1 if l >= r else 0
if isinstance(stmt.ops[0], ast.Lt):
return 1 if l < r else 0
if isinstance(stmt.ops[0], ast.LtE):
return 1 if l <= r else 0
if isinstance(stmt.ops[0], ast.In):
return 1 if l in r else 0
if isinstance(stmt.ops[0], ast.Is):
return 1 if l is r else 0
if isinstance(stmt.ops[0], ast.IsNot):
return 1 if l is not r else 0
raise NotImplementedError("Operator " + stmt.ops[0].__class__.__name__ + " not supported.")
if isinstance(stmt, (ast.If, ast.IfExp)):
value = self._eval_statement(stmt.test, ctx=ctx)
if value:
# ast.If is a list, ast.IfExp is an object.
bodies = stmt.body if isinstance(stmt.body, list) else [stmt.body]
for body in bodies:
value = self._eval_statement(body, ctx=ctx)
elif stmt.orelse:
# ast.If is a list, ast.IfExp is an object. TBH, I don't know why the If is a list, it's
# only ever one item AFAICT.
orelses = stmt.orelse if isinstance(stmt.orelse, list) else [stmt.orelse]
for orelse in orelses:
value = self._eval_statement(orelse, ctx=ctx)
return value
# Assign a variable and add it to our ctx.
if isinstance(stmt, (ast.Assign, ast.AugAssign)):
if isinstance(stmt, ast.AugAssign):
left = self._eval_statement(stmt.target, ctx=ctx)
right = self._eval_statement(stmt.value, ctx=ctx)
value = _OPERATORS[type(stmt.op)](left, right)
target = stmt.target
else:
value = self._eval_statement(stmt.value, ctx=ctx)
if len(stmt.targets) != 1:
raise ValueError('Expected length of assign targets to be 1')
target = stmt.targets[0]
if isinstance(target, ast.Tuple): # like `a, z = (1,2)` (ast.Assign only)
for i, elt in enumerate(target.elts):
ctx[elt.id] = value[i]
elif isinstance(target, ast.Name): # like `a = 1``
ctx[target.id] = value
elif isinstance(target, ast.Subscript) and isinstance(target.value, ast.Name): # `a[0] = 1`
ctx[target.value.id][self._eval_statement(target.slice, ctx=ctx)] = value
else:
raise ValueError('Unhandled target type for Assign.')
return value
# For assigning a var in a list comprehension.
# Like [name for node in node_list if (name := node.name)]
if isinstance(stmt, ast.NamedExpr):
value = self._eval_statement(stmt.value, ctx=ctx)
ctx[stmt.target.id] = value
return value
if isinstance(stmt, ast.Return):
if stmt.value is None:
value = None
else:
value = self._eval_statement(stmt.value, ctx=ctx)
# Mark that we have a return value, as we may be deeper in evaluation, like going through an
# if condition's body.
ctx['__returned__'] = value
return value
# Raise an error for break or continue, which should be caught and handled inside of loops,
# otherwise the error will be raised (which is desired when used outside of a loop).
if isinstance(stmt, ast.Break):
raise LoopBreak()
if isinstance(stmt, ast.Continue):
raise LoopContinue()
# Literally nothing.
if isinstance(stmt, ast.Pass):
return None
raise TypeError(stmt)