Better detection if AMD torch compiled with efficient attention. (#11745)
This commit is contained in:
@@ -22,7 +22,6 @@ from enum import Enum
|
|||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import importlib
|
|
||||||
import platform
|
import platform
|
||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
@@ -349,10 +348,22 @@ try:
|
|||||||
except:
|
except:
|
||||||
rocm_version = (6, -1)
|
rocm_version = (6, -1)
|
||||||
|
|
||||||
|
def aotriton_supported(gpu_arch):
|
||||||
|
path = torch.__path__[0]
|
||||||
|
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
|
||||||
|
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
|
||||||
|
if gpu_arch in gfx:
|
||||||
|
return True
|
||||||
|
if "{}x".format(gpu_arch[:-1]) in gfx:
|
||||||
|
return True
|
||||||
|
if "{}xx".format(gpu_arch[:-2]) in gfx:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
|
|||||||
Reference in New Issue
Block a user