mm: Fix cast buffers with intel offloading (#12229)
Intel has offloading support but there were some nvidia calls in the new cast buffer stuff.
This commit is contained in:
@@ -1112,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
return None
|
return None
|
||||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
del STREAM_CAST_BUFFERS[offload_stream]
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
del cast_buffer
|
del cast_buffer
|
||||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||||
torch.cuda.empty_cache()
|
soft_empty_cache()
|
||||||
with wf_context:
|
with wf_context:
|
||||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
@@ -1132,9 +1132,7 @@ def reset_cast_buffers():
|
|||||||
for offload_stream in STREAM_CAST_BUFFERS:
|
for offload_stream in STREAM_CAST_BUFFERS:
|
||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
if comfy.memory_management.aimdo_allocator is None:
|
soft_empty_cache()
|
||||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@@ -1284,7 +1282,7 @@ def discard_cuda_async_error():
|
|||||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
_ = a + b
|
_ = a + b
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
except torch.AcceleratorError:
|
except torch.AcceleratorError:
|
||||||
#Dump it! We already know about it from the synchronous return
|
#Dump it! We already know about it from the synchronous return
|
||||||
pass
|
pass
|
||||||
@@ -1688,6 +1686,12 @@ def lora_compute_dtype(device):
|
|||||||
LORA_COMPUTE_DTYPES[device] = dtype
|
LORA_COMPUTE_DTYPES[device] = dtype
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
if is_intel_xpu():
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
|||||||
Reference in New Issue
Block a user