from pathlib import Path import os from os import path as osp import gradio as gr from dotenv import load_dotenv from crud.vector_store import MultimodalLanceDB from preprocess.embedding import BridgeTowerEmbeddings from preprocess.preprocessing import extract_and_save_frames_and_metadata from utils import ( download_video, get_transcript_vtt, download_youtube_subtitle, get_video_id_from_url, str2time, maintain_aspect_ratio_resize, getSubs, encode_image, ) from mistralai import Mistral from langchain_core.runnables import ( RunnableParallel, RunnablePassthrough, RunnableLambda ) from PIL import Image import lancedb # ------------------------------- # 1. Setup - HuggingFace Spaces Configuration # ------------------------------- load_dotenv() # HuggingFace Spaces specific setup SPACE_ID = os.getenv("SPACE_ID") IS_SPACES = SPACE_ID is not None if IS_SPACES: LANCEDB_HOST_FILE = "/tmp/.lancedb" VIDEO_DIR = "/tmp/videos/video1" os.makedirs("/tmp", exist_ok=True) else: LANCEDB_HOST_FILE = "./shared_data/.lancedb" VIDEO_DIR = "./shared_data/videos/video1" TBL_NAME = "vectorstore" # Initialize components db = lancedb.connect(LANCEDB_HOST_FILE) embedder = BridgeTowerEmbeddings() # ------------------------------- # 2. Preprocessing + Storage # ------------------------------- def preprocess_and_store(youtube_url: str): """Download video, extract frames+metadata, embed & store in LanceDB""" try: video_url = youtube_url video_dir = VIDEO_DIR # download Youtube video video_filepath = download_video(video_url, video_dir) # download Youtube video's subtitle video_transcript_filepath = download_youtube_subtitle(video_url, video_dir) extracted_frames_path = osp.join(video_dir, 'extracted_frame') # create these output folders if not existing Path(extracted_frames_path).mkdir(parents=True, exist_ok=True) Path(video_dir).mkdir(parents=True, exist_ok=True) # call the function to extract frames and metadatas metadatas = extract_and_save_frames_and_metadata( video_filepath, video_transcript_filepath, extracted_frames_path, video_dir, ) # collect transcripts and image paths video_trans = [vid['transcript'] for vid in metadatas] video_img_path = [vid['extracted_frame_path'] for vid in metadatas] n = 7 updated_video_trans = [ ' '.join(video_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else ' '.join(video_trans[0 : i + int(n/2)]) for i in range(len(video_trans)) ] # also need to update the updated transcripts in metadata for i in range(len(updated_video_trans)): metadatas[i]['transcript'] = updated_video_trans[i] _ = MultimodalLanceDB.from_text_image_pairs( texts=updated_video_trans, image_paths=video_img_path, embedding=embedder, metadatas=metadatas, connection=db, table_name=TBL_NAME, mode="overwrite", ) return f"✅ Video processed and stored: {youtube_url}" except Exception as e: return f"❌ Error processing video: {str(e)}" # ------------------------------- # 3. Retrieval + Prompt Functions # ------------------------------- vectorstore = MultimodalLanceDB( uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME ) retriever_module = vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": 3} ) def prompt_processing(input): retrieved_results = input["retrieved_results"] user_query = input["user_query"] if not retrieved_results: return {"prompt": "No relevant content found.", "frame_path": None} retrieved_results = retrieved_results[0] prompt_template = ( "The transcript associated with the image is '{transcript}'. " "{user_query}" ) retrieved_metadata = retrieved_results.metadata transcript = retrieved_metadata["transcript"] frame_path = retrieved_metadata["extracted_frame_path"] return { "prompt": prompt_template.format(transcript=transcript, user_query=user_query), "frame_path": frame_path, } def lvlm_inference(input): try: # get the retrieved results and user's query lvlm_prompt = input['prompt'] frame_path = input['frame_path'] if frame_path is None: return "No relevant frame found.", None # Retrieve the API key from environment variables api_key = os.getenv("MISTRAL_API_KEY") if not api_key: return "❌ MISTRAL_API_KEY not found. Please set it in the environment variables.", frame_path # Initialize the Mistral client client = Mistral(api_key=api_key) base64_image = encode_image(frame_path) # Define the messages for the chat messages = [ { "role": "user", "content": [ { "type": "text", "text": lvlm_prompt }, { "type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}" } ] } ] # Get the chat response chat_response = client.chat.complete( model="pixtral-12b-2409", messages=messages ) return chat_response.choices[0].message.content, frame_path except Exception as e: return f"❌ Error in inference: {str(e)}", frame_path # LangChain Runnable chain prompt_processing_module = RunnableLambda(prompt_processing) lvlm_inference_module = RunnableLambda(lvlm_inference) mm_rag_chain = ( RunnableParallel({"retrieved_results": retriever_module, "user_query": RunnablePassthrough()}) | prompt_processing_module | lvlm_inference_module ) # ------------------------------- # 4. Chat API for Gradio # ------------------------------- video_loaded = False def load_video(youtube_url): global video_loaded if not youtube_url.strip(): return "❌ Please enter a YouTube URL" try: status = preprocess_and_store(youtube_url) if "✅" in status: video_loaded = True return status except Exception as e: return f"❌ Error loading video: {str(e)}" def chat_interface(message, history): if not video_loaded: return "", history + [(message, "❌ Please load a video first in the 'Load Video' tab.")], None if not message.strip(): return "", history, None try: final_text_response, frame_path = mm_rag_chain.invoke(message) history.append((message, final_text_response)) # Load and return the image retrieved_image = None if frame_path: try: retrieved_image = Image.open(frame_path) except Exception as e: print(f"Error loading image: {e}") return "", history, retrieved_image except Exception as e: error_msg = f"❌ Error processing query: {str(e)}" history.append((message, error_msg)) return "", history, None def clear_chat(): return [], None # ------------------------------- # 5. Enhanced Gradio Interface # ------------------------------- with gr.Blocks( title="Multimodal RAG Video Chat", theme=gr.themes.Default() ) as demo: gr.Markdown(""" # 🎬 Multimodal RAG Video Chat Chat with YouTube videos using BridgeTower embeddings + LanceDB + Pixtral Vision-Language Model! ⚠️ **Important**: You need to set your `MISTRAL_API_KEY` in the Space settings for this to work. """) with gr.Tab("1. Load Video"): with gr.Column(): youtube_url = gr.Textbox( label="YouTube URL", placeholder="https://www.youtube.com/watch?v=...", lines=1, scale=4 ) with gr.Row(): load_btn = gr.Button("🔄 Process Video", variant="primary", scale=1) status = gr.Textbox( label="Status", interactive=False, lines=2 ) load_btn.click( fn=load_video, inputs=youtube_url, outputs=status, show_progress=True ) with gr.Tab("2. Chat with Video"): with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Chat about the video", height=500 ) with gr.Column(scale=1): retrieved_image = gr.Image( label="Retrieved Frame", height=400, show_label=True, interactive=False ) with gr.Row(): with gr.Column(scale=4): msg = gr.Textbox( label="Your question", placeholder="Ask something about the video content...", lines=2, container=False ) with gr.Column(scale=1, min_width=100): send_btn = gr.Button("📤 Send", variant="primary") clear_btn = gr.Button("🗑️ Clear", variant="secondary") # Event handlers msg.submit( fn=chat_interface, inputs=[msg, chatbot], outputs=[msg, chatbot, retrieved_image], show_progress=True ) send_btn.click( fn=chat_interface, inputs=[msg, chatbot], outputs=[msg, chatbot, retrieved_image], show_progress=True ) clear_btn.click( fn=clear_chat, outputs=[chatbot, retrieved_image] ) with gr.Tab("📖 Instructions"): gr.Markdown(""" ## How to use this Multimodal RAG system: ### 🔧 Setup: 1. **Set API Key**: Make sure `MISTRAL_API_KEY` is set in your Space settings 2. This app uses Pixtral-12B for vision-language understanding ### 🎥 Load Video: 1. Go to the "Load Video" tab 2. Paste a YouTube URL (make sure it's publicly accessible) 3. Click "🔄 Process Video" and wait for processing to complete 4. Look for the ✅ success message ### 💬 Chat with Video: 1. Go to the "Chat with Video" tab 2. Ask questions about the video content 3. The system will retrieve the most relevant frame and provide answers 4. The retrieved frame will be displayed on the right side ## ✨ Features: - 🎥 **Automatic YouTube Processing**: Downloads and processes YouTube videos - 🧠 **Multimodal Embeddings**: Uses BridgeTower for combined text+image understanding - 💾 **Vector Storage**: Stores data in LanceDB for fast similarity search - 🤖 **Vision-Language AI**: Powered by Mistral's Pixtral model - 🖼️ **Visual Context**: Shows relevant video frames alongside responses - 🔄 **Real-time Processing**: Fast retrieval and inference ## ⚠️ Limitations: - Works with publicly accessible YouTube videos only - Processing time depends on video length - Requires stable internet connection for video download - API rate limits apply based on Mistral usage ## 🛠️ Technical Stack: - **Embeddings**: BridgeTower (multimodal) - **Vector DB**: LanceDB - **Vision-Language Model**: Pixtral-12B - **Framework**: LangChain + Gradio """) with gr.Tab("🔍 About"): gr.Markdown(""" ## Multimodal RAG Video Chat System This application demonstrates a complete multimodal Retrieval-Augmented Generation (RAG) pipeline that can understand and answer questions about video content. ### Architecture: 1. **Video Processing**: Downloads YouTube videos and extracts frames with timestamps 2. **Multimodal Embedding**: Uses BridgeTower to create embeddings that understand both visual and textual content 3. **Vector Storage**: Stores embeddings in LanceDB for efficient similarity search 4. **Retrieval**: Finds the most relevant video segments based on user queries 5. **Generation**: Uses Pixtral vision-language model to generate contextual responses ### Built with: - **Gradio**: For the web interface - **LangChain**: For orchestrating the RAG pipeline - **LanceDB**: For vector storage and retrieval - **BridgeTower**: For multimodal embeddings - **Mistral Pixtral**: For vision-language understanding --- 💡 **Tip**: For best results, ask specific questions about visual content, actions, or scenes in the video. """) # ------------------------------- # 6. Launch Configuration # ------------------------------- if __name__ == "__main__": print('🚀 Starting Multimodal RAG Video Chat App...') # Check for required environment variables if not os.getenv("MISTRAL_API_KEY"): print("⚠️ WARNING: MISTRAL_API_KEY not found in environment variables") print(" Please set this in your HuggingFace Space settings") # Launch with appropriate settings for HF Spaces if IS_SPACES: demo.launch(share=True, server_name="0.0.0.0", server_port=7860) # Use default settings for HF Spaces else: demo.launch(share=True, server_name="0.0.0.0", server_port=7860)