Add custom nodes, Civitai loras (LFS), and vast.ai setup script
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled

Includes 30 custom nodes committed directly, 7 Civitai-exclusive
loras stored via Git LFS, and a setup script that installs all
dependencies and downloads HuggingFace-hosted models on vast.ai.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 00:55:26 +00:00
parent 2b70ab9ad0
commit f09734b0ee
2274 changed files with 748556 additions and 3 deletions

View File

@@ -0,0 +1,162 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,262 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.flux.layers import timestep_embedding
import comfy
from .patch_util import PatchKeys
def set_model_dit_patch_replace(model, patch_kwargs, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
else:
to["patches_replace"] = to["patches_replace"]
if "dit" not in to["patches_replace"]:
to["patches_replace"]["dit"] = {}
else:
to["patches_replace"]["dit"] = to["patches_replace"]["dit"]
if key not in to["patches_replace"]["dit"]:
if "double_block" in key:
if key == ("double_block", 18):
to["patches_replace"]["dit"][key] = LastDitDoubleBlockReplace(pulid_patch, **patch_kwargs)
else:
to["patches_replace"]["dit"][key] = DitDoubleBlockReplace(pulid_patch, **patch_kwargs)
else:
to["patches_replace"]["dit"][key] = DitSingleBlockReplace(pulid_patch, **patch_kwargs)
# model.model_options["transformer_options"] = to
else:
to["patches_replace"]["dit"][key].add(pulid_patch, **patch_kwargs)
def pulid_patch(img, pulid_model=None, ca_idx=None, weight=1.0, embedding=None, mask=None, transformer_options={}):
pulid_img = weight * pulid_model.model.pulid_ca[ca_idx].to(img.device)(embedding, img)
if mask is not None:
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
latent_image_shape = pulid_temp_attrs.get("latent_image_shape", None)
if latent_image_shape is not None:
bs, c, h, w = latent_image_shape
mask = comfy.sampler_helpers.prepare_mask(mask, (bs, c, h, w), img.device)
patch_size = transformer_options[PatchKeys.running_net_model].patch_size
mask = comfy.ldm.common_dit.pad_to_patch_size(mask, (patch_size, patch_size))
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
# (b, seq_len, _) =>(b, seq_len, pulid.dim)
mask = mask[..., 0].unsqueeze(-1).repeat(1, 1, pulid_img.shape[-1]).to(dtype=pulid_img.dtype)
del patch_size, latent_image_shape
pulid_img = pulid_img * mask
del mask, pulid_temp_attrs
return pulid_img
class DitDoubleBlockReplace:
def __init__(self, callback, **kwargs):
self.callback = [callback]
self.kwargs = [kwargs]
def add(self, callback, **kwargs):
self.callback.append(callback)
self.kwargs.append(kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
def __call__(self, input_args, extra_options):
transformer_options = extra_options["transformer_options"]
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
sigma = pulid_temp_attrs["timesteps"][0].detach().cpu().item()
out = extra_options["original_block"](input_args)
img = out['img']
temp_img = img
for i, callback in enumerate(self.callback):
if self.kwargs[i]["sigma_start"] >= sigma >= self.kwargs[i]["sigma_end"]:
img = img + callback(temp_img,
pulid_model=self.kwargs[i]['pulid_model'],
ca_idx=self.kwargs[i]['ca_idx'],
weight=self.kwargs[i]['weight'],
embedding=self.kwargs[i]['embedding'],
mask = self.kwargs[i]['mask'],
transformer_options=transformer_options
)
out['img'] = img
del temp_img, pulid_temp_attrs, sigma, transformer_options, img
return out
class LastDitDoubleBlockReplace(DitDoubleBlockReplace):
def __call__(self, input_args, extra_options):
out = super().__call__(input_args, extra_options)
transformer_options = extra_options["transformer_options"]
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
pulid_temp_attrs["double_blocks_txt"] = out['txt']
return out
class DitSingleBlockReplace:
def __init__(self, callback, **kwargs):
self.callback = [callback]
self.kwargs = [kwargs]
def add(self, callback, **kwargs):
self.callback.append(callback)
self.kwargs.append(kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
def __call__(self, input_args, extra_options):
transformer_options = extra_options["transformer_options"]
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
out = extra_options["original_block"](input_args)
sigma = pulid_temp_attrs["timesteps"][0].detach().cpu().item()
img = out['img']
txt = pulid_temp_attrs['double_blocks_txt']
real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...]
temp_img = real_img
for i, callback in enumerate(self.callback):
if self.kwargs[i]["sigma_start"] >= sigma >= self.kwargs[i]["sigma_end"]:
real_img = real_img + callback(temp_img,
pulid_model=self.kwargs[i]['pulid_model'],
ca_idx=self.kwargs[i]['ca_idx'],
weight=self.kwargs[i]['weight'],
embedding=self.kwargs[i]['embedding'],
mask=self.kwargs[i]['mask'],
transformer_options = transformer_options,
)
img = torch.cat((txt, real_img), 1)
out['img'] = img
del temp_img, pulid_temp_attrs, sigma, transformer_options, real_img, img
return out
def pulid_forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control = None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
transformer_options[PatchKeys.running_net_model] = self
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
# 0 -> 18
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
},
{
"original_block": block_wrap,
"transformer_options": transformer_options
})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
# 0 -> 37
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask
},
{
"original_block": block_wrap,
"transformer_options": transformer_options
})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1]:, ...] += add
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
del transformer_options[PatchKeys.running_net_model]
return img
def pulid_enter(img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask, transformer_options):
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
pulid_temp_attrs['timesteps'] = timesteps
return img, img_ids, txt, txt_ids, timesteps, y, guidance, control, attn_mask
def pulid_patch_double_blocks_after(img, txt, transformer_options):
pulid_temp_attrs = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
pulid_temp_attrs['double_blocks_txt'] = txt
return img, txt

View File

@@ -0,0 +1,155 @@
[中文文档](README_CN.md)
- Solved [ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux) model pollution problem.
- **🆕 Commercial-friendly FaceNet implementation** - Alternative to InsightFace for commercial usage without licensing restrictions.
- Supported use with `TeaCache` (Need use with [ComfyUI_Patches_ll](https://github.com/lldacing/ComfyUI_Patches_ll)).
- Supported use with [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed), supported by `Comfy-WaveSpeed` in [commit-36ba3c8](https://github.com/chengzeyi/Comfy-WaveSpeed/commit/36ba3c8b74735d4521828507a4bf323df1a9a9d0).
- Supported simple use with `First Block Cache` (Can use with [ComfyUI_Patches_ll](https://github.com/lldacing/ComfyUI_Patches_ll)).
Must uninstall or disable `ComfyUI-PuLID-Flux` and other PuLID-Flux nodes before install this plugin. Due to certain reasons, I used the same node's name `ApplyPulidFlux`.
Need upgrade ComfyUI Version>=0.3.7
## 🏢 Commercial Usage Ready
This plugin now supports **FaceNet-based face analysis** as an alternative to InsightFace, making it suitable for commercial applications:
- **No ArcFace licensing restrictions** - FaceNet is freely available for commercial use
- **Compatible API** - Drop-in replacement for InsightFace workflows
- **Production ready** - Reliable face detection and embedding generation
Simply use `PulidFluxFaceNetLoader` instead of `PulidFluxInsightFaceLoader` in your workflows.
## Update logs
### 2025.02.19
- Fix: when selecting a face from multiple faces as a reference, embeddings and alignment features maybe not from the same face.
### 2025.02.18
- Supported selecting a face from multiple faces as a reference. [Example workflow](examples/PuLID_select_ref_face.png).
### 2025.01.27
- Changed the model path of facexlib to `ComfyUI/models/facexlib/`.
- When automatically downloading, modify the path of Antelope v2 model to `ComfyUI/models/insightface/models/antelopev2/`.
- Changed the model path of EVA_CLIP_L_14_336 to `ComfyUI/models/clip/`.
## Preview (Image with WorkFlow)
![save api extended](examples/PuLID_with_speedup.png)
![save api extended](examples/PuLID_with_attn_mask.png)
![save api extended](examples/PuLID_select_ref_face.png)
## Install
- Manual
```shell
cd custom_nodes
git clone https://github.com/lldacing/ComfyUI_PuLID_Flux_ll.git
cd ComfyUI_PuLID_Flux_ll
pip install -r requirements.txt
# restart ComfyUI
```
Tips:
- If you use `ComfyUI_windows_portable` and encounter the following error, please see https://github.com/deepinsight/insightface/issues/2576
```
insightface/thirdparty/face3d/mesh/cython/mesh_core_cython.cpp(36): fatal error C1083: 无法打开包括文件: "Python.h": No such file or directory
error: command 'd:\\installed\\Microsoft Visual Studio\\2022\\BuildTools\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX86\\x64\\cl.exe' failed with exit code 2
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for insightface
Failed to build insightface
```
## Models
### Available Flux models
- 32bit/16bit (~22GB VRAM): [model](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors), [encoder](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors)
- 8bit gguf (~12GB VRAM): [model](https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf), [encoder](https://huggingface.co/city96/t5-v1_1-xxl-encoder-gguf/blob/main/t5-v1_1-xxl-encoder-Q8_0.gguf)
- 8 bit FP8 e5m2 (~12GB VRAM): [model](https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8-e5m2.safetensors), [encoder](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors)
- 8 bit FP8 e4m3fn (~12GB VRAM): [model](https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8-e4m3fn.safetensors), [encoder](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors)
- Clip and VAE (for all models): [clip](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors), [vae](https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/ae.safetensors)
#### For GGUF models you will need to install [ComfyUI-GGUF](https://github.com/city96/ComfyUI-GGUF)
### PuLID models
- Download [PuLID-Flux](https://huggingface.co/guozinan/PuLID/resolve/main/pulid_flux_v0.9.1.safetensors?download=true) => `ComfyUI/models/pulid/`.
- (Support auto-download) Download [EVA02-CLIP-L-14-336](https://huggingface.co/QuanSun/EVA-CLIP/blob/main/EVA02_CLIP_L_336_psz14_s6B.pt?download=true) => `ComfyUI/models/clip/`.
### Face Analysis Models
**For Commercial Use (Recommended):**
- **FaceNet** - No additional downloads required! Uses `facenet-pytorch` with pre-trained VGGFace2 models
- ✅ Commercial license friendly
- ✅ No external model downloads
- ✅ Automatic model loading
- Use with `PulidFluxFaceNetLoader` node
**For Non-Commercial/Research Use:**
- (Support auto-download) Download all models like `*.onnx` from [AntelopeV2](https://huggingface.co/MonsterMMORPG/tools/tree/main) => `ComfyUI/models/insightface/models/antelopev2/`.
- (Support auto-download) Download [parsing_bisenet](https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth), [parsing_parsenet](https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth) and [Resnet50](https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth) => `ComfyUI/models/facexlib/`.
- Use with `PulidFluxInsightFaceLoader` node
## Nodes
- PulidFluxModelLoader
- See chapter [PuLID models](#pulid-models)
- **PulidFluxFaceNetLoader** 🆕
- **Commercial-friendly face analysis** using FaceNet
- No additional model downloads required
- Compatible with all PuLID workflows
- Supports CPU and CUDA execution
- PulidFluxInsightFaceLoader
- Traditional InsightFace-based analysis
- See chapter [PuLID models](#pulid-models)
- PulidFluxEvaClipLoader
- See chapter [PuLID models](#pulid-models)
- ApplyPulidFlux
- Solved the model pollution problem of the original plugin ComfyUI-PuLID-Flux
- **Works with both FaceNet and InsightFace** - seamless compatibility
- `attn_mask` ~~may not work correctly (I have no idea how to apply it, I have tried multiple methods and the results have been not satisfactory)~~ works now.
- If you want use with [TeaCache](https://github.com/ali-vilab/TeaCache), must put it before node [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll).
- If you want use with [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed), must put it before node `ApplyFBCacheOnModel`.
- FixPulidFluxPatch (Deprecated)
- If you want use with [TeaCache](https://github.com/ali-vilab/TeaCache), must link it after node `ApplyPulidFlux`, and link node [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll) after it.
- PulidFluxOptions
- `input_faces_order` - Sorting rule for detected bboxes.
- `left-right`: Sort the left boundary of bbox by column from left to right.
- `right-left`: Reverse order of `left-right` (Sort the left boundary of bbox by column from right to left).
- `top-bottom`: Sort the top boundary of bbox by row from top to bottom.
- `bottom-top`: Reverse order of `top-bottom` (Sort the top boundary of bbox by row from bottom to top).
- `small-large`: Sort the area of bbox from small to large.
- `large-small`: Sort the area of bbox from large to small.
- `input_faces_index` - The target index of the sorted bboxes.
- `input_faces_align_mode` - Choose the detection method for aligning facial features.
- `0`: Old version method, When there is a face in an image, the selected facial embedding amount and alignment features maybe not consistent.
- `1`: Keep the selected facial embedding amount and alignment features consistent.
- There is a slight difference between the two mode, with the `align_face` value of `1` resulting smaller area than the `embed_face` value of `0`.
- PulidFluxFaceDetector
- Can check the facial features applied in `ApplyPulidFlux`.
- **Works with both FaceNet and InsightFace** backends
- When `input_faces_align_mode = 0`, the `embed_face` and `align_face` should be the same face, but they are generated by different detectors, and the number detected may be not consistent, so they may be not the same face.
- When `input_faces_align_mode = 1`, the `embed_face` and `align_face` are always the same face, they are generated by same detectors.
- `face_bbox_image` - Draw the detected facial bounding box (the result of the `embed_face`'s detector).
## Usage Examples
### Commercial Workflow (FaceNet)
```
PulidFluxFaceNetLoader -> ApplyPulidFlux
```
### Research Workflow (InsightFace)
```
PulidFluxInsightFaceLoader -> ApplyPulidFlux
```
Both workflows produce compatible results and can be used interchangeably.
## Thanks
[ToTheBeginning/PuLID](https://github.com/ToTheBeginning/PuLID)
[ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux)
[TeaCache](https://github.com/ali-vilab/TeaCache)
[Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed)
[facenet-pytorch](https://github.com/timesler/facenet-pytorch) - For commercial-friendly face recognition

View File

@@ -0,0 +1,151 @@
[English](README.md)
- 解决插件 [ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux) 存在的模型污染问题。
- **🆕 支持商业友好的FaceNet实现** - 作为InsightFace的替代方案用于商业用途无许可证限制。
- 支持使用[TeaCache](https://github.com/ali-vilab/TeaCache)加速(`TeaCache`加速需要配合[ComfyUI_Patches_ll](https://github.com/lldacing/ComfyUI_Patches_ll)使用)。
- 支持使用[Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed)加速, Comfy-WaveSpeed在[提交记录-36ba3c8](https://github.com/chengzeyi/Comfy-WaveSpeed/commit/36ba3c8b74735d4521828507a4bf323df1a9a9d0)中提供支持。
- 支持使用简单的`First Block Cache`加速(可以配合[ComfyUI_Patches_ll](https://github.com/lldacing/ComfyUI_Patches_ll)使用)。
在安装此插件之前,必须卸载或禁用`ComfyUI-PuLID-Flux`和其他PuLID Flux节点, 因为由于某些原因,我使用了同样的节点名`ApplyPulidFlux`
ComfyUI主体版本需要>=0.3.7
## 🏢 商业用途支持
该插件现已支持基于**FaceNet的人脸分析**作为InsightFace的替代方案使其适用于商业应用
- **无ArcFace许可证限制** - FaceNet可自由用于商业用途
- **兼容API** - 可直接替换InsightFace工作流
- **生产就绪** - 可靠的人脸检测和嵌入生成
只需在工作流中使用`PulidFluxFaceNetLoader`替代`PulidFluxInsightFaceLoader`即可。
## 更新日志
### 2025.02.19
- 解决多张人脸时选择的人脸嵌入量和对齐特征不是同一个人脸的问题。
### 2025.02.18
- 支持从含有多张脸的图片中选择一张脸作为参考。[示例工作流](examples/PuLID_select_ref_face.png).
### 2025.01.27
- 修改 facexlib 的模型路径为 `ComfyUI/models/facexlib/`.
- 自动下载时 修改 Antelopev2 模型的路径为 `ComfyUI/models/insightface/models/antelopev2/`.
- 修改 EVA_CLIP_L_14_336 的模型路径为 `ComfyUI/models/clip/`.
## 预览 (图片含工作流)
![save api extended](examples/PuLID_with_speedup.png)
![save api extended](examples/PuLID_with_attn_mask.png)
![save api extended](examples/PuLID_select_ref_face.png)
## 安装
- 手动
```shell
cd custom_nodes
git clone https://github.com/lldacing/ComfyUI_PuLID_Flux_ll.git
cd ComfyUI_PuLID_Flux_ll
pip install -r requirements.txt
# 重启 ComfyUI
```
安装问题:
- 如果使用`ComfyUI_windows_portable`并遇到以下错误, 请查看 https://github.com/deepinsight/insightface/issues/2576
```
insightface/thirdparty/face3d/mesh/cython/mesh_core_cython.cpp(36): fatal error C1083: 无法打开包括文件: "Python.h": No such file or directory
error: command 'd:\\installed\\Microsoft Visual Studio\\2022\\BuildTools\\VC\\Tools\\MSVC\\14.42.34433\\bin\\HostX86\\x64\\cl.exe' failed with exit code 2
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for insightface
Failed to build insightface
```
## 模型
### 可用的 Flux 模型
- 32bit/16bit (~22GB VRAM): [model](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors), [encoder](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors)
- 8bit gguf (~12GB VRAM): [model](https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf), [encoder](https://huggingface.co/city96/t5-v1_1-xxl-encoder-gguf/blob/main/t5-v1_1-xxl-encoder-Q8_0.gguf)
- 8 bit FP8 e5m2 (~12GB VRAM): [model](https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8-e5m2.safetensors), [encoder](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors)
- 8 bit FP8 e4m3fn (~12GB VRAM): [model](https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8-e4m3fn.safetensors), [encoder](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp8_e4m3fn.safetensors)
- Clip and VAE (for all models): [clip](https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors), [vae](https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/ae.safetensors)
#### 若使用 GGUF 需要安装 [ComfyUI-GGUF](https://github.com/city96/ComfyUI-GGUF)
### PuLID 模型
- 下载 [PuLID-Flux](https://huggingface.co/guozinan/PuLID/resolve/main/pulid_flux_v0.9.1.safetensors?download=true) 到目录 `ComfyUI/models/pulid/`
- (支持自动下载)下载 [EVA02-CLIP-L-14-336](https://huggingface.co/QuanSun/EVA-CLIP/blob/main/EVA02_CLIP_L_336_psz14_s6B.pt?download=true) 到目录 `ComfyUI/models/clip/`
### 人脸分析模型
**用于商业用途(推荐):**
- **FaceNet** - 无需额外下载使用带有预训练VGGFace2模型的`facenet-pytorch`
- ✅ 商业许可证友好
- ✅ 无需外部模型下载
- ✅ 自动模型加载
- 配合`PulidFluxFaceNetLoader`节点使用
**用于非商业/研究用途:**
- (支持自动下载)从 [AntelopeV2](https://huggingface.co/MonsterMMORPG/tools/tree/main) 下载所有`*.onnx`模型文件到目录 `ComfyUI/models/insightface/models/antelopev2/`.
- (支持自动下载)下载 [parsing_bisenet](https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth), [parsing_parsenet](https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth) and [Resnet50](https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth) 到目录 `ComfyUI/models/facexlib/`.
- 配合`PulidFluxInsightFaceLoader`节点使用
## 节点
- PulidFluxModelLoader
- **PulidFluxFaceNetLoader** 🆕
- **基于FaceNet的商业友好人脸分析**
- 无需额外模型下载
- 兼容所有PuLID工作流
- 支持CPU和CUDA执行
- PulidFluxInsightFaceLoader
- 传统的基于InsightFace的分析
- PulidFluxEvaClipLoader
- ApplyPulidFlux
- 解决了原插件中模型污染的问题
- **同时支持FaceNet和InsightFace** - 无缝兼容
- `attn_mask`~~可能不能正确工作, 因为我不知道如何实现它, 尝试了多种方式效果都未能达到预期~~,可以正常工作了。
- 使用 [TeaCache](https://github.com/ali-vilab/TeaCache)加速, 必须加在[`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll)之前.
- 使用 [Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed)加速, 必须加在[`ApplyFBCacheOnModel`](https://github.com/lldacing/ComfyUI_Patches_ll)之前.
- FixPulidFluxPatch (已弃用)
- 如果想使用 [TeaCache](https://github.com/ali-vilab/TeaCache)加速, 必须加在 `ApplyPulidFlux` 节点之后, 并在后面连接节点 [`FluxForwardOverrider` and `ApplyTeaCachePatch`](https://github.com/lldacing/ComfyUI_Patches_ll).
- PulidFluxOptions
- `input_faces_order` - 对检测到的脸部边界框的排序规则。
- `left-right`: 按列从左到右对bbox的左边界进行排序。
- `right-left`: `left-right`的倒序按列从右到左对bbox的左边界进行排序
- `top-bottom`: 按行从上到下对bbox的顶部边界进行排序。
- `bottom-top`: `top-bottom`的倒序按行从下到上对bbox的顶部边界进行排序
- `small-large`: 按bbox的面积从小到大排序。
- `large-small`: 按bbox的面积从大到小排序。
- `input_faces_index` - 从排序后的bbox选取的索引号。
- `input_faces_align_mode` - 选择对齐脸部特征的检测方式。
- `0`: 旧版本方式,一张图片中有张脸时选择的脸部嵌入量和对齐特征可能不一致。
- `1`: 保持选择的脸部嵌入量和对齐特征一致。
- 两种出图有细微差别,值`1``align_face`结果图比`0``embed_face`范围小一点。
- PulidFluxFaceDetector
- 用来检查在`ApplyPulidFlux`实际使用的面部特征。
- **同时支持FaceNet和InsightFace**后端
- `input_faces_align_mode = 0`时,`embed_face``align_face` 理论上应该是同一张脸,但它们由不同的检测器产生,可能检测到的数量不一致,因此两张图可能不是同一张脸。
- `input_faces_align_mode = 1`时,`embed_face``align_face` 由相同的检测器产生,两张图始终是同一张脸。
- `face_bbox_image` - 画出检测到的脸部边界框(`embed_face`的检测器结果)。
## 使用示例
### 商业工作流 (FaceNet)
```
PulidFluxFaceNetLoader -> ApplyPulidFlux
```
### 研究工作流 (InsightFace)
```
PulidFluxInsightFaceLoader -> ApplyPulidFlux
```
两种工作流产生兼容的结果,可以互换使用。
## 感谢
[ToTheBeginning/PuLID](https://github.com/ToTheBeginning/PuLID)
[ComfyUI-PuLID-Flux](https://github.com/balazik/ComfyUI-PuLID-Flux)
[TeaCache](https://github.com/ali-vilab/TeaCache)
[Comfy-WaveSpeed](https://github.com/chengzeyi/Comfy-WaveSpeed)
[facenet-pytorch](https://github.com/timesler/facenet-pytorch) - 用于商业友好的人脸识别

View File

@@ -0,0 +1,3 @@
from .pulidflux import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

View File

@@ -0,0 +1,213 @@
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttentionCA(nn.Module):
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
super().__init__()
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, mask=None):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, seq_len, _ = latents.shape
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
# if mask is not None:
# not sure
# weight.shape (bs, n_heads, seq_len, seq_len)
# mask.shape (bs, seq_len, _) -> (bs, 1, 1, seq_len)
# mask = mask[:,:, :1].view(b, 1, 1, -1)
# weight = weight.masked_fill(mask == 0, float('-inf'))
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
return self.to_out(out)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
super().__init__()
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, seq_len, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
return self.to_out(out)
class IDFormer(nn.Module):
"""
- perceiver resampler like arch (compared with previous MLP-like arch)
- we concat id embedding (generated by arcface) and query tokens as latents
- latents will attend each other and interact with vit features through cross-attention
- vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
IDFormer layers
"""
def __init__(
self,
dim=1024,
depth=10,
dim_head=64,
heads=16,
num_id_token=5,
num_queries=32,
output_dim=2048,
ff_mult=4,
):
super().__init__()
self.num_id_token = num_id_token
self.dim = dim
self.num_queries = num_queries
assert depth % 5 == 0
self.depth = depth // 5
scale = dim ** -0.5
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
for i in range(5):
setattr(
self,
f'mapping_{i}',
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, dim),
),
)
self.id_embedding_mapping = nn.Sequential(
nn.Linear(1280, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, dim * num_id_token),
)
def forward(self, x, y):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.id_embedding_mapping(x)
x = x.reshape(-1, self.num_id_token, self.dim)
latents = torch.cat((latents, x), dim=1)
for i in range(5):
vit_feature = getattr(self, f'mapping_{i}')(y[i])
ctx_feature = torch.cat((x, vit_feature), dim=1)
for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
latents = attn(ctx_feature, latents) + latents
latents = ff(latents) + latents
latents = latents[:, :self.num_queries]
latents = latents @ self.proj_out
return latents

View File

@@ -0,0 +1,11 @@
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .tokenizer import SimpleTokenizer, tokenize
from .transform import image_transform

View File

@@ -0,0 +1,2 @@
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

View File

@@ -0,0 +1,548 @@
# --------------------------------------------------------
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import os
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
except:
from timm.layers import drop_path, to_2tuple, trunc_normal_
from .transformer import PatchDropout
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
if os.getenv('ENV_TYPE') == 'deepspeed':
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
try:
import xformers
import xformers.ops as xops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
drop=0.,
subln=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.ffn_ln(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
norm_layer=nn.LayerNorm, subln=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.w3 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.subln = subln
if self.subln:
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
else:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
# self.proj = nn.Linear(all_head_dim, all_head_dim)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, rel_pos_bias=None, attn_mask=None):
B, N, C = x.shape
if self.subln:
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
else:
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
q, k, v = qkv[0], qkv[1], qkv[2]
if self.rope:
# slightly fast impl
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
k_t = k[:, :, 1:, :]
ro_k_t = self.rope(k_t)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
if self.xattn:
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
x = xops.memory_efficient_attention(
q, k, v,
p=self.xattn_drop,
scale=self.scale,
)
x = x.reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias.type_as(attn)
if attn_mask is not None:
attn_mask = attn_mask.bool()
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.inner_attn_ln(x)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
subln=False, naiveswiglu=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if naiveswiglu:
self.mlp = SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
subln=subln,
drop=drop
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
self.postnorm = postnorm
def forward(self, x, rel_pos_bias=None, attn_mask=None):
if self.gamma_1 is None:
if self.postnorm:
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
if self.postnorm:
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class EVAVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
super().__init__()
if not XFORMERS_IS_AVAILBLE:
xattn = False
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
if rope:
half_head_dim = embed_dim // num_heads // 2
hw_seq_len = img_size // patch_size
self.rope = VisionRotaryEmbeddingFast(
dim=half_head_dim,
pt_seq_len=pt_hw_seq_len,
ft_seq_len=hw_seq_len if intp_freq else None,
# patch_dropout=patch_dropout
)
else:
self.rope = None
self.naiveswiglu = naiveswiglu
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.grad_checkpointing = grad_checkpointing
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
if self.naiveswiglu:
rescale(layer.mlp.w3.weight.data, layer_id + 1)
else:
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def get_cast_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
if shuffle:
idx = torch.randperm(x.shape[1]) + 1
zero = torch.LongTensor([0, ])
idx = torch.cat([zero, idx])
pos_embed = self.pos_embed[:, idx]
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if shuffle:
x = x + pos_embed
elif self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
if os.getenv('RoPE') == '1':
if self.training and not isinstance(self.patch_dropout, nn.Identity):
x, patch_indices_keep = self.patch_dropout(x)
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
else:
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
x = self.patch_dropout(x)
else:
x = self.patch_dropout(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
hidden_states = []
for idx, blk in enumerate(self.blocks):
if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
hidden_states.append(x)
if self.grad_checkpointing:
x = checkpoint(blk, x, (rel_pos_bias,))
else:
x = blk(x, rel_pos_bias=rel_pos_bias)
if not return_all_features:
x = self.norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1)), hidden_states
else:
return x[:, 0], hidden_states
return x
def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
if return_all_features:
return self.forward_features(x, return_all_features, return_hidden, shuffle)
x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
x = self.head(x)
if return_hidden:
return x, hidden_states
return x

View File

@@ -0,0 +1,520 @@
import json
import logging
import os
import pathlib
import re
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union, Dict, Any
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
get_cast_dtype
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
from .transform import image_transform
from .tokenizer import HFTokenizer, tokenize
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, "r", encoding="utf8") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def get_tokenizer(model_name):
config = get_model_config(model_name)
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
return tokenizer
# loading openai CLIP weights when is_openai=True for training
def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
if is_openai:
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
state_dict = model.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
for mk in model_key.split('|'):
if isinstance(checkpoint, dict) and mk in checkpoint:
state_dict = checkpoint[mk]
break
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
for k in skip_list:
if k in list(state_dict.keys()):
logging.info(f"Removing key {k} from pretrained checkpoint")
del state_dict[k]
if os.getenv('RoPE') == '1':
for k in list(state_dict.keys()):
if 'freqs_cos' in k or 'freqs_sin' in k:
del state_dict[k]
return state_dict
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
state_dict['logit_scale'] = state_dict['text.logit_scale']
del state_dict['text.logit_scale']
# resize_clip_pos_embed for CLIP and open CLIP
if 'visual.positional_embedding' in state_dict:
resize_clip_pos_embed(state_dict, model)
# specified to eva_vit_model
elif 'visual.pos_embed' in state_dict:
resize_evaclip_pos_embed(state_dict, model)
# resize_clip_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if not k.startswith('visual.'):
del state_dict[k]
for k in list(state_dict.keys()):
if k.startswith('visual.'):
new_k = k[7:]
state_dict[new_k] = state_dict[k]
del state_dict[k]
return state_dict
def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
for k in list(state_dict.keys()):
if k.startswith('visual.'):
del state_dict[k]
return state_dict
def get_pretrained_tag(pretrained_model):
pretrained_model = pretrained_model.lower()
if "laion" in pretrained_model or "open_clip" in pretrained_model:
return "open_clip"
elif "openai" in pretrained_model:
return "clip"
elif "eva" in pretrained_model and "clip" in pretrained_model:
return "eva_clip"
else:
return "other"
def load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=True,
visual_model=None,
text_model=None,
model_key="model|module|state_dict",
skip_list=[]):
visual_tag = get_pretrained_tag(visual_model)
text_tag = get_pretrained_tag(text_model)
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
visual_incompatible_keys, text_incompatible_keys = None, None
if visual_checkpoint_path:
if visual_tag == "eva_clip" or visual_tag == "open_clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
elif visual_tag == "clip":
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
# resize_clip_pos_embed for CLIP and open CLIP
if 'positional_embedding' in visual_state_dict:
resize_visual_pos_embed(visual_state_dict, model)
# specified to EVA model
elif 'pos_embed' in visual_state_dict:
resize_eva_pos_embed(visual_state_dict, model)
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
if text_checkpoint_path:
if text_tag == "eva_clip" or text_tag == "open_clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
elif text_tag == "clip":
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
else:
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
return visual_incompatible_keys, text_incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = '',
pretrained_text: str = '',
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
cache_dir: Optional[str] = None,
local_dir: Optional[str] = None,
skip_list: list = [],
):
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
jit=jit,
cache_dir=cache_dir,
)
else:
model_cfg = get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if 'rope' in model_cfg.get('vision_cfg', {}):
if model_cfg['vision_cfg']['rope']:
os.environ['RoPE'] = "1"
else:
os.environ['RoPE'] = "0"
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
# override the default patch dropout value
model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
cast_dtype = get_cast_dtype(precision)
custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
if custom_clip:
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
pretrained_cfg = {}
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir, local_dir=local_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model,
checkpoint_path,
model_key="model|module|state_dict",
strict=False
)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
else:
visual_checkpoint_path = ''
text_checkpoint_path = ''
if pretrained_image:
pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
elif pretrained_image_cfg:
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_image):
visual_checkpoint_path = pretrained_image
else:
logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
if pretrained_text:
pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
if pretrained_image_cfg:
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained_text):
text_checkpoint_path = pretrained_text
else:
logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
if visual_checkpoint_path:
logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
if text_checkpoint_path:
logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
if visual_checkpoint_path or text_checkpoint_path:
load_pretrained_checkpoint(
model,
visual_checkpoint_path,
text_checkpoint_path,
strict=False,
visual_model=pretrained_visual_model,
text_model=pretrained_text_model,
model_key="model|module|state_dict",
skip_list=skip_list
)
if "fp16" in precision or "bf16" in precision:
logging.info(f'convert precision to {precision}')
model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
model.to(device=device)
# set image / mean metadata from pretrained_cfg if available, or use default
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
if jit:
model = torch.jit.script(model)
return model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = '',
pretrained_text: str = '',
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
local_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
local_dir=local_dir,
skip_list=skip_list,
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
return model, preprocess_train, preprocess_val
def create_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
pretrained_image: str = '',
pretrained_text: str = '',
pretrained_hf: bool = True,
pretrained_visual_model: str = None,
pretrained_text_model: str = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
):
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_hf=pretrained_hf,
pretrained_visual_model=pretrained_visual_model,
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
)
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
del model
return preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: str,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_clip: bool = False,
force_patch_dropout: Optional[float] = None,
return_transform: bool = True,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
is_frozen: bool = False,
):
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
raise RuntimeError(
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
f' Use open_clip.list_pretrained() to find one.')
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_clip=force_custom_clip,
force_patch_dropout=force_patch_dropout,
cache_dir=cache_dir,
)
if is_frozen:
for param in model.parameters():
param.requires_grad = False
if not return_transform:
return model
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std
)
return model, preprocess

View File

@@ -0,0 +1,57 @@
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
}
}

View File

@@ -0,0 +1,248 @@
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
def __init__(
self,
model_name_or_path: str,
output_dim: int,
tokenizer_name: str = None,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True,
masked_language_modeling: bool = False):
super().__init__()
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if masked_language_modeling:
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
AutoModelForMaskedLM.from_config, self.config)
else:
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
if masked_language_modeling:
self.transformer = AutoModelForMaskedLM.from_config(config)
else:
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
# attn_mask = (x != self.config.pad_token_id).long()
# out = self.transformer(
# input_ids=x,
# attention_mask=attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# )
# pooled_out = self.pooler(out, attn_mask)
# return self.itm_proj(pooled_out)
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
labels = input_ids.clone()
attn_mask = (input_ids != self.config.pad_token_id).long()
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
probability_matrix = torch.full(labels.shape, mlm_probability)
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
probability_matrix = probability_matrix)
mlm_output = self.transformer(input_ids,
attention_mask = attn_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
labels = labels,
)
return mlm_output.loss
# mlm_output = self.transformer(input_ids,
# attention_mask = attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# return_dict = True,
# ).last_hidden_state
# logits = self.mlm_proj(mlm_output)
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
# labels = labels[:, 1:].contiguous().view(-1)
# mlm_loss = F.cross_entropy(
# logits,
# labels,
# # label_smoothing=0.1,
# )
# return mlm_loss
def forward(self, x:TensorType) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
return self.proj(pooled_out)
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def get_num_layers(self):
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
return len(layer_list)
def init_parameters(self):
pass

View File

@@ -0,0 +1,138 @@
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from timm.loss import LabelSmoothingCrossEntropy
def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
use_horovod=False
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
# all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
# all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
class ClipLoss(nn.Module):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
smoothing=0.,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
# cache state
self.prev_num_logits = 0
self.labels = {}
def forward(self, image_features, text_features, logit_scale=1.):
device = image_features.device
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
# calculated ground-truth and cache if enabled
num_logits = logits_per_image.shape[0]
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
if self.label_smoothing_cross_entropy:
total_loss = (
self.label_smoothing_cross_entropy(logits_per_image, labels) +
self.label_smoothing_cross_entropy(logits_per_text, labels)
) / 2
else:
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
acc = None
i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
return total_loss, acc

View File

@@ -0,0 +1,439 @@
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
try:
from .hf_model import HFTextEncoder
except:
HFTextEncoder = None
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .eva_vit_model import EVAVisionTransformer
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
try:
from apex.normalization import FusedLayerNorm
except:
FusedLayerNorm = LayerNorm
print("Nvidia APEX normalization not installed, using PyTorch LayerNorm")
try:
import xformers.ops as xops
except ImportError:
xops = None
#print("Please 'pip install xformers'")
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
drop_path_rate: Optional[float] = None # drop path rate
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
qkv_bias: bool = True
fusedLN: bool = False
xattn: bool = False
postnorm: bool = False
rope: bool = False
pt_hw_seq_len: int = 16 # 224/14
intp_freq: bool = False
naiveswiglu: bool = False
subln: bool = False
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
width: int = 512
heads: int = 8
layers: int = 12
ls_init_value: Optional[float] = None # layer scale initial value
hf_model_name: str = None
hf_tokenizer_name: str = None
hf_model_pretrained: bool = True
proj: str = 'mlp'
pooler_type: str = 'mean_pooler'
masked_language_modeling: bool = False
fusedLN: bool = False
xattn: bool = False
attn_mask: bool = True
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.eva_model_name:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNorm
visual = EVAVisionTransformer(
img_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
num_classes=embed_dim,
use_mean_pooling=vision_cfg.global_average_pool, #False
init_values=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
embed_dim=vision_cfg.width,
depth=vision_cfg.layers,
num_heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
qkv_bias=vision_cfg.qkv_bias,
drop_path_rate=vision_cfg.drop_path_rate,
norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
xattn=vision_cfg.xattn,
rope=vision_cfg.rope,
postnorm=vision_cfg.postnorm,
pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
intp_freq= vision_cfg.intp_freq,
naiveswiglu= vision_cfg.naiveswiglu,
subln= vision_cfg.subln
)
elif vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
embed_dim=embed_dim,
image_size=vision_cfg.image_size
)
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
global_average_pool=vision_cfg.global_average_pool,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
tokenizer_name=text_cfg.hf_tokenizer_name,
proj=text_cfg.proj,
pooler_type=text_cfg.pooler_type,
masked_language_modeling=text_cfg.masked_language_modeling
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNorm
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
xattn=text_cfg.xattn,
attn_mask=text_cfg.attn_mask,
)
return text
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {'logit_scale'}
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
class CustomCLIP(nn.Module):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
itm_task: bool = False,
):
super().__init__()
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
@torch.jit.ignore
def no_weight_decay(self):
return {'logit_scale'}
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr, None)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, nn.Parameter):
l.data = l.data.to(dtype)
for name in ["text_projection", "proj"]:
if hasattr(l, name) and isinstance(l, nn.Parameter):
attr = getattr(l, name, None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
'logit_scale'
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model

View File

@@ -0,0 +1,19 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16,
"eva_model_name": "eva-clip-b-16",
"ls_init_value": 0.1,
"drop_path_rate": 0.0
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}

View File

@@ -0,0 +1,24 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24,
"xattn": false,
"fusedLN": true
}
}

View File

@@ -0,0 +1,24 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14,
"eva_model_name": "eva-clip-g-14-x",
"drop_path_rate": 0.4,
"xattn": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}

View File

@@ -0,0 +1,29 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"head_width": 64,
"patch_size": 16,
"mlp_ratio": 2.6667,
"eva_model_name": "eva-clip-b-16-X",
"drop_path_rate": 0.0,
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"xattn": true,
"fusedLN": true
}
}

View File

@@ -0,0 +1,29 @@
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}

View File

@@ -0,0 +1,29 @@
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14",
"xattn": true,
"fusedLN": true,
"rope": true,
"pt_hw_seq_len": 16,
"intp_freq": true,
"naiveswiglu": true,
"subln": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": false,
"fusedLN": true
}
}

View File

@@ -0,0 +1,25 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 64,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.571428571428571,
"patch_size": 14,
"eva_model_name": "eva-clip-4b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32,
"xattn": false,
"fusedLN": true
}
}

View File

@@ -0,0 +1,25 @@
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 64,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.571428571428571,
"patch_size": 14,
"eva_model_name": "eva-clip-4b-14-x",
"drop_path_rate": 0,
"xattn": true,
"postnorm": true,
"fusedLN": true
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24,
"xattn": false,
"fusedLN": true
}
}

View File

@@ -0,0 +1,181 @@
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from .utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x

View File

@@ -0,0 +1,144 @@
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import List, Optional, Union
import torch
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag('openai')
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
jit: bool = True,
cache_dir: Optional[str] = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = 'fp32' if device == 'cpu' else 'fp16'
if get_pretrained_url(name, 'openai'):
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
try:
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model = model.to(device)
if precision.startswith('amp') or precision == 'fp32':
model.float()
elif precision == 'bf16':
convert_weights_to_lp(model, dtype=torch.bfloat16)
return model
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 (typically for CPU)
if precision == 'fp32':
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
# ensure image_size attr available at consistent location for both jit and non-jit
model.visual.image_size = model.input_resolution.item()
return model

View File

@@ -0,0 +1,340 @@
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Union
from tqdm import tqdm
try:
from huggingface_hub import hf_hub_download
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
return dict(
url=url,
hf_hub=hf_hub,
mean=mean,
std=std,
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
)
_EVAB16 = dict(
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
laion2b_s32b_b82k=_pcfg(
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)
_EVAL14 = dict(
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)
_EVAL14_336 = dict(
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)
_EVAg14 = dict(
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
)
_EVAg14_PLUS = dict(
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
)
_EVAbigE14 = dict(
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
)
_EVAbigE14_PLUS = dict(
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
)
_PRETRAINED = {
# "ViT-B-32": _VITB32,
"OpenaiCLIP-B-32": _VITB32,
"OpenCLIP-B-32": _VITB32,
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
# "ViT-B-16": _VITB16,
"OpenaiCLIP-B-16": _VITB16,
"OpenCLIP-B-16": _VITB16,
"EVA02-B-16": _EVAB16,
"EVA02-CLIP-B-16": _EVAB16,
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
# "ViT-L-14": _VITL14,
"OpenaiCLIP-L-14": _VITL14,
"OpenCLIP-L-14": _VITL14,
"EVA02-L-14": _EVAL14,
"EVA02-CLIP-L-14": _EVAL14,
# "ViT-L-14-336": _VITL14_336,
"OpenaiCLIP-L-14-336": _VITL14_336,
"EVA02-CLIP-L-14-336": _EVAL14_336,
# "ViT-H-14": _VITH14,
# "ViT-g-14": _VITg14,
"OpenCLIP-H-14": _VITH14,
"OpenCLIP-g-14": _VITg14,
"EVA01-CLIP-g-14": _EVAg14,
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
# "ViT-bigG-14": _VITbigG14,
"OpenCLIP-bigG-14": _VITbigG14,
"EVA02-CLIP-bigE-14": _EVAbigE14,
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace('-', '_')
def list_pretrained(as_str: bool = False):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
""" return all models having the specified pretrain tag """
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
""" return all pretrain tags for the specified model architecture """
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get('url', '')
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
local_dir: Union[str, None] = None,
):
cache_dir = local_dir if not local_dir else cache_dir
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if 'openaipublic' in url:
expected_sha256 = url.split("/")[-2]
elif 'mlfoundations' in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ''
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = 'open_clip_pytorch_model.bin',
revision=None,
cache_dir: Union[str, None] = None,
local_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir, local_dir=local_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
local_dir: Union[str, None] = None,
):
target = ''
if not cfg:
return target
download_url = cfg.get('url', '')
download_hf_hub = cfg.get('hf_hub', '')
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ''
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir, local_dir=local_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if local_dir is not None:
full_model_path = os.path.join(local_dir, filename)
if os.path.exists(full_model_path):
return full_model_path
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir, local_dir=local_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir, local_dir=local_dir)
return target

View File

@@ -0,0 +1,137 @@
from math import pi
import torch
from torch import nn
from einops import rearrange, repeat
import logging
def broadcat(tensors, dim = -1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim = dim)
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs = None,
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
num_freqs = 1,
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f'unknown modality {freqs_for}')
if ft_seq_len is None: ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
self.register_buffer("freqs_cos", freqs.cos())
self.register_buffer("freqs_sin", freqs.sin())
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
def forward(self, t, start_index = 0):
rot_dim = self.freqs_cos.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
return torch.cat((t_left, t, t_right), dim = -1)
class VisionRotaryEmbeddingFast(nn.Module):
def __init__(
self,
dim,
pt_seq_len,
ft_seq_len=None,
custom_freqs = None,
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
num_freqs = 1,
patch_dropout = 0.
):
super().__init__()
if custom_freqs:
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f'unknown modality {freqs_for}')
if ft_seq_len is None: ft_seq_len = pt_seq_len
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
freqs = torch.einsum('..., f -> ... f', t, freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
self.patch_dropout = patch_dropout
self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
def forward(self, t, patch_indices_keep=None):
if patch_indices_keep is not None:
batch = t.size()[0]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
return t * freqs_cos + rotate_half(t) * freqs_sin
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin

View File

@@ -0,0 +1,122 @@
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from .utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
# FIXME this adapter is a work in progress, may change in ways that break weight compat
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
pretrained=False):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
self.trunk = timm.create_model(model_name, pretrained=pretrained)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if pool in ('abs_attn', 'rot_attn'):
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
else:
assert proj, 'projection layer needed if non-attention pooling is used.'
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward(self, x):
x = self.trunk(x)
x = self.head(x)
return x

View File

@@ -0,0 +1,201 @@
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import Union, List
import ftfy
import regex as re
import torch
# https://stackoverflow.com/q/62691279
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
if not special_tokens:
special_tokens = ['<start_of_text>', '<end_of_text>']
else:
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t:t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
_tokenizer = SimpleTokenizer()
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<start_of_text>"]
eot_token = _tokenizer.encoder["<end_of_text>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result
class HFTokenizer:
"HuggingFace tokenizer wrapper"
def __init__(self, tokenizer_name:str):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
return input_ids

View File

@@ -0,0 +1,103 @@
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
class ResizeMaxSize(nn.Module):
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == 'min' else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
return img
def _convert_to_rgb(image):
return image.convert('RGB')
# class CatGen(nn.Module):
# def __init__(self, num=4):
# self.num = num
# def mixgen_batch(image, text):
# batch_size = image.shape[0]
# index = np.random.permutation(batch_size)
# cat_images = []
# for i in range(batch_size):
# # image mixup
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
# # text concat
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
# text = torch.stack(text)
# return image, text
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose([
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
])
else:
if resize_longest_max:
transforms = [
ResizeMaxSize(image_size, fill=fill_color)
]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)

View File

@@ -0,0 +1,737 @@
import os
import logging
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
try:
from timm.models.layers import trunc_normal_
except:
from timm.layers import trunc_normal_
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
from .utils import to_2tuple
if os.getenv('ENV_TYPE') == 'deepspeed':
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
print("Please 'pip install deepspeed'")
deepspeed = None
from torch.utils.checkpoint import checkpoint
else:
from torch.utils.checkpoint import checkpoint
try:
import xformers.ops as xops
except ImportError:
xops = None
print("Please 'pip install xformers'")
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor):
output = F.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(x)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
if self.training and os.getenv('RoPE') == '1':
return x, patch_indices_keep
return x
def _in_projection_packed(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
b: Optional[torch.Tensor] = None,
):
"""
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
"""
E = q.size(-1)
if k is v:
if q is k:
# self-attention
return F.linear(q, w, b).chunk(3, dim=-1)
else:
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
else:
w_q, w_k, w_v = w.chunk(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
logit_scale_max=math.log(1. / 0.01),
attn_drop=0.,
proj_drop=0.,
xattn=False,
rope=False
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
self.rope = rope
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
if self.xattn:
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
x = xops.memory_efficient_attention(
q, k, v,
p=self.xattn_drop,
scale=self.scale if self.logit_scale is None else None,
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
)
else:
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class CustomAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=True,
scale_heads=False,
logit_scale_max=math.log(1. / 0.01),
attn_drop=0.,
proj_drop=0.,
xattn=False
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
self.xattn = xattn
self.xattn_drop = attn_drop
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
N_q, B_q, C_q = q.shape
N_k, B_k, C_k = k.shape
N_v, B_v, C_v = v.shape
if self.xattn:
# B, N, C -> B, N, num_heads, C
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
x = xops.memory_efficient_attention(
q, k, v,
p=self.xattn_drop,
scale=self.scale if self.logit_scale is None else None,
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
)
else:
# B*H, L, C
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
if self.logit_scale is not None:
# B*H, N_q, N_k
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
attn = attn.view(-1, N_q, N_k)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
x = x.view(-1, N_q, C_q)
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
cross_attn: bool = False,
xattn: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
self.attn = CustomAttention(
d_model, n_head,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
xattn=xattn
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
q = q + self.ls_2(self.mlp(self.ln_2(q)))
return q
class CustomTransformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = True,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
cross_attn: bool = False,
xattn: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.xattn = xattn
self.resblocks = nn.ModuleList([
CustomResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
scale_cosine_attn=scale_cosine_attn,
scale_heads=scale_heads,
scale_attn=scale_attn,
scale_fc=scale_fc,
cross_attn=cross_attn,
xattn=xattn)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
if k is None and v is None:
k = v = q
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
q = checkpoint(r, q, k, v, attn_mask)
else:
q = r(q, k, v, attn_mask=attn_mask)
return q
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if xattn:
self.attn = Attention(d_model, n_head, xattn=True)
else:
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.xattn = xattn
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
if self.xattn:
return self.attn(x, attn_mask=attn_mask)
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
patch_dropout: float = 0.,
global_average_pool: bool = False,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool = False,
):
super().__init__()
self.image_size = to_2tuple(image_size)
self.patch_size = to_2tuple(patch_size)
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
xattn=xattn
)
self.global_average_pool = global_average_pool
self.ln_post = norm_layer(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def get_num_layers(self):
return self.transformer.layers
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
return {'positional_embedding', 'class_embedding'}
def forward(self, x: torch.Tensor, return_all_features: bool=False):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat(
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
if not return_all_features:
if self.global_average_pool:
x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
else:
x = x[:, 0]
x = self.ln_post(x)
if self.proj is not None:
x = x @ self.proj
return x
class TextTransformer(nn.Module):
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
ls_init_value: float = None,
output_dim: int = 512,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
xattn: bool= False,
attn_mask: bool = True
):
super().__init__()
self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.token_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
xattn=xattn
)
self.xattn = xattn
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
if attn_mask:
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
else:
self.attn_mask = None
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
@torch.jit.ignore
def no_weight_decay(self):
# return {'positional_embedding', 'token_embedding'}
return {'positional_embedding'}
def get_num_layers(self):
return self.transformer.layers
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, text, return_all_features: bool=False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x, attn_mask=self.attn_mask)
# x = self.transformer(x) # no attention mask is applied
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
if not return_all_features:
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x

View File

@@ -0,0 +1,326 @@
from itertools import repeat
import collections.abc
import logging
import math
import numpy as np
import torch
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
import torch.nn.functional as F
# open CLIP
def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed
def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
align_corners=True,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['positional_embedding'] = new_pos_embed
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
all_keys = list(state_dict.keys())
# interpolate position embedding
if 'visual.pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['visual.pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['visual.pos_embed'] = new_pos_embed
patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
patch_size = model.visual.patch_embed.patch_size
state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
all_keys = list(state_dict.keys())
# interpolate position embedding
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
patch_embed_proj = state_dict['patch_embed.proj.weight']
patch_size = model.visual.patch_embed.patch_size
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
all_keys = list(state_dict.keys())
for key in all_keys:
if "relative_position_index" in key:
state_dict.pop(key)
if "relative_position_bias_table" in key:
rel_pos_bias = state_dict[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.visual.state_dict()[key].size()
dst_patch_shape = model.visual.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
print("Position interpolate for %s from %dx%d to %dx%d" % (
key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
f = F.interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
state_dict[key] = new_rel_pos_bias
# interpolate position embedding
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.visual.patch_embed.num_patches
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
patch_embed_proj = state_dict['patch_embed.proj.weight']
patch_size = model.visual.patch_embed.patch_size
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
def is_logging(args):
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
return is_master
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor.
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
@staticmethod
def forward(ctx, tensor, rank, world_size):
tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensors_gather, tensor)
ctx.rank = rank
ctx.batch_size = tensor.shape[0]
return torch.cat(tensors_gather, 0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
None,
None
)
allgather = AllGather.apply

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

@@ -0,0 +1,456 @@
import cv2
import numpy as np
import os
import torch
from torchvision.transforms.functional import normalize
from torchvision.ops import box_iou
from facexlib.detection import init_detection_model
from facexlib.parsing import init_parsing_model
from facexlib.utils.misc import img2tensor, imwrite
def get_face_by_index(det_faces, face_sort_rule, face_index=0):
if det_faces is None:
return None, None
has_bbox_attr = hasattr(det_faces[0], 'bbox')
# 创建带索引的列表 [(original_index, face), ...]
indexed_faces = list(enumerate(det_faces))
# 定义排序规则
if face_sort_rule == 'left-right':
sorted_faces = sorted(indexed_faces, key=lambda x: x[1].bbox[0] if has_bbox_attr else x[1][0])
elif face_sort_rule == "right-left":
sorted_faces = sorted(indexed_faces, key=lambda x: x[1].bbox[0] if has_bbox_attr else x[1][0], reverse=True)
elif face_sort_rule == "top-bottom":
sorted_faces = sorted(indexed_faces, key=lambda x: x[1].bbox[1] if has_bbox_attr else x[1][1])
elif face_sort_rule == "bottom-top":
sorted_faces = sorted(indexed_faces, key=lambda x: x[1].bbox[1] if has_bbox_attr else x[1][1], reverse=True)
elif face_sort_rule == "small-large":
sorted_faces = sorted(indexed_faces, key=lambda x: (x[1].bbox[2] - x[1].bbox[0]) * (x[1].bbox[3] - x[1].bbox[1]) if has_bbox_attr else (x[1][2] - x[1][0]) * (x[1][3] - x[1][1]))
elif face_sort_rule == "large-small":
sorted_faces = sorted(indexed_faces, key=lambda x: (x[1].bbox[2] - x[1].bbox[0]) * (x[1].bbox[3] - x[1].bbox[1]) if has_bbox_attr else (x[1][2] - x[1][0]) * (x[1][3] - x[1][1]), reverse=True)
else:
sorted_faces = indexed_faces
# 返回原始索引
if not 0 <= face_index < len(sorted_faces):
# 返回第一个
face_index = 0
# 返回选择的脸部、原始索引值和排序后的列表
return sorted_faces[face_index][1], sorted_faces[face_index][0], [face[1] for face in sorted_faces]
def get_largest_face(det_faces, h, w):
def get_location(val, length):
if val < 0:
return 0
elif val > length:
return length
else:
return val
face_areas = []
for det_face in det_faces:
left = get_location(det_face[0], w)
right = get_location(det_face[2], w)
top = get_location(det_face[1], h)
bottom = get_location(det_face[3], h)
face_area = (right - left) * (bottom - top)
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
return det_faces[largest_idx], largest_idx
def get_center_face(det_faces, h=0, w=0, center=None):
if center is not None:
center = np.array(center)
else:
center = np.array([w / 2, h / 2])
center_dist = []
for det_face in det_faces:
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
dist = np.linalg.norm(face_center - center)
center_dist.append(dist)
center_idx = center_dist.index(min(center_dist))
return det_faces[center_idx], center_idx
class FaceRestoreHelper(object):
"""Helper for the face restoration pipeline (base class)."""
def __init__(self,
upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
parsing_model='bisenet',
save_ext='png',
template_3points=False,
pad_blur=False,
use_parse=False,
device=None,
model_rootpath=None):
self.template_3points = template_3points # improve robustness
self.upscale_factor = upscale_factor
# the cropped face ratio based on the square face
self.crop_ratio = crop_ratio # (h, w)
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
if self.template_3points:
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
else:
# standard 5 landmarks for FFHQ faces with 512 x 512
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
[201.26117, 371.41043], [313.08905, 371.15118]])
self.face_template = self.face_template * (face_size / 512.0)
if self.crop_ratio[0] > 1:
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
if self.crop_ratio[1] > 1:
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
self.save_ext = save_ext
self.pad_blur = pad_blur
if self.pad_blur is True:
self.template_3points = False
self.all_landmarks_5 = []
self.det_faces = []
self.affine_matrices = []
self.inverse_affine_matrices = []
self.cropped_faces = []
self.restored_faces = []
self.pad_input_imgs = []
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
# init face detection model
self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath)
# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(model_name=parsing_model, device=self.device, model_rootpath=model_rootpath)
def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
def read_image(self, img):
"""img can be image path or cv2 loaded image."""
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
if isinstance(img, str):
img = cv2.imread(img)
if np.max(img) > 256: # 16-bit image
img = img / 65535 * 255
if len(img.shape) == 2: # gray image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # RGBA image with alpha channel
img = img[:, :, 0:3]
self.input_img = img
def get_face_landmarks_5(self,
only_keep_largest=False,
only_center_face=False,
resize=None,
blur_ratio=0.01,
eye_dist_threshold=None,
face_sort_rule=None,
ref_sort_bboxes=None,
face_index=None):
if resize is None:
scale = 1
input_img = self.input_img
else:
h, w = self.input_img.shape[0:2]
scale = min(h, w) / resize
h, w = int(h / scale), int(w / scale)
input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)
with torch.no_grad():
# use 0.5 (old value is 0.97), keep consistent with Insightface, but still cannot ensure consistent quantity of bboxes.
bboxes = self.face_det.detect_faces(input_img, 0.5) * scale
for bbox in bboxes:
# remove faces with too small eye distance: side faces or too small faces
eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]])
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
continue
if self.template_3points:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
else:
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
self.all_landmarks_5.append(landmark)
self.det_faces.append(bbox[0:5])
if len(self.det_faces) == 0:
return 0
if ref_sort_bboxes is not None:
if len(self.det_faces) != len(ref_sort_bboxes):
return 0
iou = box_iou(torch.tensor(ref_sort_bboxes), torch.tensor(self.det_faces)[:, :4])
indices = torch.max(iou, dim=1).indices
self.det_faces = [self.det_faces[idx] for idx in indices]
if face_index is None or face_index >= len(self.det_faces):
face_index = 0
self.all_landmarks_5 = [self.all_landmarks_5[indices[face_index]]]
elif face_sort_rule is not None:
self.det_faces, center_idx, _ = get_face_by_index(self.det_faces, face_sort_rule=face_sort_rule, face_index=face_index)
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
elif only_keep_largest:
h, w, _ = self.input_img.shape
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
elif only_center_face:
h, w, _ = self.input_img.shape
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
# pad blurry images
if self.pad_blur:
self.pad_input_imgs = []
for landmarks in self.all_landmarks_5:
# get landmarks
eye_left = landmarks[0, :]
eye_right = landmarks[1, :]
eye_avg = (eye_left + eye_right) * 0.5
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1.5
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
border = max(int(np.rint(qsize * 0.1)), 3)
# get pad
# pad: (width_left, height_top, width_right, height_bottom)
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = [
max(-pad[0] + border, 1),
max(-pad[1] + border, 1),
max(pad[2] - self.input_img.shape[0] + border, 1),
max(pad[3] - self.input_img.shape[1] + border, 1)
]
if max(pad) > 1:
# pad image
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
# modify landmark coords
landmarks[:, 0] += pad[0]
landmarks[:, 1] += pad[1]
# blur pad images
h, w, _ = pad_img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * blur_ratio)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
pad_img = pad_img.astype('float32')
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
self.pad_input_imgs.append(pad_img)
else:
self.pad_input_imgs.append(np.copy(self.input_img))
return len(self.all_landmarks_5)
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
"""Align and warp faces with face template.
"""
if self.pad_blur:
assert len(self.pad_input_imgs) == len(
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
for idx, landmark in enumerate(self.all_landmarks_5):
# use 5 landmarks to get affine matrix
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
self.affine_matrices.append(affine_matrix)
# warp and crop faces
if border_mode == 'constant':
border_mode = cv2.BORDER_CONSTANT
elif border_mode == 'reflect101':
border_mode = cv2.BORDER_REFLECT101
elif border_mode == 'reflect':
border_mode = cv2.BORDER_REFLECT
if self.pad_blur:
input_img = self.pad_input_imgs[idx]
else:
input_img = self.input_img
cropped_face = cv2.warpAffine(
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
self.cropped_faces.append(cropped_face)
# save the cropped face
if save_cropped_path is not None:
path = os.path.splitext(save_cropped_path)[0]
save_path = f'{path}_{idx:02d}.{self.save_ext}'
imwrite(cropped_face, save_path)
def get_inverse_affine(self, save_inverse_affine_path=None):
"""Get inverse affine matrix."""
for idx, affine_matrix in enumerate(self.affine_matrices):
inverse_affine = cv2.invertAffineTransform(affine_matrix)
inverse_affine *= self.upscale_factor
self.inverse_affine_matrices.append(inverse_affine)
# save inverse affine matrices
if save_inverse_affine_path is not None:
path, _ = os.path.splitext(save_inverse_affine_path)
save_path = f'{path}_{idx:02d}.pth'
torch.save(inverse_affine, save_path)
def add_restored_face(self, face):
self.restored_faces.append(face)
def paste_faces_to_input_image(self, save_path=None, upsample_img=None):
h, w, _ = self.input_img.shape
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
if upsample_img is None:
# simply resize the background
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
else:
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
assert len(self.restored_faces) == len(
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
# Add an offset to inverse affine matrix, for more precise back alignment
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
if self.use_parse:
# inference
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).to(self.device)
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()
mask = np.zeros(out.shape)
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
for idx, color in enumerate(MASK_COLORMAP):
mask[out == idx] = color
# blur the mask
mask = cv2.GaussianBlur(mask, (101, 101), 11)
mask = cv2.GaussianBlur(mask, (101, 101), 11)
# remove the black borders
thres = 10
mask[:thres, :] = 0
mask[-thres:, :] = 0
mask[:, :thres] = 0
mask[:, -thres:] = 0
mask = mask / 255.
mask = cv2.resize(mask, restored_face.shape[:2])
mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
inv_soft_mask = mask[:, :, None]
pasted_face = inv_restored
else: # use square parse maps
mask = np.ones(self.face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
alpha = upsample_img[:, :, 3:]
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
else:
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
if np.max(upsample_img) > 256: # 16-bit image
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
if save_path is not None:
path = os.path.splitext(save_path)[0]
save_path = f'{path}.{self.save_ext}'
imwrite(upsample_img, save_path)
return upsample_img
def clean_all(self):
self.all_landmarks_5 = []
self.restored_faces = []
self.affine_matrices = []
self.cropped_faces = []
self.inverse_affine_matrices = []
self.det_faces = []
self.pad_input_imgs = []
def draw_on(img, faces):
dimg = img.copy()
for i in range(len(faces)):
face = faces[i]
box = face.bbox.astype(np.int32)
color = (0, 0, 255)
cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2)
if face.kps is not None:
kps = face.kps.astype(np.int32)
#print(landmark.shape)
for l in range(kps.shape[0]):
color = (0, 0, 255)
if l == 0 or l == 3:
color = (0, 255, 0)
cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color,
2)
cv2.putText(dimg,'index: %d'%i, (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
# if face.gender is not None and face.age is not None:
# cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1)
#for key, value in face.items():
# if key.startswith('landmark_3d'):
# print(key, value.shape)
# print(value[0:10,:])
# lmk = np.round(value).astype(np.int)
# for l in range(lmk.shape[0]):
# color = (255, 0, 0)
# cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color,
# 2)
return dimg

View File

@@ -0,0 +1,41 @@
class PatchKeys:
################## transformer_options patches ##################
running_net_model = "running_net_model"
options_key = "patches_point"
# patches_point下支持设置的补丁
dit_enter = "patch_dit_enter"
dit_blocks_before = "patch_dit_blocks_before"
dit_double_blocks_replace = "patch_dit_double_blocks_replace"
dit_double_blocks_after = "patch_dit_double_blocks_after"
dit_blocks_transition_replace = "patch_dit_blocks_transition_replace"
dit_single_blocks_before = "patch_dit_single_blocks_before"
dit_single_blocks_replace = "patch_dit_single_blocks_replace"
dit_blocks_after = "patch_dit_blocks_after"
dit_blocks_after_transition_replace = "patch_dit_final_layer_before_replace"
dit_final_layer_before = "patch_dit_final_layer_before"
dit_exit = "patch_dit_exit"
################## transformer_options patches ##################
# pulid
pulid_patch_key_attrs = "pulid_temp_attr"
def set_model_patch(model_patcher, options_key, patch, name):
to = model_patcher.model_options["transformer_options"]
if options_key not in to:
to[options_key] = {}
to[options_key][name] = to[options_key].get(name, []) + [patch]
def set_model_patch_replace(model_patcher, options_key, patch, name):
to = model_patcher.model_options["transformer_options"]
if options_key not in to:
to[options_key] = {}
to[options_key][name] = patch
def add_model_patch_option(model, patch_key):
if 'transformer_options' not in model.model_options:
model.model_options['transformer_options'] = {}
to = model.model_options['transformer_options']
if patch_key not in to:
to[patch_key] = {}
return to[patch_key]

View File

@@ -0,0 +1,878 @@
import types
import zipfile
import cv2
import torch
from insightface.utils.download import download_file
from insightface.utils.storage import BASE_REPO_URL
from insightface.utils import face_align
from torch import nn
from torchvision import transforms
from torchvision.transforms import functional
import os
import logging
import folder_paths
import comfy
from insightface.app import FaceAnalysis
from .face_restoration_helper import FaceRestoreHelper, get_face_by_index, draw_on
from comfy import model_management
from .eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .encoders_flux import IDFormer, PerceiverAttentionCA
from .PulidFluxHook import pulid_forward_orig, set_model_dit_patch_replace, pulid_enter, pulid_patch_double_blocks_after
from .patch_util import PatchKeys, add_model_patch_option, set_model_patch
# facenet implementation
import numpy as np
from PIL import Image
from facenet_pytorch import MTCNN, InceptionResnetV1
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
def set_extra_config_model_path(extra_config_models_dir_key, models_dir_name:str):
models_dir_default = os.path.join(folder_paths.models_dir, models_dir_name)
if extra_config_models_dir_key not in folder_paths.folder_names_and_paths:
folder_paths.folder_names_and_paths[extra_config_models_dir_key] = (
[os.path.join(folder_paths.models_dir, models_dir_name)], folder_paths.supported_pt_extensions)
else:
if not os.path.exists(models_dir_default):
os.makedirs(models_dir_default, exist_ok=True)
folder_paths.add_model_folder_path(extra_config_models_dir_key, models_dir_default, is_default=True)
set_extra_config_model_path("pulid", "pulid")
set_extra_config_model_path("insightface", "insightface")
set_extra_config_model_path("facexlib", "facexlib")
INSIGHTFACE_DIR = folder_paths.get_folder_paths("insightface")[0]
FACEXLIB_DIR = folder_paths.get_folder_paths("facexlib")[0]
#FACENET_DIR = folder_paths.get_folder_paths("facenet")[0]
# MODELS_DIR = os.path.join(folder_paths.models_dir, "pulid")
# if "pulid" not in folder_paths.folder_names_and_paths:
# current_paths = [MODELS_DIR]
# else:
# current_paths, _ = folder_paths.folder_names_and_paths["pulid"]
# folder_paths.folder_names_and_paths["pulid"] = (current_paths, folder_paths.supported_pt_extensions)
class PulidFluxModel(nn.Module):
def __init__(self):
super().__init__()
self.double_interval = 2
self.single_interval = 4
# Init encoder
self.pulid_encoder = IDFormer()
# Init attention
num_ca = 19 // self.double_interval + 38 // self.single_interval
if 19 % self.double_interval != 0:
num_ca += 1
if 38 % self.single_interval != 0:
num_ca += 1
self.pulid_ca = nn.ModuleList([
PerceiverAttentionCA() for _ in range(num_ca)
])
def from_pretrained(self, path: str):
state_dict = comfy.utils.load_torch_file(path, safe_load=True)
state_dict_dict = {}
for k, v in state_dict.items():
module = k.split('.')[0]
state_dict_dict.setdefault(module, {})
new_k = k[len(module) + 1:]
state_dict_dict[module][new_k] = v
for module in state_dict_dict:
getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)
del state_dict
del state_dict_dict
def get_embeds(self, face_embed, clip_embeds):
return self.pulid_encoder(face_embed, clip_embeds)
def tensor_to_image(tensor):
image = tensor.mul(255).clamp(0, 255).byte().cpu()
image = image[..., [2, 1, 0]].numpy()
return image
def image_to_tensor(image):
tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1)
tensor = tensor[..., [2, 1, 0]]
return tensor
def to_gray(img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
x = x.repeat(1, 3, 1, 1)
return x
"""
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Nodes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
wrappers_name = "PULID_wrappers"
class PulidFluxModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"pulid_file": (folder_paths.get_filename_list("pulid"), )}}
RETURN_TYPES = ("PULIDFLUX",)
FUNCTION = "load_model"
CATEGORY = "pulid"
def load_model(self, pulid_file):
model_path = folder_paths.get_full_path("pulid", pulid_file)
# Also initialize the model, takes longer to load but then it doesn't have to be done every time you change parameters in the apply node
offload_device = model_management.unet_offload_device()
load_device = model_management.get_torch_device()
model = PulidFluxModel()
logging.info("Loading PuLID-Flux model.")
model.from_pretrained(path=model_path)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
del model
return (model_patcher,)
def download_insightface_model(sub_dir, name, force=False, root='~/.insightface'):
# Copied and modified from insightface.utils.storage.download
# Solve https://github.com/deepinsight/insightface/issues/2711
_root = os.path.expanduser(root)
dir_path = os.path.join(_root, sub_dir, name)
if os.path.exists(dir_path) and not force:
return dir_path
print('download_path:', dir_path)
zip_file_path = os.path.join(_root, sub_dir, name + '.zip')
model_url = "%s/%s.zip"%(BASE_REPO_URL, name)
download_file(model_url,
path=zip_file_path,
overwrite=True)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# zip file has contains ${name}
real_dir_path = os.path.join(_root, sub_dir)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(real_dir_path)
#os.remove(zip_file_path)
return dir_path
class PulidFluxInsightFaceLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"provider": (["CPU", "CUDA", "ROCM"], ),
},
}
RETURN_TYPES = ("FACEANALYSIS",)
FUNCTION = "load_insightface"
CATEGORY = "pulid"
def load_insightface(self, provider):
name = "antelopev2"
download_insightface_model("models", name, root=INSIGHTFACE_DIR)
model = FaceAnalysis(name=name, root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider', ]) # alternative to buffalo_l
model.prepare(ctx_id=0, det_size=(640, 640))
return (model,)
class PulidFluxEvaClipLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {},
}
RETURN_TYPES = ("EVA_CLIP",)
FUNCTION = "load_eva_clip"
CATEGORY = "pulid"
def load_eva_clip(self):
from .eva_clip.factory import create_model_and_transforms
clip_file_path = folder_paths.get_full_path("text_encoders", 'EVA02_CLIP_L_336_psz14_s6B.pt')
if clip_file_path is None:
clip_dir = os.path.join(folder_paths.models_dir, "clip")
else:
clip_dir = os.path.dirname(clip_file_path)
model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True, local_dir=clip_dir)
model = model.visual
eva_transform_mean = getattr(model, 'image_mean', OPENAI_DATASET_MEAN)
eva_transform_std = getattr(model, 'image_std', OPENAI_DATASET_STD)
if not isinstance(eva_transform_mean, (list, tuple)):
model["image_mean"] = (eva_transform_mean,) * 3
if not isinstance(eva_transform_std, (list, tuple)):
model["image_std"] = (eva_transform_std,) * 3
return (model,)
class ApplyPulidFlux:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"pulid_flux": ("PULIDFLUX", ),
"eva_clip": ("EVA_CLIP", ),
"face_analysis": ("FACEANALYSIS", ),
"image": ("IMAGE", ),
"weight": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 5.0, "step": 0.05 }),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001 }),
},
"optional": {
"attn_mask": ("MASK", ),
"options": ("OPTIONS",),
},
"hidden": {
"unique_id": "UNIQUE_ID"
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_pulid_flux"
CATEGORY = "pulid"
def apply_pulid_flux(self, model, pulid_flux, eva_clip, face_analysis, image, weight, start_at, end_at, attn_mask=None, options={}, unique_id=None):
model = model.clone()
device = comfy.model_management.get_torch_device()
# Why should I care what args say, when the unet model has a different dtype?!
# Am I missing something?!
#dtype = comfy.model_management.unet_dtype()
dtype = model.model.diffusion_model.dtype
# Because of 8bit models we must check what cast type does the unet uses
# ZLUDA (Intel, AMD) & GPUs with compute capability < 8.0 don't support bfloat16 etc.
# Issue: https://github.com/balazik/ComfyUI-PuLID-Flux/issues/6
if model.model.manual_cast_dtype is not None:
dtype = model.model.manual_cast_dtype
eva_clip.to(device, dtype=dtype)
pulid_flux.model.to(dtype=dtype)
model_management.load_models_gpu([pulid_flux], force_full_load=True)
# model_management.load_model_gpu(pulid_flux)
if attn_mask is not None:
if attn_mask.dim() > 3:
attn_mask = attn_mask.squeeze(-1)
elif attn_mask.dim() < 3:
attn_mask = attn_mask.unsqueeze(0)
# attn_mask = attn_mask.to(device, dtype=dtype)
image = tensor_to_image(image)
face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
parsing_model='bisenet',
save_ext='png',
device=device,
model_rootpath=FACEXLIB_DIR
)
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
cond = []
input_face_sort = options.get('input_faces_order', "large-small")
input_face_index = options.get('input_faces_index', 0)
input_face_align_mode = options.get('input_faces_align_mode', 1)
# Analyse multiple images at multiple sizes and combine largest area embeddings
for i in range(image.shape[0]):
# get insightface embeddings
bboxes = []
iface_embeds = None
for size in [(size, size) for size in range(640, 256, -64)]:
face_analysis.det_model.input_size = size
face_info = face_analysis.get(image[i])
if face_info:
face_info, index, sorted_faces = get_face_by_index(face_info, face_sort_rule=input_face_sort, face_index=input_face_index)
bboxes = [face.bbox for face in sorted_faces]
iface_embeds = torch.from_numpy(face_info.embedding).unsqueeze(0).to(device, dtype=dtype)
break
else:
# No face detected, skip this image
logging.warning(f'Warning: No face detected in image {str(i)}')
continue
if input_face_align_mode == 1:
image_size = 512
#M = face_align.estimate_norm(face_info.kps, image_size=image_size)
kps = np.array(face_info.kps, dtype=np.float64)
if kps.dtype != np.float64:
kps = kps.astype(np.float64)
M = face_align.estimate_norm(kps, image_size=image_size)
align_face = cv2.warpAffine(image[i], M, (image_size, image_size), borderMode=cv2.BORDER_CONSTANT,
borderValue=(135, 133, 132))
# align_face = face_align.norm_crop(image[i], landmark=face_info.kps, image_size=image_size)
del M
else:
# get eva_clip embeddings
face_helper.clean_all()
face_helper.read_image(image[i])
face_helper.get_face_landmarks_5(ref_sort_bboxes=bboxes, face_index=input_face_index)
face_helper.align_warp_face()
if len(face_helper.cropped_faces) == 0:
# No face detected, skip this image
continue
# Get aligned face image
align_face = face_helper.cropped_faces[0]
# Convert bgr face image to tensor
align_face = image_to_tensor(align_face).unsqueeze(0).permute(0, 3, 1, 2).to(device)
parsing_out = face_helper.face_parse(functional.normalize(align_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
parsing_out = parsing_out.argmax(dim=1, keepdim=True)
bg = sum(parsing_out == i for i in bg_label).bool()
white_image = torch.ones_like(align_face)
# Only keep the face features
face_features_image = torch.where(bg, white_image, to_gray(align_face))
# Transform img before sending to eva_clip
# Apparently MPS only supports NEAREST interpolation?
face_features_image = functional.resize(face_features_image, eva_clip.image_size, transforms.InterpolationMode.BICUBIC if 'cuda' in device.type else transforms.InterpolationMode.NEAREST).to(device, dtype=dtype)
face_features_image = functional.normalize(face_features_image, eva_clip.image_mean, eva_clip.image_std)
# eva_clip
id_cond_vit, id_vit_hidden = eva_clip(face_features_image, return_all_features=False, return_hidden=True, shuffle=False)
id_cond_vit = id_cond_vit.to(device, dtype=dtype)
for idx in range(len(id_vit_hidden)):
id_vit_hidden[idx] = id_vit_hidden[idx].to(device, dtype=dtype)
id_cond_vit = torch.div(id_cond_vit, torch.norm(id_cond_vit, 2, 1, True))
# Combine embeddings
id_cond = torch.cat([iface_embeds, id_cond_vit], dim=-1)
# Pulid_encoder
cond.append(pulid_flux.model.get_embeds(id_cond, id_vit_hidden))
eva_clip.to(torch.device('cpu'))
if not cond:
# No faces detected, return the original model
logging.warning("PuLID warning: No faces detected in any of the given images, returning unmodified model.")
del eva_clip, face_analysis, pulid_flux, face_helper, attn_mask
return (model,)
# average embeddings
cond = torch.cat(cond).to(device, dtype=dtype)
if cond.shape[0] > 1:
cond = torch.mean(cond, dim=0, keepdim=True)
sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
patch_kwargs = {
"pulid_model": pulid_flux,
"weight": weight,
"embedding": cond,
"sigma_start": sigma_start,
"sigma_end": sigma_end,
"mask": attn_mask
}
ca_idx = 0
for i in range(19):
if i % pulid_flux.model.double_interval == 0:
patch_kwargs["ca_idx"] = ca_idx
set_model_dit_patch_replace(model, patch_kwargs, ("double_block", i))
ca_idx += 1
for i in range(38):
if i % pulid_flux.model.single_interval == 0:
patch_kwargs["ca_idx"] = ca_idx
set_model_dit_patch_replace(model, patch_kwargs, ("single_block", i))
ca_idx += 1
if len(model.get_additional_models_with_key("pulid_flux_model_patcher")) == 0:
model.set_additional_models("pulid_flux_model_patcher", [pulid_flux])
if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, wrappers_name)) == 0:
# Just add it once when connecting in series
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, wrappers_name, pulid_outer_sample_wrappers_with_override)
if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, wrappers_name)) == 0:
# Just add it once when connecting in series
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.APPLY_MODEL, wrappers_name, pulid_apply_model_wrappers)
del eva_clip, face_analysis, pulid_flux, face_helper, attn_mask
return (model,)
class FixPulidFluxPatch:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "fix_pulid_patch"
CATEGORY = "pulid"
def fix_pulid_patch(self, model):
model = model.clone()
if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, wrappers_name)) > 0:
model.remove_wrappers_with_key(comfy.patcher_extension.WrappersMP.APPLY_MODEL, wrappers_name)
if len(model.get_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, wrappers_name)) > 0:
model.remove_wrappers_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, wrappers_name)
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, wrappers_name, pulid_outer_sample_wrappers)
set_model_patch(model, PatchKeys.options_key, pulid_enter, PatchKeys.dit_enter)
set_model_patch(model, PatchKeys.options_key, pulid_patch_double_blocks_after, PatchKeys.dit_double_blocks_after)
return (model,)
class PulidFluxOptions:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input_faces_order": (
["left-right","right-left","top-bottom","bottom-top","small-large","large-small"],
{
"default": "large-small",
"tooltip": "left-right: Sort the left boundary of bbox by column from left to right.\n"
"right-left: Reverse order of left-right (Sort the left boundary of bbox by column from right to left).\n"
"top-bottom: Sort the top boundary of bbox by row from top to bottom.\n"
"bottom-top: Reverse order of top-bottom (Sort the top boundary of bbox by row from bottom to top).\n"
"small-large: Sort the area of bbox from small to large.\n"
"large-small: Sort the area of bbox from large to small."
}
),
"input_faces_index": ("INT",
{
"default": 0, "min": 0, "max": 1000, "step": 1,
"tooltip": "If the value is greater than the size of bboxes, will set value to 0."
}),
"input_faces_align_mode": ("INT",
{
"default": 1, "min": 0, "max": 1, "step": 1,
"tooltip": "Align face mode.\n"
"0: align_face and embed_face use different detectors. The results maybe different.\n"
"1: align_face and embed_face use the same detector."
}),
}
}
RETURN_TYPES = ("OPTIONS",)
FUNCTION = "execute"
CATEGORY = "pulid"
def execute(self,input_faces_order, input_faces_index, input_faces_align_mode=1):
options: dict = {
"input_faces_order": input_faces_order,
"input_faces_index": input_faces_index,
"input_faces_align_mode": input_faces_align_mode,
}
return (options, )
class PulidFluxFaceDetector:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"face_analysis": ("FACEANALYSIS", ),
"image": ("IMAGE",),
"options": ("OPTIONS",),
}
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE",)
RETURN_NAMES = ("embed_face", "align_face", "face_bbox_image",)
FUNCTION = "execute"
CATEGORY = "pulid"
OUTPUT_IS_LIST = (True, True, True,)
def execute(self, face_analysis, image, options):
device = comfy.model_management.get_torch_device()
input_face_sort = options.get('input_faces_order', "large-small")
input_face_index = options.get('input_faces_index', 0)
input_face_align_mode = options.get('input_faces_align_mode', 1)
if input_face_align_mode == 0:
face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
parsing_model='bisenet',
save_ext='png',
device=device,
model_rootpath=FACEXLIB_DIR
)
# Analyse multiple images at multiple sizes and combine largest area embeddings
embed_faces=[]
align_faces=[]
draw_embed_face_bbox=[]
image = tensor_to_image(image)
for i in range(image.shape[0]):
bboxes = []
for size in [(size, size) for size in range(640, 256, -64)]:
face_analysis.det_model.input_size = size
face_info = face_analysis.get(image[i])
if face_info:
face_info, index, sorted_faces = get_face_by_index(face_info, face_sort_rule=input_face_sort,
face_index=input_face_index)
bboxes = [face.bbox for face in sorted_faces]
embed_faces.append(crop_image(image[i], face_info.bbox, margin=10))
draw_embed_face_bbox.append(image_to_tensor(draw_on(image[i], sorted_faces)).unsqueeze(0))
break
else:
# No face detected, skip this image
logging.warning(f'Warning: No face detected in image {str(i)}')
continue
if input_face_align_mode == 1:
image_size = 512
M = face_align.estimate_norm(face_info.kps, image_size=image_size)
align_face = cv2.warpAffine(image[i], M, (image_size, image_size), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132))
# align_face = face_align.norm_crop(image[i], landmark=face_info.kps, image_size=image_size)
del M
else:
# get eva_clip embeddings
face_helper.clean_all()
face_helper.read_image(image[i])
face_helper.get_face_landmarks_5(ref_sort_bboxes=bboxes, face_index=input_face_index)
face_helper.align_warp_face()
if len(face_helper.cropped_faces) == 0:
# No face detected, skip this image
continue
# Get aligned face image
align_face = face_helper.cropped_faces[0]
del face_helper
align_faces.append(image_to_tensor(align_face).unsqueeze(0))
del bboxes, align_face
del image
if len(embed_faces) == 0:
# No face detected, skip this image
logging.warning(f'Warning: No embed face detected in image')
if len(align_faces) == 0:
logging.warning(f'Warning: No align face detected in image')
return embed_faces, align_faces, draw_embed_face_bbox,
def crop_image(image, bbox, margin=0):
if len(image.shape) == 3:
image = image[None, ...]
image = image_to_tensor(image)
x, y, x1, y1 = bbox.astype(int)
w = x1 - x
h = y1 - y
image_height = image.shape[1]
image_width = image.shape[2]
# 左上角坐标
x = min(x, image_width)
y = min(y, image_height)
# 右下角坐标
to_x = min(w + x + margin, image_width)
to_y = min(h + y + margin, image_height)
# 防止越界
x = max(0, x - margin)
y = max(0, y - margin)
to_x = max(0, to_x)
to_y = max(0, to_y)
# 按区域截取图片
crop_img = image[:, y:to_y, x:to_x, :]
return crop_img
def set_hook(diffusion_model, target_forward_orig):
# comfy.ldm.flux.model.Flux.old_forward_orig_for_pulid = comfy.ldm.flux.model.Flux.forward_orig
# comfy.ldm.flux.model.Flux.forward_orig = pulid_forward_orig
diffusion_model.old_forward_orig_for_pulid = diffusion_model.forward_orig
diffusion_model.forward_orig = types.MethodType(target_forward_orig, diffusion_model)
def clean_hook(diffusion_model):
# if hasattr(comfy.ldm.flux.model.Flux, 'old_forward_orig_for_pulid'):
# comfy.ldm.flux.model.Flux.forward_orig = comfy.ldm.flux.model.Flux.old_forward_orig_for_pulid
# del comfy.ldm.flux.model.Flux.old_forward_orig_for_pulid
if hasattr(diffusion_model, 'old_forward_orig_for_pulid'):
diffusion_model.forward_orig = diffusion_model.old_forward_orig_for_pulid
del diffusion_model.old_forward_orig_for_pulid
def pulid_outer_sample_wrappers_with_override(wrapper_executor, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, **kwargs):
cfg_guider = wrapper_executor.class_obj
PULID_model_patch = add_model_patch_option(cfg_guider, PatchKeys.pulid_patch_key_attrs)
PULID_model_patch['latent_image_shape'] = latent_image.shape
diffusion_model = cfg_guider.model_patcher.model.diffusion_model
set_hook(diffusion_model, pulid_forward_orig)
try :
out = wrapper_executor(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, **kwargs)
finally:
del PULID_model_patch['latent_image_shape']
clean_hook(diffusion_model)
del diffusion_model, cfg_guider
return out
def pulid_outer_sample_wrappers(wrapper_executor, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, **kwargs):
cfg_guider = wrapper_executor.class_obj
PULID_model_patch = add_model_patch_option(cfg_guider, PatchKeys.pulid_patch_key_attrs)
PULID_model_patch['latent_image_shape'] = latent_image.shape
try:
out = wrapper_executor(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, **kwargs)
finally:
del PULID_model_patch['latent_image_shape']
return out
def pulid_apply_model_wrappers(wrapper_executor, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
base_model = wrapper_executor.class_obj
PULID_model_patch = transformer_options.get(PatchKeys.pulid_patch_key_attrs, {})
PULID_model_patch['timesteps'] = base_model.model_sampling.timestep(t).float()
try:
transformer_options[PatchKeys.running_net_model] = base_model.diffusion_model
out = wrapper_executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
finally:
if PatchKeys.running_net_model in transformer_options:
del transformer_options[PatchKeys.running_net_model]
del PULID_model_patch['timesteps'], base_model
return out
#facenet implementation
# ──────────────────────── 2. Model caches ──────────────────────────
MTCNN_CACHE = {}
RESNET_CACHE = {}
def get_models(device: torch.device):
"""Lazy-load / cache MTCNN + InceptionResnetV1 for the chosen device."""
if device not in MTCNN_CACHE:
MTCNN_CACHE[device] = MTCNN(
image_size=160,
margin=14,
keep_all=True, # Keep all faces for compatibility
post_process=False,
device=device
)
if device not in RESNET_CACHE:
RESNET_CACHE[device] = (
InceptionResnetV1(pretrained='vggface2')
.eval()
.to(device)
)
return MTCNN_CACHE[device], RESNET_CACHE[device]
# ──────────────────────── 3. Face Info compatibility class ─────────────────────────
class FaceNetFaceInfo:
"""Compatible face info object that mimics InsightFace's face structure"""
def __init__(self, bbox, kps, embedding, det_score=0.9):
self.bbox = bbox # [x1, y1, x2, y2]
self.kps = kps # 5 keypoints: [[x1,y1], [x2,y2], ...]
self.embedding = embedding # 512-D embedding
self.det_score = det_score
# ──────────────────────── 4. Detection Model compatibility class ─────────────────────────
class FaceNetDetModel:
"""Mimics InsightFace's det_model interface"""
def __init__(self):
self.input_size = (640, 640) # Default size, will be modified by pipeline
# ──────────────────────── 5. Main FaceAnalysis compatibility class ─────────────────────────
class FaceNetAnalysis:
"""
FaceNet-based face analysis that mimics InsightFace's FaceAnalysis interface
"""
def __init__(self, device):
self.device = device
self.det_model = FaceNetDetModel()
self.mtcnn = None
self.resnet = None
self._prepared = False
def prepare(self, ctx_id=0, det_size=(640, 640)):
"""Initialize models - called by downstream nodes"""
self.det_model.input_size = det_size
self._prepared = True
def get(self, image):
"""
Main face detection and embedding method - must return list of face objects
Compatible with: face_info = face_analysis.get(image[i])
"""
if not self._prepared:
self.prepare()
# Lazy load models
if self.mtcnn is None or self.resnet is None:
self.mtcnn, self.resnet = get_models(self.device)
# Handle numpy array input (from ComfyUI)
if isinstance(image, np.ndarray):
# Convert from BGR to RGB if needed
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(image)
else:
pil_img = image
# Ensure RGB
if pil_img.mode != 'RGB':
pil_img = pil_img.convert('RGB')
# Resize image based on current input_size setting
input_size = self.det_model.input_size
if isinstance(input_size, (list, tuple)):
target_size = input_size[0] if input_size[0] == input_size[1] else min(input_size)
else:
target_size = input_size
# Resize image for detection
original_size = pil_img.size
if original_size[0] != target_size or original_size[1] != target_size:
pil_img_resized = pil_img.resize((target_size, target_size), Image.Resampling.LANCZOS)
else:
pil_img_resized = pil_img
# Detect faces and get aligned crops
try:
# MTCNN returns: boxes, probs, landmarks (if keep_all=True, returns lists)
boxes, probs, landmarks = self.mtcnn.detect(pil_img_resized, landmarks=True)
if boxes is None or len(boxes) == 0:
return [] # No faces detected
# Get face crops for embedding
aligned_faces = []
face_tensors = []
for i, (box, landmark, prob) in enumerate(zip(boxes, landmarks, probs)):
if prob < 0.9: # Skip low confidence faces
continue
# Extract and align face
try:
face_tensor = self.mtcnn.extract(pil_img_resized, [box], save_path=None)
if face_tensor is not None and len(face_tensor) > 0:
face_tensors.append(face_tensor[0])
aligned_faces.append((box, landmark, prob))
except:
continue
if not face_tensors:
return []
# Get embeddings for all faces
face_tensors = torch.stack(face_tensors).to(self.device)
with torch.no_grad():
embeddings = self.resnet(face_tensors)
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
# Create FaceInfo objects compatible with InsightFace
face_infos = []
scale_x = original_size[0] / target_size
scale_y = original_size[1] / target_size
for i, ((box, landmark, prob), embedding) in enumerate(zip(aligned_faces, embeddings)):
# Scale bbox back to original image size
scaled_bbox = [
box[0] * scale_x, # x1
box[1] * scale_y, # y1
box[2] * scale_x, # x2
box[3] * scale_y # y2
]
# Scale landmarks back to original image size
scaled_kps = landmark.copy()
scaled_kps[:, 0] *= scale_x # x coordinates
scaled_kps[:, 1] *= scale_y # y coordinates
# Convert embedding to numpy for compatibility
embedding_np = embedding.cpu().numpy()
face_info = FaceNetFaceInfo(
bbox=np.array(scaled_bbox),
kps=scaled_kps,
embedding=embedding_np,
det_score=float(prob)
)
face_infos.append(face_info)
return face_infos
except Exception as e:
logging.warning(f"FaceNet face detection failed: {str(e)}")
return []
# ──────────────────────── 6. ComfyUI Node ──────────────────────────
class PulidFluxFaceNetLoader:
"""
FaceNet-based face analysis loader compatible with PuLID-Flux pipeline
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"provider": (["CPU", "CUDA"],),
}
}
RETURN_TYPES = ("FACEANALYSIS",)
FUNCTION = "load_facenet"
CATEGORY = "pulid"
def load_facenet(self, provider: str):
# Map provider to torch device
if provider == "CPU":
device = torch.device("cpu")
elif provider in ["CUDA"]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
# Create FaceNet analysis object
face_analysis = FaceNetAnalysis(device)
# Pre-initialize with default settings
face_analysis.prepare(ctx_id=0, det_size=(640, 640))
return (face_analysis,)
NODE_CLASS_MAPPINGS = {
"PulidFluxModelLoader": PulidFluxModelLoader,
"PulidFluxInsightFaceLoader": PulidFluxInsightFaceLoader,
"PulidFluxEvaClipLoader": PulidFluxEvaClipLoader,
"ApplyPulidFlux": ApplyPulidFlux,
"FixPulidFluxPatch": FixPulidFluxPatch,
"PulidFluxOptions": PulidFluxOptions,
"PulidFluxFaceDetector": PulidFluxFaceDetector,
"PulidFluxFaceNetLoader": PulidFluxFaceNetLoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PulidFluxModelLoader": "Load PuLID Flux Model",
"PulidFluxInsightFaceLoader": "Load InsightFace (PuLID Flux)",
"PulidFluxFaceNetLoader": "Load FaceNet (PuLID Flux)",
"PulidFluxEvaClipLoader": "Load Eva Clip (PuLID Flux)",
"ApplyPulidFlux": "Apply PuLID Flux",
"FixPulidFluxPatch": "Fix PuLID Flux Patch",
"PulidFluxOptions": "Pulid Flux Options",
"PulidFluxFaceDetector": "Pulid Flux Face Detector",
}

View File

@@ -0,0 +1,15 @@
[project]
name = "comfyui_pulid_flux_ll"
description = "The implementation for PuLID-Flux, support use with TeaCache and WaveSpeed, no model pollution."
version = "1.1.4"
license = {file = "LICENSE"}
dependencies = ['cython', 'facexlib', 'insightface', 'onnxruntime', 'onnxruntime-gpu; sys_platform != "darwin" and (platform_machine == "x86_64" or platform_machine == "AMD64")', 'ftfy', 'timm']
[project.urls]
Repository = "https://github.com/lldacing/ComfyUI_PuLID_Flux_ll"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "lldacing"
DisplayName = "ComfyUI_PuLID_Flux_ll"
Icon = ""

View File

@@ -0,0 +1,8 @@
cython
facexlib
insightface
onnxruntime
onnxruntime-gpu; sys_platform != 'darwin' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')
ftfy
timm
facenet-pytorch