[Weight-adapter/Trainer] Bypass forward mode in Weight adapter system (#11958)
* Add API of bypass forward module * bypass implementation * add bypass fwd into nodes list/trainer
This commit is contained in:
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.model_management
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
@@ -14,7 +15,17 @@ from .base import (
|
||||
class LokrDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||
(
|
||||
lokr_w1,
|
||||
lokr_w2,
|
||||
alpha,
|
||||
lokr_w1_a,
|
||||
lokr_w1_b,
|
||||
lokr_w2_a,
|
||||
lokr_w2_b,
|
||||
lokr_t2,
|
||||
dora_scale,
|
||||
) = weights
|
||||
self.use_tucker = False
|
||||
if lokr_w1_a is not None:
|
||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||
@@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase):
|
||||
if self.w2_rebuild:
|
||||
if self.use_tucker:
|
||||
w2 = torch.einsum(
|
||||
'i j k l, j r, i p -> p r k l',
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
self.lokr_t2,
|
||||
self.lokr_w2_b,
|
||||
self.lokr_w2_a
|
||||
self.lokr_w2_a,
|
||||
)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
@@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase):
|
||||
return self.lokr_w2
|
||||
|
||||
def __call__(self, w):
|
||||
diff = torch.kron(self.w1, self.w2)
|
||||
w1 = self.w1
|
||||
w2 = self.w2
|
||||
# Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron)
|
||||
for _ in range(w2.dim() - w1.dim()):
|
||||
w1 = w1.unsqueeze(-1)
|
||||
diff = torch.kron(w1, w2)
|
||||
return w + diff.reshape(w.shape).to(w)
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoKr training: efficient Kronecker product.
|
||||
|
||||
Uses w1/w2 properties which handle both direct and decomposed cases.
|
||||
For create_train (direct w1/w2), no alpha scaling in properties.
|
||||
For to_train (decomposed), alpha/rank scaling is in properties.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
"""
|
||||
# Get w1, w2 from properties (handles rebuild vs direct)
|
||||
w1 = self.w1
|
||||
w2 = self.w2
|
||||
|
||||
# Multiplier from bypass injection
|
||||
multiplier = getattr(self, "multiplier", 1.0)
|
||||
|
||||
# Get module info from bypass injection
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {})
|
||||
|
||||
# Efficient Kronecker application without materializing full weight
|
||||
# kron(w1, w2) @ x can be computed as nested operations
|
||||
# w1: [out_l, in_m], w2: [out_k, in_n, *k_size]
|
||||
# Full weight would be [out_l*out_k, in_m*in_n, *k_size]
|
||||
|
||||
uq = w1.size(1) # in_m - inner grouping dimension
|
||||
|
||||
if is_conv:
|
||||
conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1]
|
||||
|
||||
B, C_in, *spatial = x.shape
|
||||
# Reshape input for grouped application: [B * uq, C_in // uq, *spatial]
|
||||
h_in_group = x.reshape(B * uq, -1, *spatial)
|
||||
|
||||
# Ensure w2 has conv dims
|
||||
if w2.dim() == 2:
|
||||
w2 = w2.view(*w2.shape, *([1] * conv_dim))
|
||||
|
||||
# Apply w2 path with stride/padding
|
||||
hb = conv_fn(h_in_group, w2, **kw_dict)
|
||||
|
||||
# Reshape for cross-group operation
|
||||
hb = hb.view(B, -1, *hb.shape[1:])
|
||||
h_cross = hb.transpose(1, -1)
|
||||
|
||||
# Apply w1 (always 2D, applied as linear on channel dim)
|
||||
hc = F.linear(h_cross, w1)
|
||||
hc = hc.transpose(1, -1)
|
||||
|
||||
# Reshape to output
|
||||
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||
else:
|
||||
# Linear case
|
||||
# Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n]
|
||||
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||
|
||||
# Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k]
|
||||
hb = F.linear(h_in_group, w2)
|
||||
|
||||
# Transpose for w1: [..., uq, out_k] -> [..., out_k, uq]
|
||||
h_cross = hb.transpose(-1, -2)
|
||||
|
||||
# Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l]
|
||||
hc = F.linear(h_cross, w1)
|
||||
|
||||
# Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k]
|
||||
hc = hc.transpose(-1, -2)
|
||||
out = hc.reshape(*hc.shape[:-2], -1)
|
||||
|
||||
return out * multiplier
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
@@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32)
|
||||
in_dim = weight.shape[1] # Just in_channels, not flattened with kernel
|
||||
k_size = weight.shape[2:] if weight.dim() > 2 else ()
|
||||
|
||||
out_l, out_k = factorization(out_dim, rank)
|
||||
in_m, in_n = factorization(in_dim, rank)
|
||||
|
||||
# w1: [out_l, in_m]
|
||||
mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32)
|
||||
# w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear
|
||||
mat2 = torch.empty(
|
||||
out_k, in_n, *k_size, device=weight.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None))
|
||||
|
||||
def to_train(self):
|
||||
return LokrDiff(self.weights)
|
||||
@@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
||||
if (
|
||||
(lokr_w1 is not None)
|
||||
or (lokr_w2 is not None)
|
||||
or (lokr_w1_a is not None)
|
||||
or (lokr_w2_a is not None)
|
||||
):
|
||||
weights = (
|
||||
lokr_w1,
|
||||
lokr_w2,
|
||||
alpha,
|
||||
lokr_w1_a,
|
||||
lokr_w1_b,
|
||||
lokr_w2_a,
|
||||
lokr_w2_b,
|
||||
lokr_t2,
|
||||
dora_scale,
|
||||
)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
@@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
w1 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w1_a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w1_b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
w1 = comfy.model_management.cast_to_device(
|
||||
w1, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
w2 = torch.mm(
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_a, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_b, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
w2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l",
|
||||
comfy.model_management.cast_to_device(
|
||||
t2, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_b, weight.device, intermediate_dtype
|
||||
),
|
||||
comfy.model_management.cast_to_device(
|
||||
w2_a, weight.device, intermediate_dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
w2 = comfy.model_management.cast_to_device(
|
||||
w2, weight.device, intermediate_dtype
|
||||
)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
@@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
weight,
|
||||
lora_diff,
|
||||
alpha,
|
||||
strength,
|
||||
intermediate_dtype,
|
||||
function,
|
||||
)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
||||
def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Additive bypass component for LoKr: efficient Kronecker product application.
|
||||
|
||||
Note:
|
||||
Does not access original model weights - bypass mode is designed
|
||||
for quantized models where weights may not be accessible.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
base_out: Output from base forward (unused, for API consistency)
|
||||
|
||||
Reference: LyCORIS functional/lokr.py bypass_forward_diff
|
||||
"""
|
||||
# FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d]
|
||||
|
||||
v = self.weights
|
||||
# v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
alpha = v[2]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
|
||||
use_w1 = w1 is not None
|
||||
use_w2 = w2 is not None
|
||||
tucker = t2 is not None
|
||||
|
||||
# Use module info from bypass injection, not weight dimension
|
||||
is_conv = getattr(self, "is_conv", False)
|
||||
conv_dim = getattr(self, "conv_dim", 0)
|
||||
kw_dict = getattr(self, "kw_dict", {}) if is_conv else {}
|
||||
|
||||
if is_conv:
|
||||
op = FUNC_LIST[conv_dim + 2]
|
||||
else:
|
||||
op = F.linear
|
||||
|
||||
# Determine rank and scale
|
||||
rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha
|
||||
scale = (alpha / rank if alpha is not None else 1.0) * getattr(
|
||||
self, "multiplier", 1.0
|
||||
)
|
||||
|
||||
# Build c (w1)
|
||||
if use_w1:
|
||||
c = w1.to(dtype=x.dtype)
|
||||
else:
|
||||
c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype)
|
||||
uq = c.size(1)
|
||||
|
||||
# Build w2 components
|
||||
if use_w2:
|
||||
ba = w2.to(dtype=x.dtype)
|
||||
else:
|
||||
a = w2_b.to(dtype=x.dtype)
|
||||
b = w2_a.to(dtype=x.dtype)
|
||||
if is_conv:
|
||||
if tucker:
|
||||
# Tucker: a, b get 1s appended (kernel is in t2)
|
||||
if a.dim() == 2:
|
||||
a = a.view(*a.shape, *([1] * conv_dim))
|
||||
if b.dim() == 2:
|
||||
b = b.view(*b.shape, *([1] * conv_dim))
|
||||
else:
|
||||
# Non-tucker conv: b may need 1s appended
|
||||
if b.dim() == 2:
|
||||
b = b.view(*b.shape, *([1] * conv_dim))
|
||||
|
||||
# Reshape input by uq groups
|
||||
if is_conv:
|
||||
B, _, *rest = x.shape
|
||||
h_in_group = x.reshape(B * uq, -1, *rest)
|
||||
else:
|
||||
h_in_group = x.reshape(*x.shape[:-1], uq, -1)
|
||||
|
||||
# Apply w2 path
|
||||
if use_w2:
|
||||
hb = op(h_in_group, ba, **kw_dict)
|
||||
else:
|
||||
if is_conv:
|
||||
if tucker:
|
||||
t = t2.to(dtype=x.dtype)
|
||||
if t.dim() == 2:
|
||||
t = t.view(*t.shape, *([1] * conv_dim))
|
||||
ha = op(h_in_group, a)
|
||||
ht = op(ha, t, **kw_dict)
|
||||
hb = op(ht, b)
|
||||
else:
|
||||
ha = op(h_in_group, a, **kw_dict)
|
||||
hb = op(ha, b)
|
||||
else:
|
||||
ha = op(h_in_group, a)
|
||||
hb = op(ha, b)
|
||||
|
||||
# Reshape and apply c (w1)
|
||||
if is_conv:
|
||||
hb = hb.view(B, -1, *hb.shape[1:])
|
||||
h_cross_group = hb.transpose(1, -1)
|
||||
else:
|
||||
h_cross_group = hb.transpose(-1, -2)
|
||||
|
||||
hc = F.linear(h_cross_group, c)
|
||||
|
||||
if is_conv:
|
||||
hc = hc.transpose(1, -1)
|
||||
out = hc.reshape(B, -1, *hc.shape[3:])
|
||||
else:
|
||||
hc = hc.transpose(-1, -2)
|
||||
out = hc.reshape(*hc.shape[:-2], -1)
|
||||
|
||||
return out * scale
|
||||
|
||||
Reference in New Issue
Block a user