Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import logging | |
| import sys | |
| from typing import Optional, Literal | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from pyannote.audio import Pipeline | |
| from huggingface_hub import HfApi | |
| from torchaudio import functional as F # For resampling and audio processing | |
| # To run this Gradio demo, first ensure you have Python 3.9+ installed. | |
| # Then, create a virtual environment and install the required packages. | |
| # | |
| # 1. Create and activate a virtual environment (recommended): | |
| # python3 -m venv venv | |
| # source venv/bin/activate # On Linux/macOS | |
| # venv\Scripts\activate # On Windows | |
| # | |
| # 2. Install the necessary packages: | |
| # pip install gradio==4.20.1 \ | |
| # torch==2.2.1 \ | |
| # torchaudio==2.2.1 \ | |
| # transformers==4.38.2 \ | |
| # pyannote-audio==3.1.1 \ | |
| # numpy==1.26.4 \ | |
| # fastapi==0.110.0 \ | |
| # uvicorn==0.27.1 \ | |
| # python-multipart==0.0.9 \ | |
| # pydantic==2.6.1 \ | |
| # soundfile==0.12.1 # Required by torchaudio and pyannote for certain audio formats | |
| # | |
| # # If you have a CUDA-compatible GPU, install the CUDA version of PyTorch instead: | |
| # # For CUDA 12.1 (adjust 'cu121' to your CUDA version, e.g., 'cu118' for CUDA 11.8): | |
| # # pip install torch==2.2.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121 | |
| # | |
| # 3. Set your Hugging Face Token (required for pyannote/speaker-diarization-3.1). | |
| # You must accept the user conditions on the model page: https://huggingface.co/pyannote/speaker-diarization-3.1 | |
| # export HF_TOKEN="hf_YOUR_TOKEN_HERE" | |
| # # Or directly in the script (not recommended for security): | |
| # # HF_TOKEN = "hf_YOUR_TOKEN_HERE" | |
| # | |
| # 4. Save this file as, for example, `app.py`. | |
| # | |
| # 5. Run the application using uvicorn (as requested): | |
| # uvicorn app:demo --host 0.0.0.0 --port 7860 | |
| # | |
| # # If you just want to run it via Python script (Gradio's default, without uvicorn directly): | |
| # # python app.py | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Configuration --- | |
| # You will need a Hugging Face token for pyannote/speaker-diarization-3.1. | |
| # 1. Go to https://huggingface.co/settings/tokens to create a new token. | |
| # 2. Make sure you have accepted the user conditions on the model page: | |
| # https://huggingface.co/pyannote/speaker-diarization-3.1 | |
| # 3. Set your token as an environment variable before running this script: | |
| # export HF_TOKEN="hf_YOUR_TOKEN_HERE" | |
| # Alternatively, replace os.getenv("HF_TOKEN") with your actual token string: | |
| # HF_TOKEN = "hf_YOUR_TOKEN_HERE" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Model names | |
| ASR_MODEL = "openai/whisper-large-v3-turbo" # Smaller, faster Whisper model for demo | |
| DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1" | |
| # Speculative decoding (assistant model) is explicitly excluded as per requirements. | |
| # --- Inference Configuration (Pydantic Model for validation) --- | |
| class InferenceConfig(BaseModel): | |
| task: Literal["transcribe", "translate"] = "transcribe" | |
| batch_size: int = 1 | |
| chunk_length_s: int = 30 | |
| language: Optional[str] = None | |
| num_speakers: Optional[int] = None | |
| min_speakers: Optional[int] = None | |
| max_speakers: Optional[int] = None | |
| # --- Global Models and Device --- | |
| models = {} | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| logger.info(f"Using device: {device.type}") | |
| torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 # Use float16 on GPU for efficiency | |
| # --- Model Loading Function --- | |
| def load_models(): | |
| """ | |
| Loads the ASR and Diarization models into the global `models` dictionary. | |
| Handles device placement and Hugging Face token authentication. | |
| """ | |
| logger.info("Loading ASR pipeline...") | |
| # The ASR pipeline can directly take a numpy array for inference. | |
| models["asr_pipeline"] = pipeline( | |
| "automatic-speech-recognition", | |
| model=ASR_MODEL, | |
| torch_dtype=torch_dtype, | |
| device=device | |
| ) | |
| logger.info("ASR pipeline loaded.") | |
| if DIARIZATION_MODEL: | |
| logger.info(f"Loading Diarization pipeline: {DIARIZATION_MODEL}...") | |
| if not HF_TOKEN: | |
| raise ValueError( | |
| "HF_TOKEN environment variable or HF_TOKEN constant not set. " | |
| "Pyannote models require a Hugging Face token for authentication. " | |
| "Get it from https://huggingface.co/settings/tokens and ensure you accept " | |
| "the user conditions on the model page: " | |
| "https://huggingface.co/pyannote/speaker-diarization-3.1" | |
| ) | |
| try: | |
| # Verify token and load pyannote pipeline | |
| HfApi().whoami(token=HF_TOKEN) # Check token validity | |
| models["diarization_pipeline"] = Pipeline.from_pretrained( | |
| checkpoint_path=DIARIZATION_MODEL, | |
| use_auth_token=HF_TOKEN, | |
| ) | |
| models["diarization_pipeline"].to(device) | |
| logger.info("Diarization pipeline loaded.") | |
| except Exception as e: | |
| logger.error(f"Failed to load diarization pipeline: {e}") | |
| raise | |
| else: | |
| models["diarization_pipeline"] = None | |
| logger.info("Diarization model not specified, diarization will be skipped.") | |
| # Load models once when the script starts | |
| try: | |
| load_models() | |
| except Exception as e: | |
| logger.critical(f"Failed to load models. Please check your HF_TOKEN and model names. Exiting: {e}") | |
| sys.exit(1) | |
| # --- Diarization Utility Functions (adapted from original `model-server/app/utils/diarization_utils.py`) --- | |
| def preprocess_audio_for_diarization(sampling_rate_in: int, audio_array_in: np.ndarray) -> tuple[torch.Tensor, int]: | |
| """ | |
| Preprocesses audio for the diarization pipeline. | |
| Resamples to 16kHz and ensures single channel float32 torch tensor. | |
| """ | |
| if audio_array_in is None or audio_array_in.size == 0: | |
| raise ValueError("Audio array is empty for diarization preprocessing.") | |
| # Convert to float32 if not already (pyannote expects float32) | |
| if audio_array_in.dtype != np.float32: | |
| audio_array_in = audio_array_in.astype(np.float32) | |
| # If stereo, take one channel (pyannote expects single channel) | |
| if len(audio_array_in.shape) > 1: | |
| audio_array_in = audio_array_in[:, 0] # Take the first channel | |
| # Resample to 16kHz if necessary, as pyannote models are typically trained on 16kHz audio. | |
| if sampling_rate_in != 16000: | |
| audio_array_in = F.resample( | |
| torch.from_numpy(audio_array_in), sampling_rate_in, 16000 | |
| ).numpy() | |
| sampling_rate_in = 16000 # Update SR to reflect resampling | |
| # Diarization model expects float32 torch tensor of shape `(channels, seq_len)` | |
| diarizer_inputs = torch.from_numpy(audio_array_in).float() | |
| diarizer_inputs = diarizer_inputs.unsqueeze(0) # Add channel dimension (1, seq_len) | |
| return diarizer_inputs, sampling_rate_in | |
| def diarize_audio(diarizer_inputs: torch.Tensor, diarization_pipeline: Pipeline, parameters: InferenceConfig) -> list: | |
| """ | |
| Performs diarization using the pyannote pipeline and combines consecutive speaker segments. | |
| """ | |
| # Run the diarization pipeline | |
| diarization = diarization_pipeline( | |
| {"waveform": diarizer_inputs, "sample_rate": 16000}, # Always pass 16kHz to diarizer | |
| num_speakers=parameters.num_speakers, | |
| min_speakers=parameters.min_speakers, | |
| max_speakers=parameters.max_speakers, | |
| ) | |
| raw_segments = [] | |
| # pyannote.audio returns segments as `Segment(start=X, end=Y)` | |
| for segment, _, label in diarization.itertracks(yield_label=True): | |
| raw_segments.append( | |
| { | |
| "segment": {"start": segment.start, "end": segment.end}, | |
| "label": label, | |
| } | |
| ) | |
| # Combine consecutive segments from the same speaker | |
| combined_segments = [] | |
| if not raw_segments: | |
| return combined_segments | |
| # Initialize with the first segment | |
| current_speaker_segment = { | |
| "speaker": raw_segments[0]["label"], | |
| "segment": {"start": raw_segments[0]["segment"]["start"], "end": raw_segments[0]["segment"]["end"]}, | |
| } | |
| for i in range(1, len(raw_segments)): | |
| next_segment = raw_segments[i] | |
| # If the speaker changes | |
| if next_segment["label"] != current_speaker_segment["speaker"]: | |
| # Add the accumulated segment for the previous speaker | |
| combined_segments.append(current_speaker_segment) | |
| # Start a new segment accumulation with the current speaker | |
| current_speaker_segment = { | |
| "speaker": next_segment["label"], | |
| "segment": {"start": next_segment["segment"]["start"], "end": next_segment["segment"]["end"]}, | |
| } | |
| else: | |
| # Same speaker, extend the end time of the current accumulated segment | |
| current_speaker_segment["segment"]["end"] = next_segment["segment"]["end"] | |
| # Add the very last accumulated segment after the loop finishes | |
| combined_segments.append(current_speaker_segment) | |
| return combined_segments | |
| def post_process_segments_and_transcripts(combined_diarization_segments: list, asr_transcript_chunks: list) -> list: | |
| """ | |
| Aligns combined diarization segments with ASR transcript chunks. | |
| This logic closely follows the provided `diarization_utils.py`'s `post_process_segments_and_transcripts` | |
| function, which uses `argmin` for alignment and slicing for chunk consumption. | |
| """ | |
| if not asr_transcript_chunks: | |
| return [] | |
| # Get the end timestamps for each ASR chunk | |
| # Use sys.float_info.max for None to ensure `argmin` works | |
| asr_end_timestamps = np.array( | |
| [chunk["timestamp"][1] if chunk["timestamp"][1] is not None else sys.float_info.max for chunk in asr_transcript_chunks] | |
| ) | |
| # Create mutable copies to slice from | |
| current_asr_chunks = list(asr_transcript_chunks) | |
| current_asr_end_timestamps = asr_end_timestamps.copy() | |
| final_segmented_transcript = [] | |
| for diar_segment in combined_diarization_segments: | |
| if not current_asr_chunks: | |
| break # No more ASR chunks to process | |
| diar_start = diar_segment["segment"]["start"] | |
| diar_end = diar_segment["segment"]["end"] | |
| speaker = diar_segment["speaker"] | |
| # Find the index of the ASR chunk whose end timestamp is closest to diar_end | |
| # Ensure argmin operates on a non-empty array | |
| if current_asr_end_timestamps.size == 0: | |
| logger.warning("No ASR end timestamps left to align with diarization segment. Breaking alignment.") | |
| break # No more ASR chunks to align | |
| upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end)) | |
| chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1] | |
| if not chunks_for_this_diar_segment: | |
| logger.warning(f"No ASR chunks selected for diarization segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Skipping.") | |
| continue | |
| # Initialize with extreme values to find min/max correctly, handling None timestamps | |
| asr_min_start_val = float('inf') | |
| asr_max_end_val = float('-inf') | |
| all_text = [] | |
| for chunk in chunks_for_this_diar_segment: | |
| all_text.append(chunk["text"]) | |
| if chunk["timestamp"] and chunk["timestamp"][0] is not None: | |
| asr_min_start_val = min(asr_min_start_val, chunk["timestamp"][0]) | |
| if chunk["timestamp"] and chunk["timestamp"][1] is not None: | |
| asr_max_end_val = max(asr_max_end_val, chunk["timestamp"][1]) | |
| combined_text = "".join(all_text).strip() | |
| # If no valid timestamps were found in the selected ASR chunks, fall back to diarization segment's bounds | |
| if asr_min_start_val == float('inf'): | |
| logger.warning(f"No valid start timestamps in ASR chunks for segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Using diarization start.") | |
| asr_min_start_val = diar_start | |
| if asr_max_end_val == float('-inf'): | |
| logger.warning(f"No valid end timestamps in ASR chunks for segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Using diarization end.") | |
| asr_max_end_val = diar_end | |
| # Ensure final timestamp range makes sense and is clamped by diarization segment | |
| final_segment_start = max(diar_start, asr_min_start_val) | |
| final_segment_end = min(diar_end, asr_max_end_val) | |
| final_segmented_transcript.append( | |
| { | |
| "speaker": speaker, | |
| "text": combined_text, | |
| "timestamp": (final_segment_start, final_segment_end), | |
| } | |
| ) | |
| # Crop the transcripts and timestamp lists according to the latest timestamp | |
| current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:] | |
| current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:] | |
| return final_segmented_transcript | |
| def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int, | |
| audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list: | |
| """ | |
| Orchestrates the entire diarization and transcript alignment process. | |
| """ | |
| # 1. Preprocess audio for the diarization model (resample to 16kHz, ensure mono, convert to torch.Tensor) | |
| diarizer_input_tensor, processed_sampling_rate = preprocess_audio_for_diarization( | |
| original_sampling_rate, audio_numpy_array | |
| ) | |
| # 2. Perform diarization to get speaker segments | |
| # Update parameters with the processed sampling rate for diarization model's internal use. | |
| diarization_params_for_pipeline = parameters.model_copy(update={"sampling_rate": processed_sampling_rate}) | |
| combined_diarization_segments = diarize_audio( | |
| diarizer_input_tensor, | |
| diarization_pipeline, | |
| diarization_params_for_pipeline | |
| ) | |
| # 3. Align diarization segments with ASR transcript chunks | |
| aligned_transcript = post_process_segments_and_transcripts( | |
| combined_diarization_segments, asr_outputs["chunks"] | |
| ) | |
| return aligned_transcript | |
| # --- Main Prediction Function for Gradio Interface --- | |
| def predict_audio( | |
| audio_file_tuple: tuple[int, np.ndarray], | |
| batch_size: int, | |
| chunk_length_s: int, | |
| language: str, | |
| num_speakers: Optional[int], | |
| min_speakers: Optional[int], | |
| max_speakers: Optional[int] | |
| ) -> tuple[str, str, str]: | |
| """ | |
| Gradio-compatible function to perform ASR and optionally speaker diarization. | |
| Args: | |
| audio_file_tuple: A tuple (sampling_rate, numpy_array) from Gradio's gr.Audio input. | |
| batch_size: Batch size for ASR inference. | |
| chunk_length_s: Chunk length for ASR inference in seconds. | |
| language: Language for ASR (e.g., "English", "Auto-detect"). | |
| num_speakers: Expected number of speakers for diarization (optional). | |
| min_speakers: Minimum number of speakers for diarization (optional). | |
| max_speakers: Maximum number of speakers for diarization (optional). | |
| Returns: | |
| A tuple containing: | |
| - formatted_diarized_text: A string with the diarized transcript. | |
| - full_transcript_text: A string with the full ASR transcript. | |
| - status_message: A message indicating success or failure. | |
| """ | |
| if audio_file_tuple is None: | |
| return "", "", gr.Warning("Please upload an audio file.") | |
| sampling_rate, audio_numpy_array = audio_file_tuple | |
| if audio_numpy_array is None or audio_numpy_array.size == 0: | |
| return "", "", gr.Warning("Audio file is empty. Please upload a valid audio.") | |
| # Ensure audio_numpy_array is float32 as expected by transformers pipeline | |
| if audio_numpy_array.dtype != np.float32: | |
| audio_numpy_array = audio_numpy_array.astype(np.float32) | |
| # If stereo, convert to mono for consistent processing (e.g., take the first channel) | |
| if len(audio_numpy_array.shape) > 1: | |
| audio_numpy_array = audio_numpy_array[:, 0] | |
| # Process speaker parameters: convert 0 or negative values to None for pyannote compatibility | |
| processed_num_speakers = num_speakers if num_speakers is not None and num_speakers > 0 else None | |
| processed_min_speakers = min_speakers if min_speakers is not None and min_speakers > 0 else None | |
| processed_max_speakers = max_speakers if max_speakers is not None and max_speakers > 0 else None | |
| # Validation logic for min/max speakers | |
| if processed_min_speakers is not None and processed_max_speakers is not None and processed_min_speakers > processed_max_speakers: | |
| return "", "", gr.Warning("Diarization: Min Speakers cannot be greater than Max Speakers.") | |
| if processed_num_speakers is not None: | |
| if processed_min_speakers is not None and processed_num_speakers < processed_min_speakers: | |
| return "", "", gr.Warning("Diarization: Number of Speakers cannot be less than Min Speakers.") | |
| if processed_max_speakers is not None and processed_num_speakers > processed_max_speakers: | |
| return "", "", gr.Warning("Diarization: Number of Speakers cannot be greater than Max Speakers.") | |
| # Create an InferenceConfig object from Gradio inputs for internal validation and use. | |
| try: | |
| parameters = InferenceConfig( | |
| batch_size=batch_size, | |
| chunk_length_s=chunk_length_s, | |
| language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model | |
| num_speakers=processed_num_speakers, | |
| min_speakers=processed_min_speakers, | |
| max_speakers=processed_max_speakers, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error validating parameters: {e}") | |
| return "", "", gr.Error(f"Error validating input parameters: {e}") # Use gr.Error for critical validation failures | |
| logger.info(f"Inference parameters: {parameters.model_dump_json()}") | |
| logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}") | |
| asr_pipeline = models.get("asr_pipeline") | |
| diarization_pipeline = models.get("diarization_pipeline") | |
| if not asr_pipeline: | |
| return "", "", gr.Error("ASR model not loaded. Please restart the application.") | |
| # ASR language and batch size conflict warning/error | |
| if parameters.language is None and parameters.batch_size > 1: | |
| return "", "", gr.Warning( | |
| "ASR: 'Auto-detect' language is not supported with batch size > 1. " | |
| "Please select a specific language or set batch size to 1." | |
| ) | |
| # Prepare ASR generation arguments | |
| generate_kwargs = { | |
| "task": parameters.task, | |
| "language": parameters.language, | |
| "assistant_model": None # Speculative decoding is disabled | |
| } | |
| asr_outputs = None | |
| try: | |
| logger.info("Starting ASR inference...") | |
| asr_outputs = asr_pipeline( | |
| audio_numpy_array, # Pass numpy array directly | |
| chunk_length_s=parameters.chunk_length_s, | |
| batch_size=parameters.batch_size, | |
| generate_kwargs=generate_kwargs, | |
| return_timestamps=True, | |
| # sampling_rate=sampling_rate # Pass original sampling rate to pipeline | |
| ) | |
| logger.info("ASR inference completed.") | |
| except Exception as e: | |
| logger.error(f"ASR inference error: {str(e)}") | |
| return "", "", gr.Error(f"ASR inference error: {str(e)}") | |
| final_transcript_data = [] | |
| status_message = "" | |
| if diarization_pipeline: | |
| try: | |
| logger.info("Starting Diarization inference and alignment...") | |
| final_transcript_data = diarize_and_align_transcript( | |
| diarization_pipeline, sampling_rate, audio_numpy_array, parameters, asr_outputs | |
| ) | |
| status_message = "Diarization and ASR successful!" | |
| logger.info("Diarization and alignment completed.") | |
| except Exception as e: | |
| logger.error(f"Diarization inference error: {str(e)}") | |
| # If diarization fails, still provide the full ASR transcript | |
| final_transcript_data = [] # Clear any partial diarization | |
| status_message = f"Diarization failed: {str(e)}. Displaying full ASR transcript only." | |
| else: | |
| logger.info("Diarization pipeline not loaded, skipping diarization and returning raw ASR chunks.") | |
| # If no diarization, format ASR chunks as if they were from a single "Speaker" | |
| for chunk in asr_outputs["chunks"]: | |
| final_transcript_data.append({ | |
| "speaker": "Speaker", # Generic label | |
| "text": chunk["text"], | |
| "timestamp": chunk["timestamp"] | |
| }) | |
| status_message = "Diarization not enabled. Displaying full ASR transcript by chunk." | |
| # Format the output for Gradio display | |
| formatted_diarized_text_output = [] | |
| for entry in final_transcript_data: | |
| start_time = f"{entry['timestamp'][0]:.2f}" if entry['timestamp'][0] is not None else "0.00" | |
| end_time = f"{entry['timestamp'][1]:.2f}" if entry['timestamp'][1] is not None else "End" | |
| formatted_diarized_text_output.append( | |
| f"[{start_time} - {end_time}] {entry['speaker']}: {entry['text'].strip()}" | |
| ) | |
| full_asr_text_output = asr_outputs["text"] if asr_outputs else "No ASR transcript generated." | |
| return ( | |
| "\n".join(formatted_diarized_text_output), | |
| full_asr_text_output, | |
| status_message | |
| ) | |
| # --- Gradio Interface Definition --- | |
| # List of languages supported by OpenAI Whisper models | |
| WHISPER_LANGUAGES = [ | |
| "Auto-detect", "English", "Chinese", "German", "Spanish", "Russian", "Korean", "French", "Japanese", "Portuguese", | |
| "Turkish", "Polish", "Catalan", "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", "Finnish", | |
| "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", "Czech", "Romanian", "Danish", "Hungarian", "Tamil", | |
| "Norwegian", "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", "Maori", "Malayalam", "Afrikaans", | |
| "Welsh", "Belarusian", "Gujarati", "Kannada", "Armenian", "Azerbaijani", "Serbian", "Slovenian", "Estonian", | |
| "Burmese", "Galician", "Mongolian", "Lao", "Kazakh", "Georgian", "Amharic", "Nepali", "Bosnian", "Luxembourgish", | |
| "Pashto", "Tagalog", "Malagasy", "Albanian", "Sindhi", "Kurdish", "Somali", "Telugu", "Tajik", "Swahili", | |
| "Kashmiri" | |
| ] | |
| demo = gr.Interface( | |
| fn=predict_audio, | |
| inputs=[ | |
| gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"), | |
| gr.Slider(minimum=1, maximum=32, value=1, step=1, label="ASR Batch Size"), | |
| gr.Slider(minimum=1, maximum=30, value=30, step=1, label="ASR Chunk Length (seconds)"), | |
| gr.Dropdown(WHISPER_LANGUAGES, value="Chinese", label="ASR Language"), | |
| gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers (positive integer, or leave empty for auto-detect)."), | |
| gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect (positive integer, or leave empty for auto-detect)."), | |
| gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect (positive integer, or leave empty for auto-detect).") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Diarized Transcript", lines=10, interactive=False), | |
| gr.Textbox(label="Full ASR Transcript", lines=5, interactive=False), | |
| gr.Textbox(label="Status Message", lines=1, interactive=False) | |
| ], | |
| title="Whisper ASR with Pyannote Speaker Diarization", | |
| description=( | |
| "Upload an audio file to get a transcript with speaker diarization. " | |
| "This demo uses `openai/whisper-small` for ASR and `pyannote/speaker-diarization-3.1` for diarization. " | |
| "A Hugging Face token with access to `pyannote/speaker-diarization-3.1` is required. " | |
| "Please set it as an `HF_TOKEN` environment variable before launching (see script comments)." | |
| "<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`." | |
| ), | |
| allow_flagging="never", # Disable Gradio flagging feature | |
| examples=[ | |
| # Adjust this path if the `model-server/app/tests/` directory is not alongside your `app.py` | |
| # For example, if app.py is in the root, and the audio is in a tests/ subdirectory, | |
| # you might use: ["tests/polyai-minds14-0.wav", 24, 30, "Auto-detect", None, None, None] | |
| [os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None] | |
| ], | |
| cache_examples=False # Disable caching of examples to prevent InvalidPathError | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio demo...") | |
| demo.launch() |