""" Unified model loading utility supporting ModelScope, HuggingFace and local path loading """ import os import logging import threading from typing import Optional, Dict, Any, Tuple import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from funasr_detach import AutoModel # Global cache for downloaded models to avoid repeated downloads # Key: (model_path, source) # Value: local_model_path _model_download_cache = {} _download_cache_lock = threading.Lock() class ModelSource: """Model source enumeration""" MODELSCOPE = "modelscope" HUGGINGFACE = "huggingface" LOCAL = "local" AUTO = "auto" # Auto-detect class UnifiedModelLoader: """Unified model loader""" def __init__(self): self.logger = logging.getLogger(__name__) def _cached_snapshot_download(self, model_path: str, source: str, **kwargs) -> str: """ Cached version of snapshot_download to avoid repeated downloads Args: model_path: Model path or ID to download source: Model source ('modelscope' or 'huggingface') **kwargs: Additional arguments for snapshot_download Returns: Local path to downloaded model """ cache_key = (model_path, source, str(sorted(kwargs.items()))) # Check cache first with _download_cache_lock: if cache_key in _model_download_cache: cached_path = _model_download_cache[cache_key] self.logger.info(f"Using cached download for {model_path} from {source}: {cached_path}") return cached_path # Cache miss, need to download if source == ModelSource.MODELSCOPE: from modelscope.hub.snapshot_download import snapshot_download local_path = snapshot_download(model_path, **kwargs) elif source == ModelSource.HUGGINGFACE: from huggingface_hub import snapshot_download local_path = snapshot_download(model_path, **kwargs) else: raise ValueError(f"Unsupported source for cached download: {source}") # Cache the result with _download_cache_lock: _model_download_cache[cache_key] = local_path self.logger.info(f"Downloaded and cached {model_path} from {source}: {local_path}") return local_path def detect_model_source(self, model_path: str) -> str: """ Automatically detect model source Args: model_path: Model path or ID Returns: Model source type """ # Local path detection if os.path.exists(model_path) or os.path.isabs(model_path): return ModelSource.LOCAL # ModelScope format detection (usually includes username/model_name) if "/" in model_path and not model_path.startswith("http"): # If contains modelscope keyword or is known modelscope format if "modelscope" in model_path.lower() or self._is_modelscope_format(model_path): return ModelSource.MODELSCOPE else: # Default to HuggingFace return ModelSource.HUGGINGFACE return ModelSource.LOCAL def _is_modelscope_format(self, model_path: str) -> bool: """Detect if it's ModelScope format model ID""" # Can be judged according to known ModelScope model ID formats # For example: iic/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online modelscope_patterns = [] return any(pattern in model_path for pattern in modelscope_patterns) def _prepare_quantization_config(self, quantization_config: Optional[str], torch_dtype: Optional[torch.dtype] = None) -> Tuple[Dict[str, Any], bool]: """ Prepare quantization configuration for model loading Args: quantization_config: Quantization type ('int4', 'int8', or None) torch_dtype: PyTorch data type for compute operations Returns: Tuple of (quantization parameters dict, should_set_torch_dtype) """ if not quantization_config: return {}, True quantization_config = quantization_config.lower() if quantization_config == "int8": # Use user-specified torch_dtype for compute, default to bfloat16 compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16 self.logger.info(f"🔧 INT8 quantization: using {compute_dtype} for compute operations") bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=compute_dtype, ) return { "quantization_config": bnb_config }, False # INT8 quantization handles data types automatically, don't set torch_dtype elif quantization_config == "int4": # Use user-specified torch_dtype for compute, default to bfloat16 compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16 self.logger.info(f"🔧 INT4 quantization: using {compute_dtype} for compute operations") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=True, ) return { "quantization_config": bnb_config }, False # INT4 quantization handles torch_dtype internally, don't set it again else: raise ValueError(f"Unsupported quantization config: {quantization_config}. Supported: 'int4', 'int8'") def load_transformers_model( self, model_path: str, source: str = ModelSource.AUTO, quantization_config: Optional[str] = None, **kwargs ) -> Tuple: """ Load Transformers model (for StepAudioTTS) Args: model_path: Model path or ID source: Model source, auto means auto-detect quantization_config: Quantization configuration ('int4', 'int8', or None for no quantization) **kwargs: Other parameters (torch_dtype, device_map, etc.) Returns: (model, tokenizer) tuple """ if source == ModelSource.AUTO: source = self.detect_model_source(model_path) self.logger.info(f"Loading Transformers model from {source}: {model_path}") if quantization_config: self.logger.info(f"🔧 {quantization_config.upper()} quantization enabled") # Prepare quantization configuration quantization_kwargs, should_set_torch_dtype = self._prepare_quantization_config(quantization_config, kwargs.get("torch_dtype")) try: if source == ModelSource.LOCAL: # Local loading load_kwargs = { "device_map": kwargs.get("device_map", "auto"), "trust_remote_code": True, "local_files_only": True } # Add quantization configuration if specified load_kwargs.update(quantization_kwargs) # Add torch_dtype based on quantization requirements if should_set_torch_dtype and kwargs.get("torch_dtype") is not None: load_kwargs["torch_dtype"] = kwargs.get("torch_dtype") # Standard loading model = AutoModelForCausalLM.from_pretrained( model_path, **load_kwargs ) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=True ) elif source == ModelSource.MODELSCOPE: # Load from ModelScope from modelscope import AutoModelForCausalLM as MSAutoModelForCausalLM from modelscope import AutoTokenizer as MSAutoTokenizer model_path = self._cached_snapshot_download(model_path, ModelSource.MODELSCOPE) load_kwargs = { "device_map": kwargs.get("device_map", "auto"), "trust_remote_code": True, "local_files_only": True } # Add quantization configuration if specified load_kwargs.update(quantization_kwargs) # Add torch_dtype based on quantization requirements if should_set_torch_dtype and kwargs.get("torch_dtype") is not None: load_kwargs["torch_dtype"] = kwargs.get("torch_dtype") # Standard loading model = MSAutoModelForCausalLM.from_pretrained( model_path, **load_kwargs ) tokenizer = MSAutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=True ) elif source == ModelSource.HUGGINGFACE: model_path = self._cached_snapshot_download(model_path, ModelSource.HUGGINGFACE) # Load from HuggingFace load_kwargs = { "device_map": kwargs.get("device_map", "auto"), "trust_remote_code": True, "local_files_only": True } # Add quantization configuration if specified load_kwargs.update(quantization_kwargs) # Add torch_dtype based on quantization requirements if should_set_torch_dtype and kwargs.get("torch_dtype") is not None: load_kwargs["torch_dtype"] = kwargs.get("torch_dtype") # Standard loading model = AutoModelForCausalLM.from_pretrained( model_path, **load_kwargs ) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, local_files_only=True ) else: raise ValueError(f"Unsupported model source: {source}") self.logger.info(f"Successfully loaded model from {source}") return model, tokenizer, model_path except Exception as e: self.logger.error(f"Failed to load model from {source}: {e}") raise def load_funasr_model( self, repo_path: str, model_path: str, source: str = ModelSource.AUTO, **kwargs ) -> AutoModel: """ Load FunASR model (for StepAudioTokenizer) Args: model_path: Model path or ID source: Model source, auto means auto-detect **kwargs: Other parameters Returns: FunASR AutoModel instance """ if source == ModelSource.AUTO: source = self.detect_model_source(model_path) self.logger.info(f"Loading FunASR model from {source}: {model_path}") try: # Extract model_revision to avoid duplicate passing model_revision = kwargs.pop("model_revision", "main") # Map ModelSource to model_hub parameter if source == ModelSource.LOCAL: model_hub = "local" elif source == ModelSource.MODELSCOPE: model_hub = "ms" elif source == ModelSource.HUGGINGFACE: model_hub = "hf" else: raise ValueError(f"Unsupported model source: {source}") # Use unified download_model for all cases model = AutoModel( repo_path=repo_path, model=model_path, model_hub=model_hub, model_revision=model_revision, **kwargs ) self.logger.info(f"Successfully loaded FunASR model from {source}") return model except Exception as e: self.logger.error(f"Failed to load FunASR model from {source}: {e}") raise def resolve_model_path( self, base_path: str, model_name: str, source: str = ModelSource.AUTO ) -> str: """ Resolve model path Args: base_path: Base path model_name: Model name source: Model source Returns: Resolved model path """ if source == ModelSource.AUTO: # First check local path local_path = os.path.join(base_path, model_name) if os.path.exists(local_path): return local_path # If local doesn't exist, return model name for online download return model_name elif source == ModelSource.LOCAL: return os.path.join(base_path, model_name) else: # For online sources, directly return model name/ID return model_name # Global instance model_loader = UnifiedModelLoader()