Spaces:
Runtime error
Runtime error
Upload utils.py
Browse files
utils.py
CHANGED
|
@@ -8,11 +8,35 @@ Date: Feb 26, 2025
|
|
| 8 |
import torch
|
| 9 |
import gc
|
| 10 |
import os
|
|
|
|
| 11 |
from PIL import Image, ImageDraw, ImageFont
|
| 12 |
|
| 13 |
# Disable HF transfer to avoid download issues
|
| 14 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def load_models(device="cuda"):
|
| 18 |
"""
|
|
@@ -20,7 +44,10 @@ def load_models(device="cuda"):
|
|
| 20 |
:param device: (str) Device to load models on ('cuda', 'mps', or 'cpu')
|
| 21 |
:return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe)
|
| 22 |
"""
|
| 23 |
-
#
|
|
|
|
|
|
|
|
|
|
| 24 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, StableDiffusionPipeline
|
| 25 |
from transformers import CLIPTokenizer, CLIPTextModel
|
| 26 |
|
|
|
|
| 8 |
import torch
|
| 9 |
import gc
|
| 10 |
import os
|
| 11 |
+
import sys
|
| 12 |
from PIL import Image, ImageDraw, ImageFont
|
| 13 |
|
| 14 |
# Disable HF transfer to avoid download issues
|
| 15 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
|
| 16 |
|
| 17 |
+
# Create a monkey patch for the cached_download function
|
| 18 |
+
# This is needed because newer versions of huggingface_hub
|
| 19 |
+
# removed cached_download but diffusers still tries to import it
|
| 20 |
+
def apply_huggingface_patch():
|
| 21 |
+
import importlib
|
| 22 |
+
import huggingface_hub
|
| 23 |
+
|
| 24 |
+
# Check if cached_download is already available
|
| 25 |
+
if hasattr(huggingface_hub, 'cached_download'):
|
| 26 |
+
return # No need to patch
|
| 27 |
+
|
| 28 |
+
# Create a wrapper around hf_hub_download to mimic the old cached_download
|
| 29 |
+
def cached_download(*args, **kwargs):
|
| 30 |
+
# Forward to the new function with appropriate args
|
| 31 |
+
return huggingface_hub.hf_hub_download(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
# Add the function to the huggingface_hub module
|
| 34 |
+
setattr(huggingface_hub, 'cached_download', cached_download)
|
| 35 |
+
|
| 36 |
+
# Make sure diffusers.utils.dynamic_modules_utils sees the patched module
|
| 37 |
+
if 'diffusers.utils.dynamic_modules_utils' in sys.modules:
|
| 38 |
+
del sys.modules['diffusers.utils.dynamic_modules_utils']
|
| 39 |
+
|
| 40 |
|
| 41 |
def load_models(device="cuda"):
|
| 42 |
"""
|
|
|
|
| 44 |
:param device: (str) Device to load models on ('cuda', 'mps', or 'cpu')
|
| 45 |
:return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe)
|
| 46 |
"""
|
| 47 |
+
# Apply the patch before importing diffusers
|
| 48 |
+
apply_huggingface_patch()
|
| 49 |
+
|
| 50 |
+
# Now we can safely import from diffusers
|
| 51 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, StableDiffusionPipeline
|
| 52 |
from transformers import CLIPTokenizer, CLIPTextModel
|
| 53 |
|