doggdad commited on
Commit
ea92a48
Β·
verified Β·
1 Parent(s): 21cbf97

Create app.py

Browse files
Files changed (1) hide show
  1. src/app.py +414 -0
src/app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from os import path as osp
4
+ import gradio as gr
5
+ from dotenv import load_dotenv
6
+ from crud.vector_store import MultimodalLanceDB
7
+ from preprocess.embedding import BridgeTowerEmbeddings
8
+ from preprocess.preprocessing import extract_and_save_frames_and_metadata
9
+ from utils import (
10
+ download_video,
11
+ get_transcript_vtt,
12
+ download_youtube_subtitle,
13
+ get_video_id_from_url,
14
+ str2time,
15
+ maintain_aspect_ratio_resize,
16
+ getSubs,
17
+ encode_image,
18
+ )
19
+ from mistralai import Mistral
20
+ from langchain_core.runnables import (
21
+ RunnableParallel,
22
+ RunnablePassthrough,
23
+ RunnableLambda
24
+ )
25
+ from PIL import Image
26
+ import lancedb
27
+
28
+ # -------------------------------
29
+ # 1. Setup - HuggingFace Spaces Configuration
30
+ # -------------------------------
31
+ load_dotenv()
32
+
33
+ # HuggingFace Spaces specific setup
34
+ SPACE_ID = os.getenv("SPACE_ID")
35
+ IS_SPACES = SPACE_ID is not None
36
+
37
+ if IS_SPACES:
38
+ LANCEDB_HOST_FILE = "/tmp/.lancedb"
39
+ VIDEO_DIR = "/tmp/videos/video1"
40
+ os.makedirs("/tmp", exist_ok=True)
41
+ else:
42
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
43
+ VIDEO_DIR = "./shared_data/videos/video1"
44
+
45
+ TBL_NAME = "vectorstore"
46
+
47
+ # Initialize components
48
+ db = lancedb.connect(LANCEDB_HOST_FILE)
49
+ embedder = BridgeTowerEmbeddings()
50
+
51
+ # -------------------------------
52
+ # 2. Preprocessing + Storage
53
+ # -------------------------------
54
+ def preprocess_and_store(youtube_url: str):
55
+ """Download video, extract frames+metadata, embed & store in LanceDB"""
56
+ try:
57
+ video_url = youtube_url
58
+ video_dir = VIDEO_DIR
59
+
60
+ # download Youtube video
61
+ video_filepath = download_video(video_url, video_dir)
62
+
63
+ # download Youtube video's subtitle
64
+ video_transcript_filepath = download_youtube_subtitle(video_url, video_dir)
65
+
66
+ extracted_frames_path = osp.join(video_dir, 'extracted_frame')
67
+
68
+ # create these output folders if not existing
69
+ Path(extracted_frames_path).mkdir(parents=True, exist_ok=True)
70
+ Path(video_dir).mkdir(parents=True, exist_ok=True)
71
+
72
+ # call the function to extract frames and metadatas
73
+ metadatas = extract_and_save_frames_and_metadata(
74
+ video_filepath,
75
+ video_transcript_filepath,
76
+ extracted_frames_path,
77
+ video_dir,
78
+ )
79
+
80
+ # collect transcripts and image paths
81
+ video_trans = [vid['transcript'] for vid in metadatas]
82
+ video_img_path = [vid['extracted_frame_path'] for vid in metadatas]
83
+
84
+ n = 7
85
+ updated_video_trans = [
86
+ ' '.join(video_trans[i-int(n/2) : i+int(n/2)]) if i-int(n/2) >= 0 else
87
+ ' '.join(video_trans[0 : i + int(n/2)]) for i in range(len(video_trans))
88
+ ]
89
+
90
+ # also need to update the updated transcripts in metadata
91
+ for i in range(len(updated_video_trans)):
92
+ metadatas[i]['transcript'] = updated_video_trans[i]
93
+
94
+ _ = MultimodalLanceDB.from_text_image_pairs(
95
+ texts=updated_video_trans,
96
+ image_paths=video_img_path,
97
+ embedding=embedder,
98
+ metadatas=metadatas,
99
+ connection=db,
100
+ table_name=TBL_NAME,
101
+ mode="overwrite",
102
+ )
103
+
104
+ return f"βœ… Video processed and stored: {youtube_url}"
105
+
106
+ except Exception as e:
107
+ return f"❌ Error processing video: {str(e)}"
108
+
109
+ # -------------------------------
110
+ # 3. Retrieval + Prompt Functions
111
+ # -------------------------------
112
+ vectorstore = MultimodalLanceDB(
113
+ uri=LANCEDB_HOST_FILE,
114
+ embedding=embedder,
115
+ table_name=TBL_NAME
116
+ )
117
+
118
+ retriever_module = vectorstore.as_retriever(
119
+ search_type="similarity",
120
+ search_kwargs={"k": 3}
121
+ )
122
+
123
+ def prompt_processing(input):
124
+ retrieved_results = input["retrieved_results"]
125
+ user_query = input["user_query"]
126
+
127
+ if not retrieved_results:
128
+ return {"prompt": "No relevant content found.", "frame_path": None}
129
+
130
+ retrieved_results = retrieved_results[0]
131
+ prompt_template = (
132
+ "The transcript associated with the image is '{transcript}'. "
133
+ "{user_query}"
134
+ )
135
+
136
+ retrieved_metadata = retrieved_results.metadata
137
+ transcript = retrieved_metadata["transcript"]
138
+ frame_path = retrieved_metadata["extracted_frame_path"]
139
+
140
+ return {
141
+ "prompt": prompt_template.format(transcript=transcript, user_query=user_query),
142
+ "frame_path": frame_path,
143
+ }
144
+
145
+ def lvlm_inference(input):
146
+ try:
147
+ # get the retrieved results and user's query
148
+ lvlm_prompt = input['prompt']
149
+ frame_path = input['frame_path']
150
+
151
+ if frame_path is None:
152
+ return "No relevant frame found.", None
153
+
154
+ # Retrieve the API key from environment variables
155
+ api_key = os.getenv("MISTRAL_API_KEY")
156
+ if not api_key:
157
+ return "❌ MISTRAL_API_KEY not found. Please set it in the environment variables.", frame_path
158
+
159
+ # Initialize the Mistral client
160
+ client = Mistral(api_key=api_key)
161
+
162
+ base64_image = encode_image(frame_path)
163
+
164
+ # Define the messages for the chat
165
+ messages = [
166
+ {
167
+ "role": "user",
168
+ "content": [
169
+ {
170
+ "type": "text",
171
+ "text": lvlm_prompt
172
+ },
173
+ {
174
+ "type": "image_url",
175
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
176
+ }
177
+ ]
178
+ }
179
+ ]
180
+
181
+ # Get the chat response
182
+ chat_response = client.chat.complete(
183
+ model="pixtral-12b-2409",
184
+ messages=messages
185
+ )
186
+
187
+ return chat_response.choices[0].message.content, frame_path
188
+
189
+ except Exception as e:
190
+ return f"❌ Error in inference: {str(e)}", frame_path
191
+
192
+ # LangChain Runnable chain
193
+ prompt_processing_module = RunnableLambda(prompt_processing)
194
+ lvlm_inference_module = RunnableLambda(lvlm_inference)
195
+
196
+ mm_rag_chain = (
197
+ RunnableParallel({"retrieved_results": retriever_module, "user_query": RunnablePassthrough()})
198
+ | prompt_processing_module
199
+ | lvlm_inference_module
200
+ )
201
+
202
+ # -------------------------------
203
+ # 4. Chat API for Gradio
204
+ # -------------------------------
205
+ video_loaded = False
206
+
207
+ def load_video(youtube_url):
208
+ global video_loaded
209
+ if not youtube_url.strip():
210
+ return "❌ Please enter a YouTube URL"
211
+
212
+ try:
213
+ status = preprocess_and_store(youtube_url)
214
+ if "βœ…" in status:
215
+ video_loaded = True
216
+ return status
217
+ except Exception as e:
218
+ return f"❌ Error loading video: {str(e)}"
219
+
220
+ def chat_interface(message, history):
221
+ if not video_loaded:
222
+ return "", history + [(message, "❌ Please load a video first in the 'Load Video' tab.")], None
223
+
224
+ if not message.strip():
225
+ return "", history, None
226
+
227
+ try:
228
+ final_text_response, frame_path = mm_rag_chain.invoke(message)
229
+ history.append((message, final_text_response))
230
+
231
+ # Load and return the image
232
+ retrieved_image = None
233
+ if frame_path:
234
+ try:
235
+ retrieved_image = Image.open(frame_path)
236
+ except Exception as e:
237
+ print(f"Error loading image: {e}")
238
+
239
+ return "", history, retrieved_image
240
+
241
+ except Exception as e:
242
+ error_msg = f"❌ Error processing query: {str(e)}"
243
+ history.append((message, error_msg))
244
+ return "", history, None
245
+
246
+ def clear_chat():
247
+ return [], None
248
+
249
+ # -------------------------------
250
+ # 5. Enhanced Gradio Interface
251
+ # -------------------------------
252
+ with gr.Blocks(
253
+ title="Multimodal RAG Video Chat",
254
+ theme=gr.themes.Default()
255
+ ) as demo:
256
+ gr.Markdown("""
257
+ # 🎬 Multimodal RAG Video Chat
258
+
259
+ Chat with YouTube videos using BridgeTower embeddings + LanceDB + Pixtral Vision-Language Model!
260
+
261
+ ⚠️ **Important**: You need to set your `MISTRAL_API_KEY` in the Space settings for this to work.
262
+ """)
263
+
264
+ with gr.Tab("1. Load Video"):
265
+ with gr.Column():
266
+ youtube_url = gr.Textbox(
267
+ label="YouTube URL",
268
+ placeholder="https://www.youtube.com/watch?v=...",
269
+ lines=1,
270
+ scale=4
271
+ )
272
+ with gr.Row():
273
+ load_btn = gr.Button("πŸ”„ Process Video", variant="primary", scale=1)
274
+ status = gr.Textbox(
275
+ label="Status",
276
+ interactive=False,
277
+ lines=2
278
+ )
279
+
280
+ load_btn.click(
281
+ fn=load_video,
282
+ inputs=youtube_url,
283
+ outputs=status,
284
+ show_progress=True
285
+ )
286
+
287
+ with gr.Tab("2. Chat with Video"):
288
+ with gr.Row():
289
+ with gr.Column(scale=2):
290
+ chatbot = gr.Chatbot(
291
+ label="Chat about the video",
292
+ height=500
293
+ )
294
+
295
+ with gr.Column(scale=1):
296
+ retrieved_image = gr.Image(
297
+ label="Retrieved Frame",
298
+ height=400,
299
+ show_label=True,
300
+ interactive=False
301
+ )
302
+
303
+ with gr.Row():
304
+ with gr.Column(scale=4):
305
+ msg = gr.Textbox(
306
+ label="Your question",
307
+ placeholder="Ask something about the video content...",
308
+ lines=2,
309
+ container=False
310
+ )
311
+ with gr.Column(scale=1, min_width=100):
312
+ send_btn = gr.Button("πŸ“€ Send", variant="primary")
313
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
314
+
315
+ # Event handlers
316
+ msg.submit(
317
+ fn=chat_interface,
318
+ inputs=[msg, chatbot],
319
+ outputs=[msg, chatbot, retrieved_image],
320
+ show_progress=True
321
+ )
322
+ send_btn.click(
323
+ fn=chat_interface,
324
+ inputs=[msg, chatbot],
325
+ outputs=[msg, chatbot, retrieved_image],
326
+ show_progress=True
327
+ )
328
+ clear_btn.click(
329
+ fn=clear_chat,
330
+ outputs=[chatbot, retrieved_image]
331
+ )
332
+
333
+ with gr.Tab("πŸ“– Instructions"):
334
+ gr.Markdown("""
335
+ ## How to use this Multimodal RAG system:
336
+
337
+ ### πŸ”§ Setup:
338
+ 1. **Set API Key**: Make sure `MISTRAL_API_KEY` is set in your Space settings
339
+ 2. This app uses Pixtral-12B for vision-language understanding
340
+
341
+ ### πŸŽ₯ Load Video:
342
+ 1. Go to the "Load Video" tab
343
+ 2. Paste a YouTube URL (make sure it's publicly accessible)
344
+ 3. Click "πŸ”„ Process Video" and wait for processing to complete
345
+ 4. Look for the βœ… success message
346
+
347
+ ### πŸ’¬ Chat with Video:
348
+ 1. Go to the "Chat with Video" tab
349
+ 2. Ask questions about the video content
350
+ 3. The system will retrieve the most relevant frame and provide answers
351
+ 4. The retrieved frame will be displayed on the right side
352
+
353
+ ## ✨ Features:
354
+ - πŸŽ₯ **Automatic YouTube Processing**: Downloads and processes YouTube videos
355
+ - 🧠 **Multimodal Embeddings**: Uses BridgeTower for combined text+image understanding
356
+ - πŸ’Ύ **Vector Storage**: Stores data in LanceDB for fast similarity search
357
+ - πŸ€– **Vision-Language AI**: Powered by Mistral's Pixtral model
358
+ - πŸ–ΌοΈ **Visual Context**: Shows relevant video frames alongside responses
359
+ - πŸ”„ **Real-time Processing**: Fast retrieval and inference
360
+
361
+ ## ⚠️ Limitations:
362
+ - Works with publicly accessible YouTube videos only
363
+ - Processing time depends on video length
364
+ - Requires stable internet connection for video download
365
+ - API rate limits apply based on Mistral usage
366
+
367
+ ## πŸ› οΈ Technical Stack:
368
+ - **Embeddings**: BridgeTower (multimodal)
369
+ - **Vector DB**: LanceDB
370
+ - **Vision-Language Model**: Pixtral-12B
371
+ - **Framework**: LangChain + Gradio
372
+ """)
373
+
374
+ with gr.Tab("πŸ” About"):
375
+ gr.Markdown("""
376
+ ## Multimodal RAG Video Chat System
377
+
378
+ This application demonstrates a complete multimodal Retrieval-Augmented Generation (RAG) pipeline that can understand and answer questions about video content.
379
+
380
+ ### Architecture:
381
+ 1. **Video Processing**: Downloads YouTube videos and extracts frames with timestamps
382
+ 2. **Multimodal Embedding**: Uses BridgeTower to create embeddings that understand both visual and textual content
383
+ 3. **Vector Storage**: Stores embeddings in LanceDB for efficient similarity search
384
+ 4. **Retrieval**: Finds the most relevant video segments based on user queries
385
+ 5. **Generation**: Uses Pixtral vision-language model to generate contextual responses
386
+
387
+ ### Built with:
388
+ - **Gradio**: For the web interface
389
+ - **LangChain**: For orchestrating the RAG pipeline
390
+ - **LanceDB**: For vector storage and retrieval
391
+ - **BridgeTower**: For multimodal embeddings
392
+ - **Mistral Pixtral**: For vision-language understanding
393
+
394
+ ---
395
+
396
+ πŸ’‘ **Tip**: For best results, ask specific questions about visual content, actions, or scenes in the video.
397
+ """)
398
+
399
+ # -------------------------------
400
+ # 6. Launch Configuration
401
+ # -------------------------------
402
+ if __name__ == "__main__":
403
+ print('πŸš€ Starting Multimodal RAG Video Chat App...')
404
+
405
+ # Check for required environment variables
406
+ if not os.getenv("MISTRAL_API_KEY"):
407
+ print("⚠️ WARNING: MISTRAL_API_KEY not found in environment variables")
408
+ print(" Please set this in your HuggingFace Space settings")
409
+
410
+ # Launch with appropriate settings for HF Spaces
411
+ if IS_SPACES:
412
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860) # Use default settings for HF Spaces
413
+ else:
414
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860)