alx-d commited on
Commit
5e4c8fa
ยท
verified ยท
1 Parent(s): 8c804eb

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. advanced_rag.py +526 -196
  2. advanced_rag.py.bak +610 -0
advanced_rag.py CHANGED
@@ -68,20 +68,205 @@ def count_tokens(text: str) -> int:
68
  return len(text.split())
69
  return len(text.split())
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
 
 
 
 
72
  if global_tokenizer:
73
  try:
74
  tokens = global_tokenizer.encode(prompt)
75
  if len(tokens) > max_tokens:
76
- tokens = tokens[-max_tokens:] # keep the last max_tokens tokens
77
- return global_tokenizer.decode(tokens)
 
 
 
 
 
78
  except Exception as e:
79
- debug_print("Truncation error: " + str(e))
 
 
80
  words = prompt.split()
81
  if len(words) > max_tokens:
82
- return " ".join(words[-max_tokens:])
 
 
 
 
83
  return prompt
84
 
 
 
 
85
  default_prompt = """\
86
  {conversation_history}
87
  Use the following context to provide a detailed technical answer to the user's question.
@@ -141,123 +326,169 @@ class ElevatedRagChain:
141
  def extract_question(self, input_data):
142
  return input_data["question"]
143
 
 
144
  def create_llm_pipeline(self):
145
  normalized = self.llm_choice.lower()
146
- if "remote" in normalized:
147
- debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
148
- from huggingface_hub import InferenceClient
149
- repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
150
- hf_api_token = os.environ.get("HF_API_TOKEN")
151
- if not hf_api_token:
152
- raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
153
- client = InferenceClient(token=hf_api_token, timeout=240)
154
-
155
- from huggingface_hub.utils._errors import HfHubHTTPError
156
- def remote_generate(prompt: str) -> str:
157
- max_retries = 5
158
- backoff = 2 # start with 2 seconds
159
- response = None
160
- for attempt in range(max_retries):
161
- try:
162
- response = client.text_generation(
163
- prompt,
164
- model=repo_id,
165
- temperature=self.temperature,
166
- top_p=self.top_p,
167
- repetition_penalty=1.1
168
- )
169
- return response
170
- except HfHubHTTPError as e:
171
- debug_print(f"Attempt {attempt+1} failed with error: {e}")
172
- # if this is the last attempt, re-raise the error
173
- if attempt == max_retries - 1:
174
- raise
175
- time.sleep(backoff)
176
- backoff *= 2 # exponential backoff
177
- return response
178
-
179
- class RemoteLLM(LLM):
180
- @property
181
- def _llm_type(self) -> str:
182
- return "remote_llm"
183
- def _call(self, prompt: str, stop: typing.Optional[List[str]] = None) -> str:
184
- return remote_generate(prompt)
185
- @property
186
- def _identifying_params(self) -> dict:
187
- return {"model": repo_id}
188
- debug_print("Remote Meta-Llama-3 pipeline created successfully.")
189
- return RemoteLLM()
190
- elif "mistral-api" in normalized:
191
- debug_print("Creating Mistral API pipeline...")
192
- mistral_api_key = os.environ.get("MISTRAL_API_KEY")
193
- if not mistral_api_key:
194
- raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
195
-
196
- class MistralLLM(LLM):
197
- temperature: float = 0.7
198
- top_p: float = 0.95
199
- client: Any = None # Changed from _client PrivateAttr
200
 
201
- def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
202
- super().__init__(temperature=temperature, top_p=top_p, **kwargs)
203
- # Initialize the client as a regular attribute instead of PrivateAttr
204
- self.client = Mistral(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- @property
207
- def _llm_type(self) -> str:
208
- return "mistral_llm"
209
 
210
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
211
- response = self.client.chat.complete( # Use self.client instead of self._client
212
- model="mistral-small-latest",
213
- messages=[{"role": "user", "content": prompt}],
214
- temperature=self.temperature,
215
- top_p=self.top_p,
216
- max_tokens=32000
217
- )
218
- return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- @property
221
- def _identifying_params(self) -> dict:
222
- return {"model": "mistral-small-latest"}
 
223
 
224
- mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
225
- debug_print("Mistral API pipeline created successfully.")
226
- return mistral_llm
227
- else:
228
- # Default branch: assume Llama
229
- model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
230
- extra_kwargs = {}
231
- if "llama" in normalized or model_id.startswith("meta-llama"):
232
- extra_kwargs["max_length"] = 4096
233
- pipe = pipeline(
234
- "text-generation",
235
- model=model_id,
236
- model_kwargs={"torch_dtype": torch.bfloat16},
237
- do_sample=True,
238
- temperature=self.temperature,
239
- top_p=self.top_p,
240
- device=-1,
241
- **extra_kwargs
242
- )
243
-
244
- class LocalLLM(LLM):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  @property
246
  def _llm_type(self) -> str:
247
- return "local_llm"
248
- def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
249
- # Reserve tokens for generation (e.g., 512 tokens)
250
- reserved_gen = 512
251
- max_total = 8192
252
- max_prompt_tokens = max_total - reserved_gen
253
- truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
254
- generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
255
- return generated
256
  @property
257
  def _identifying_params(self) -> dict:
258
- return {"model": model_id, "max_length": extra_kwargs.get("max_length")}
259
- debug_print("Local Llama pipeline created successfully with max_length=4096.")
260
- return LocalLLM()
261
 
262
  def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
263
  debug_print(f"Updating chain with new model: {new_model_choice}")
@@ -416,31 +647,48 @@ def update_model(new_model: str):
416
  else:
417
  return "No files loaded; please load files first."
418
 
 
 
419
  def submit_query_updated(query):
420
- debug_print("Inside submit_query function.")
421
  if not query:
422
- debug_print("Please enter a non-empty query")
423
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
424
 
425
- if not hasattr(rag_chain, 'elevated_rag_chain'):
 
426
  return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
427
 
428
  try:
429
- # Collect and truncate conversation history if needed
430
- history_text = ""
 
 
 
 
 
 
 
431
  if rag_chain.conversation_history:
432
- # Only keep the last 3 conversations to limit context size
433
- recent_history = rag_chain.conversation_history[-3:]
434
- history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}"
435
  for conv in recent_history])
 
 
436
 
437
- # Determine max context size based on model
438
- max_context_tokens = 32000 if "mistral" in rag_chain.llm_choice.lower() else 4096
439
- # Reserve 1000 tokens for the question and generation
440
- max_context_tokens -= 1000
 
 
 
 
441
 
442
  # Truncate context if needed
443
- context = truncate_prompt(rag_chain.context, max_tokens=max_context_tokens)
 
 
444
 
445
  prompt_variables = {
446
  "conversation_history": history_text,
@@ -448,12 +696,18 @@ def submit_query_updated(query):
448
  "question": query
449
  }
450
 
 
451
  response = rag_chain.elevated_rag_chain.invoke({"question": query})
452
- rag_chain.conversation_history.append({"query": query, "response": response})
 
 
 
453
 
454
  input_token_count = count_tokens(query)
455
  output_token_count = count_tokens(response)
456
 
 
 
457
  return (
458
  response,
459
  rag_chain.get_current_context(),
@@ -462,7 +716,7 @@ def submit_query_updated(query):
462
  )
463
  except Exception as e:
464
  error_msg = traceback.format_exc()
465
- debug_print("LLM error. Error: " + error_msg)
466
  return (
467
  f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
468
  "",
@@ -490,8 +744,9 @@ textarea {
490
  }
491
  """
492
 
 
493
  with gr.Blocks(css=custom_css) as app:
494
- gr.Markdown('''# PhiRAG
495
  **PhiRAG** Query Your Data with Advanced RAG Techniques
496
 
497
  **Model Selection & Parameters:** Choose from the following options:
@@ -514,89 +769,164 @@ with gr.Blocks(css=custom_css) as app:
514
  https://www.gutenberg.org/ebooks/8438.txt.utf-8
515
 
516
  **๐Ÿ” Query:** Enter your query below.
517
- - Example: Select all parts in each book focusing on moral responsibility in Aristotle philosophy and discuss in a comprehensible way and link the topics to other top world philosophers. Use a structure and bullet points
518
 
519
- The response displays the model used, word count, and current context (with conversation history).
 
 
520
  ''')
521
- with gr.Row():
522
- with gr.Column():
523
- model_dropdown = gr.Dropdown(
524
- choices=["๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3", "๐Ÿ‡ช๐Ÿ‡บ Mistral-API"],
525
- value="๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3",
526
- label="Select Model"
527
- )
528
- temperature_slider = gr.Slider(
529
- minimum=0.1, maximum=1.0, value=0.5, step=0.1,
530
- label="Randomness (Temperature)"
531
- )
532
- top_p_slider = gr.Slider(
533
- minimum=0.1, maximum=0.99, value=0.95, step=0.05,
534
- label="Word Variety (Top-p)"
535
- )
536
- with gr.Column():
537
- pdf_input = gr.Textbox(
538
- label="Enter your file URLs (one per line)",
539
- placeholder="Enter one URL per line (.pdf or .txt)",
540
- lines=4
541
- )
542
- prompt_input = gr.Textbox(
543
- label="Custom Prompt Template",
544
- placeholder="Enter your custom prompt template here",
545
- lines=8,
546
- value=default_prompt
547
- )
548
- with gr.Column():
549
- bm25_weight_slider = gr.Slider(
550
- minimum=0.0, maximum=1.0, value=0.6, step=0.1,
551
- label="Lexical vs Semantics (BM25 Weight)"
552
- )
553
- load_button = gr.Button("Load Files")
554
-
555
- with gr.Row():
556
- with gr.Column():
557
- query_input = gr.Textbox(
558
- label="Enter your query here",
559
- placeholder="Type your query",
560
- lines=4
561
- )
562
- submit_button = gr.Button("Submit")
563
- with gr.Column():
564
- reset_button = gr.Button("Reset App")
565
-
566
- with gr.Row():
567
- response_output = gr.Textbox(
568
- label="Response",
569
- placeholder="Response will appear here (formatted as Markdown)",
570
- lines=6
571
- )
572
- context_output = gr.Textbox(
573
- label="Current Context",
574
- placeholder="Retrieved context and conversation history will appear here",
575
- lines=6
576
- )
577
-
578
- with gr.Row():
579
- input_tokens = gr.Markdown("Input tokens: 0")
580
- output_tokens = gr.Markdown("Output tokens: 0")
581
- model_output = gr.Markdown("**Current Model**: Not selected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
 
583
  load_button.click(
584
- load_pdfs_updated,
585
  inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
586
- outputs=[response_output, context_output, model_output]
587
-
588
  )
589
 
590
  submit_button.click(
591
- submit_query_updated,
592
  inputs=[query_input],
593
- outputs=[response_output, context_output, input_tokens, output_tokens]
 
 
 
 
 
 
 
 
 
 
 
 
594
  )
595
 
596
  reset_button.click(
597
  reset_app_updated,
598
  inputs=[],
599
- outputs=[response_output, context_output, model_output]
600
  )
601
 
602
  model_dropdown.change(
@@ -604,7 +934,7 @@ The response displays the model used, word count, and current context (with conv
604
  inputs=model_dropdown,
605
  outputs=model_output
606
  )
607
-
608
  if __name__ == "__main__":
609
  debug_print("Launching Gradio interface.")
610
  app.launch(share=False)
 
68
  return len(text.split())
69
  return len(text.split())
70
 
71
+
72
+ # Add these imports at the top of your file
73
+ import uuid
74
+ import threading
75
+ import queue
76
+ from typing import Dict, Any, Tuple, Optional
77
+ import time
78
+
79
+ # Global storage for jobs and results
80
+ jobs = {} # Stores job status and results
81
+ results_queue = queue.Queue() # Thread-safe queue for completed jobs
82
+ processing_lock = threading.Lock() # Prevent simultaneous processing of the same job
83
+
84
+ # Function to process tasks in background
85
+ def process_in_background(job_id: str, function, args):
86
+ try:
87
+ result = function(*args)
88
+ results_queue.put((job_id, result))
89
+ except Exception as e:
90
+ error_msg = f"Error: {str(e)}\n\nTraceback: {traceback.format_exc()}"
91
+ debug_print(f"Job {job_id} failed: {error_msg}")
92
+ results_queue.put((job_id, (error_msg, "", "", "")))
93
+
94
+ # Async version of load_pdfs_updated
95
+ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
96
+ if not file_links:
97
+ return "Please enter non-empty URLs", "", "Model used: N/A", "Context: N/A"
98
+
99
+ job_id = str(uuid.uuid4())
100
+ debug_print(f"Starting async job {job_id} for loading files")
101
+
102
+ # Start background thread
103
+ threading.Thread(
104
+ target=process_in_background,
105
+ args=(job_id, load_pdfs_updated, [file_links, model_choice, prompt_template, bm25_weight, temperature, top_p])
106
+ ).start()
107
+
108
+ jobs[job_id] = {
109
+ "status": "processing",
110
+ "type": "load_files",
111
+ "start_time": time.time()
112
+ }
113
+
114
+ return (
115
+ f"Files are being processed in the background (Job ID: {job_id}).\n\n"
116
+ f"Use 'Check Job Status' with this ID to get results.",
117
+ f"Job ID: {job_id}",
118
+ f"Model selected: {model_choice}"
119
+ )
120
+
121
+ # Async version of submit_query_updated
122
+ def submit_query_async(query):
123
+ if not query:
124
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
125
+
126
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
127
+ return "Please load files first", "", "Input tokens: 0", "Output tokens: 0"
128
+
129
+ job_id = str(uuid.uuid4())
130
+ debug_print(f"Starting async job {job_id} for query: {query}")
131
+
132
+ # Start background thread
133
+ threading.Thread(
134
+ target=process_in_background,
135
+ args=(job_id, submit_query_updated, [query])
136
+ ).start()
137
+
138
+ jobs[job_id] = {
139
+ "status": "processing",
140
+ "type": "query",
141
+ "start_time": time.time(),
142
+ "query": query
143
+ }
144
+
145
+ return (
146
+ f"Query submitted and processing in the background (Job ID: {job_id}).\n\n"
147
+ f"Use 'Check Job Status' with this ID to get results.",
148
+ f"Job ID: {job_id}",
149
+ f"Input tokens: {count_tokens(query)}",
150
+ "Output tokens: pending"
151
+ )
152
+
153
+ # Function to check job status
154
+ def check_job_status(job_id):
155
+ if not job_id:
156
+ return "Please enter a job ID", "", "", ""
157
+
158
+ # Process any completed jobs in the queue
159
+ try:
160
+ while not results_queue.empty():
161
+ completed_id, result = results_queue.get_nowait()
162
+ if completed_id in jobs:
163
+ jobs[completed_id]["status"] = "completed"
164
+ jobs[completed_id]["result"] = result
165
+ jobs[completed_id]["end_time"] = time.time()
166
+ debug_print(f"Job {completed_id} completed and stored in jobs dictionary")
167
+ except queue.Empty:
168
+ pass
169
+
170
+ # Check if the requested job exists
171
+ if job_id not in jobs:
172
+ return "Job not found. Please check the ID and try again.", "", "", ""
173
+
174
+ job = jobs[job_id]
175
+
176
+ # If job is still processing
177
+ if job["status"] == "processing":
178
+ elapsed_time = time.time() - job["start_time"]
179
+ job_type = job.get("type", "unknown")
180
+
181
+ if job_type == "load_files":
182
+ return (
183
+ f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
184
+ f"Try checking again in a few seconds.",
185
+ f"Job ID: {job_id}",
186
+ f"Status: Processing"
187
+ )
188
+ else: # query job
189
+ return (
190
+ f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n"
191
+ f"Try checking again in a few seconds.",
192
+ f"Job ID: {job_id}",
193
+ f"Input tokens: {count_tokens(job.get('query', ''))}",
194
+ "Output tokens: pending"
195
+ )
196
+
197
+ # If job is completed
198
+ if job["status"] == "completed":
199
+ result = job["result"]
200
+ processing_time = job["end_time"] - job["start_time"]
201
+
202
+ if job.get("type") == "load_files":
203
+ return (
204
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
205
+ result[1],
206
+ result[2]
207
+ )
208
+ else: # query job
209
+ return (
210
+ f"{result[0]}\n\nProcessing time: {processing_time:.1f}s",
211
+ result[1],
212
+ result[2],
213
+ result[3]
214
+ )
215
+
216
+ # Fallback for unknown status
217
+ return f"Job status: {job['status']}", "", "", ""
218
+
219
+ # Function to clean up old jobs
220
+ def cleanup_old_jobs():
221
+ current_time = time.time()
222
+ to_delete = []
223
+
224
+ for job_id, job in jobs.items():
225
+ # Keep completed jobs for 1 hour, processing jobs for 2 hours
226
+ if job["status"] == "completed" and (current_time - job.get("end_time", 0)) > 3600:
227
+ to_delete.append(job_id)
228
+ elif job["status"] == "processing" and (current_time - job.get("start_time", 0)) > 7200:
229
+ to_delete.append(job_id)
230
+
231
+ for job_id in to_delete:
232
+ del jobs[job_id]
233
+
234
+ debug_print(f"Cleaned up {len(to_delete)} old jobs. {len(jobs)} jobs remaining.")
235
+ return f"Cleaned up {len(to_delete)} old jobs", "", ""
236
+
237
+ # Improve the truncate_prompt function to be more aggressive with limiting context
238
  def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
239
+ """Truncate prompt to fit within token limit, preserving the most recent/relevant parts."""
240
+ if not prompt:
241
+ return ""
242
+
243
  if global_tokenizer:
244
  try:
245
  tokens = global_tokenizer.encode(prompt)
246
  if len(tokens) > max_tokens:
247
+ # For prompts, we often want to keep the beginning instructions and the end context
248
+ # So we'll keep the first 20% and the last 80% of the max tokens
249
+ beginning_tokens = int(max_tokens * 0.2)
250
+ ending_tokens = max_tokens - beginning_tokens
251
+
252
+ new_tokens = tokens[:beginning_tokens] + tokens[-(ending_tokens):]
253
+ return global_tokenizer.decode(new_tokens)
254
  except Exception as e:
255
+ debug_print(f"Truncation error: {str(e)}")
256
+
257
+ # Fallback to word-based truncation
258
  words = prompt.split()
259
  if len(words) > max_tokens:
260
+ beginning_words = int(max_tokens * 0.2)
261
+ ending_words = max_tokens - beginning_words
262
+
263
+ return " ".join(words[:beginning_words] + words[-(ending_words):])
264
+
265
  return prompt
266
 
267
+
268
+
269
+
270
  default_prompt = """\
271
  {conversation_history}
272
  Use the following context to provide a detailed technical answer to the user's question.
 
326
  def extract_question(self, input_data):
327
  return input_data["question"]
328
 
329
+ # Improve error handling in the ElevatedRagChain class
330
  def create_llm_pipeline(self):
331
  normalized = self.llm_choice.lower()
332
+ try:
333
+ if "remote" in normalized:
334
+ debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
335
+ from huggingface_hub import InferenceClient
336
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
337
+ hf_api_token = os.environ.get("HF_API_TOKEN")
338
+ if not hf_api_token:
339
+ raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
+ client = InferenceClient(token=hf_api_token, timeout=120)
342
+
343
+ from huggingface_hub.utils._errors import HfHubHTTPError
344
+ def remote_generate(prompt: str) -> str:
345
+ max_retries = 3
346
+ backoff = 2 # start with 2 seconds
347
+ for attempt in range(max_retries):
348
+ try:
349
+ debug_print(f"Remote generation attempt {attempt+1}")
350
+ response = client.text_generation(
351
+ prompt,
352
+ model=repo_id,
353
+ temperature=self.temperature,
354
+ top_p=self.top_p,
355
+ max_new_tokens=512, # Reduced token count for speed
356
+ timeout=120
357
+ )
358
+ return response
359
+ except Exception as e:
360
+ debug_print(f"Attempt {attempt+1} failed with error: {e}")
361
+ if attempt == max_retries - 1:
362
+ raise
363
+ time.sleep(backoff)
364
+ backoff *= 2 # exponential backoff
365
+ return "Failed to generate response after multiple attempts."
366
+
367
+ class RemoteLLM(LLM):
368
+ @property
369
+ def _llm_type(self) -> str:
370
+ return "remote_llm"
371
 
372
+ def _call(self, prompt: str, stop: typing.Optional[List[str]] = None) -> str:
373
+ return remote_generate(prompt)
 
374
 
375
+ @property
376
+ def _identifying_params(self) -> dict:
377
+ return {"model": repo_id}
378
+
379
+ debug_print("Remote Meta-Llama-3 pipeline created successfully.")
380
+ return RemoteLLM()
381
+
382
+ elif "mistral" in normalized:
383
+ debug_print("Creating Mistral API pipeline...")
384
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
385
+ if not mistral_api_key:
386
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
387
+
388
+ # Import Mistral library with proper error handling
389
+ try:
390
+ from mistralai import Mistral
391
+ from mistralai.exceptions import MistralException
392
+ debug_print("Mistral library imported successfully")
393
+ except ImportError:
394
+ raise ImportError("Mistral client library not found. Install with: pip install mistralai")
395
+
396
+ # Fixed MistralLLM implementation that works with Pydantic v1
397
+ class MistralLLM(LLM):
398
+ client: Optional[Any] = None
399
+ temperature: float = 0.7
400
+ top_p: float = 0.95
401
 
402
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
403
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
404
+ self.client = Mistral(api_key=api_key)
405
+ debug_print("Mistral client initialized")
406
 
407
+ @property
408
+ def _llm_type(self) -> str:
409
+ return "mistral_llm"
410
+
411
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
412
+ try:
413
+ debug_print("Calling Mistral API...")
414
+ response = self.client.chat.complete(
415
+ model="mistral-small-latest",
416
+ messages=[{"role": "user", "content": prompt}],
417
+ temperature=self.temperature,
418
+ top_p=self.top_p,
419
+ max_tokens=1024 # Limit token count for faster response
420
+ )
421
+ return response.choices[0].message.content
422
+ except Exception as e:
423
+ debug_print(f"Mistral API error: {str(e)}")
424
+ return f"Error generating response: {str(e)}"
425
+
426
+ @property
427
+ def _identifying_params(self) -> dict:
428
+ return {"model": "mistral-small-latest"}
429
+
430
+ debug_print("Creating Mistral LLM instance")
431
+ mistral_llm = MistralLLM(
432
+ api_key=mistral_api_key,
433
+ temperature=self.temperature,
434
+ top_p=self.top_p
435
+ )
436
+ debug_print("Mistral API pipeline created successfully.")
437
+ return mistral_llm
438
+
439
+ else:
440
+ # Default case - use a smaller model that's more likely to work within constraints
441
+ debug_print("Using local/fallback model pipeline")
442
+ model_id = "facebook/opt-350m" # Much smaller model
443
+
444
+ pipe = pipeline(
445
+ "text-generation",
446
+ model=model_id,
447
+ device=-1, # CPU
448
+ max_length=1024
449
+ )
450
+
451
+ class LocalLLM(LLM):
452
+ @property
453
+ def _llm_type(self) -> str:
454
+ return "local_llm"
455
+
456
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
457
+ # Aggressively truncate prompt
458
+ truncated_prompt = truncate_prompt(prompt, max_tokens=512)
459
+ try:
460
+ generated = pipe(truncated_prompt, max_new_tokens=256)[0]["generated_text"]
461
+ # Only return the newly generated part
462
+ if generated.startswith(truncated_prompt):
463
+ return generated[len(truncated_prompt):].strip()
464
+ return generated
465
+ except Exception as e:
466
+ debug_print(f"Generation error: {str(e)}")
467
+ return f"Error generating response: {str(e)}"
468
+
469
+ @property
470
+ def _identifying_params(self) -> dict:
471
+ return {"model": model_id}
472
+
473
+ debug_print("Local fallback pipeline created.")
474
+ return LocalLLM()
475
+
476
+ except Exception as e:
477
+ debug_print(f"Error creating LLM pipeline: {str(e)}")
478
+ # Return a dummy LLM that explains the error
479
+ class ErrorLLM(LLM):
480
  @property
481
  def _llm_type(self) -> str:
482
+ return "error_llm"
483
+
484
+ def _call(self, prompt: str, stop: typing.Optional[List[str]] = None) -> str:
485
+ return f"Error initializing LLM: \n\nPlease check your environment variables and try again."
486
+
 
 
 
 
487
  @property
488
  def _identifying_params(self) -> dict:
489
+ return {"model": "error"}
490
+
491
+ return ErrorLLM()
492
 
493
  def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
494
  debug_print(f"Updating chain with new model: {new_model_choice}")
 
647
  else:
648
  return "No files loaded; please load files first."
649
 
650
+
651
+ # Update submit_query_updated to better handle context limitation
652
  def submit_query_updated(query):
653
+ debug_print(f"Processing query: {query}")
654
  if not query:
655
+ debug_print("Empty query received")
656
  return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
657
 
658
+ if not hasattr(rag_chain, 'elevated_rag_chain') or not rag_chain.raw_data:
659
+ debug_print("RAG chain not initialized")
660
  return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
661
 
662
  try:
663
+ # Determine max context size based on model
664
+ model_name = rag_chain.llm_choice.lower()
665
+ max_context_tokens = 32000 if "mistral" in model_name else 4096
666
+
667
+ # Reserve 20% of tokens for the question and response generation
668
+ reserved_tokens = int(max_context_tokens * 0.2)
669
+ max_context_tokens -= reserved_tokens
670
+
671
+ # Collect conversation history (last 2 only to save tokens)
672
  if rag_chain.conversation_history:
673
+ recent_history = rag_chain.conversation_history[-2:]
674
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response'][:300]}..."
 
675
  for conv in recent_history])
676
+ else:
677
+ history_text = ""
678
 
679
+ # Get history token count
680
+ history_tokens = count_tokens(history_text)
681
+
682
+ # Adjust context tokens based on history size
683
+ context_tokens = max_context_tokens - history_tokens
684
+
685
+ # Ensure we have some minimum context
686
+ context_tokens = max(context_tokens, 1000)
687
 
688
  # Truncate context if needed
689
+ context = truncate_prompt(rag_chain.context, max_tokens=context_tokens)
690
+
691
+ debug_print(f"Using model: {model_name}, context tokens: {count_tokens(context)}, history tokens: {history_tokens}")
692
 
693
  prompt_variables = {
694
  "conversation_history": history_text,
 
696
  "question": query
697
  }
698
 
699
+ debug_print("Invoking RAG chain")
700
  response = rag_chain.elevated_rag_chain.invoke({"question": query})
701
+
702
+ # Store only a reasonable amount of the response in history
703
+ trimmed_response = response[:1000] + ("..." if len(response) > 1000 else "")
704
+ rag_chain.conversation_history.append({"query": query, "response": trimmed_response})
705
 
706
  input_token_count = count_tokens(query)
707
  output_token_count = count_tokens(response)
708
 
709
+ debug_print(f"Query processed successfully. Output tokens: {output_token_count}")
710
+
711
  return (
712
  response,
713
  rag_chain.get_current_context(),
 
716
  )
717
  except Exception as e:
718
  error_msg = traceback.format_exc()
719
+ debug_print(f"LLM error: {error_msg}")
720
  return (
721
  f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
722
  "",
 
744
  }
745
  """
746
 
747
+ # Update the Gradio interface to include job status checking
748
  with gr.Blocks(css=custom_css) as app:
749
+ gr.Markdown('''# PhiRAG - Async Version
750
  **PhiRAG** Query Your Data with Advanced RAG Techniques
751
 
752
  **Model Selection & Parameters:** Choose from the following options:
 
769
  https://www.gutenberg.org/ebooks/8438.txt.utf-8
770
 
771
  **๐Ÿ” Query:** Enter your query below.
 
772
 
773
+ **โš ๏ธ IMPORTANT: This app now uses asynchronous processing to avoid timeout issues**
774
+ - When you load files or submit a query, you'll receive a Job ID
775
+ - Use the "Check Job Status" tab to monitor and retrieve your results
776
  ''')
777
+
778
+ with gr.Tabs() as tabs:
779
+ with gr.TabItem("Setup & Load Files"):
780
+ with gr.Row():
781
+ with gr.Column():
782
+ model_dropdown = gr.Dropdown(
783
+ choices=["๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3", "๐Ÿ‡ช๐Ÿ‡บ Mistral-API"],
784
+ value="๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3",
785
+ label="Select Model"
786
+ )
787
+ temperature_slider = gr.Slider(
788
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
789
+ label="Randomness (Temperature)"
790
+ )
791
+ top_p_slider = gr.Slider(
792
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
793
+ label="Word Variety (Top-p)"
794
+ )
795
+ with gr.Column():
796
+ pdf_input = gr.Textbox(
797
+ label="Enter your file URLs (one per line)",
798
+ placeholder="Enter one URL per line (.pdf or .txt)",
799
+ lines=4
800
+ )
801
+ prompt_input = gr.Textbox(
802
+ label="Custom Prompt Template",
803
+ placeholder="Enter your custom prompt template here",
804
+ lines=8,
805
+ value=default_prompt
806
+ )
807
+ with gr.Column():
808
+ bm25_weight_slider = gr.Slider(
809
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
810
+ label="Lexical vs Semantics (BM25 Weight)"
811
+ )
812
+ load_button = gr.Button("Load Files (Async)")
813
+ load_status = gr.Markdown("Status: Waiting for files")
814
+
815
+ with gr.Row():
816
+ load_response = gr.Textbox(
817
+ label="Load Response",
818
+ placeholder="Response will appear here",
819
+ lines=4
820
+ )
821
+ load_context = gr.Textbox(
822
+ label="Context Info",
823
+ placeholder="Context info will appear here",
824
+ lines=4
825
+ )
826
+
827
+ with gr.Row():
828
+ model_output = gr.Markdown("**Current Model**: Not selected")
829
+
830
+ with gr.TabItem("Submit Query"):
831
+ with gr.Row():
832
+ query_input = gr.Textbox(
833
+ label="Enter your query here",
834
+ placeholder="Type your query",
835
+ lines=4
836
+ )
837
+ submit_button = gr.Button("Submit Query (Async)")
838
+
839
+ with gr.Row():
840
+ query_response = gr.Textbox(
841
+ label="Query Response",
842
+ placeholder="Response will appear here (formatted as Markdown)",
843
+ lines=6
844
+ )
845
+ query_context = gr.Textbox(
846
+ label="Context Information",
847
+ placeholder="Retrieved context and conversation history will appear here",
848
+ lines=6
849
+ )
850
+
851
+ with gr.Row():
852
+ input_tokens = gr.Markdown("Input tokens: 0")
853
+ output_tokens = gr.Markdown("Output tokens: 0")
854
+
855
+ with gr.TabItem("Check Job Status"):
856
+ with gr.Row():
857
+ job_id_input = gr.Textbox(
858
+ label="Enter Job ID",
859
+ placeholder="Paste the Job ID here",
860
+ lines=1
861
+ )
862
+ check_button = gr.Button("Check Status")
863
+ cleanup_button = gr.Button("Cleanup Old Jobs")
864
+
865
+ with gr.Row():
866
+ status_response = gr.Textbox(
867
+ label="Job Result",
868
+ placeholder="Job result will appear here",
869
+ lines=6
870
+ )
871
+ status_context = gr.Textbox(
872
+ label="Context Information",
873
+ placeholder="Context information will appear here",
874
+ lines=6
875
+ )
876
+
877
+ with gr.Row():
878
+ status_tokens1 = gr.Markdown("")
879
+ status_tokens2 = gr.Markdown("")
880
+
881
+ with gr.TabItem("App Management"):
882
+ with gr.Row():
883
+ reset_button = gr.Button("Reset App")
884
+
885
+ with gr.Row():
886
+ reset_response = gr.Textbox(
887
+ label="Reset Response",
888
+ placeholder="Reset confirmation will appear here",
889
+ lines=2
890
+ )
891
+ reset_context = gr.Textbox(
892
+ label="",
893
+ placeholder="",
894
+ lines=2,
895
+ visible=False
896
+ )
897
+
898
+ with gr.Row():
899
+ reset_model = gr.Markdown("")
900
 
901
+ # Connect the buttons to their respective functions
902
  load_button.click(
903
+ load_pdfs_async,
904
  inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
905
+ outputs=[load_response, load_context, model_output]
 
906
  )
907
 
908
  submit_button.click(
909
+ submit_query_async,
910
  inputs=[query_input],
911
+ outputs=[query_response, query_context, input_tokens, output_tokens]
912
+ )
913
+
914
+ check_button.click(
915
+ check_job_status,
916
+ inputs=[job_id_input],
917
+ outputs=[status_response, status_context, status_tokens1, status_tokens2]
918
+ )
919
+
920
+ cleanup_button.click(
921
+ cleanup_old_jobs,
922
+ inputs=[],
923
+ outputs=[status_response, status_context, status_tokens1]
924
  )
925
 
926
  reset_button.click(
927
  reset_app_updated,
928
  inputs=[],
929
+ outputs=[reset_response, reset_context, reset_model]
930
  )
931
 
932
  model_dropdown.change(
 
934
  inputs=model_dropdown,
935
  outputs=model_output
936
  )
937
+
938
  if __name__ == "__main__":
939
  debug_print("Launching Gradio interface.")
940
  app.launch(share=False)
advanced_rag.py.bak ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
+ import datetime
4
+ import functools
5
+ import traceback
6
+ from typing import List, Optional, Any, Dict
7
+
8
+ import torch
9
+ import transformers
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain_community.llms import HuggingFacePipeline
12
+
13
+ # Other LangChain and community imports
14
+ from langchain_community.document_loaders import OnlinePDFLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.vectorstores import FAISS
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.retrievers import BM25Retriever
19
+ from langchain.retrievers import EnsembleRetriever
20
+ from langchain.prompts import ChatPromptTemplate
21
+ from langchain.schema import StrOutputParser, Document
22
+ from langchain_core.runnables import RunnableParallel, RunnableLambda
23
+ from transformers.quantizers.auto import AutoQuantizationConfig
24
+ import gradio as gr
25
+ import requests
26
+ from pydantic import PrivateAttr
27
+ import pydantic
28
+
29
+ from langchain.llms.base import LLM
30
+ from typing import Any, Optional, List
31
+ import typing
32
+ import time
33
+
34
+ print("Pydantic Version: ")
35
+ print(pydantic.__version__)
36
+ # Add Mistral imports with fallback handling
37
+ try:
38
+ from mistralai import Mistral
39
+ MISTRAL_AVAILABLE = True
40
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
41
+ debug_print("Loaded latest Mistral client library")
42
+ except ImportError:
43
+ MISTRAL_AVAILABLE = False
44
+ debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}")
45
+ debug_print("Mistral client library not found. Install with: pip install mistralai")
46
+
47
+ def debug_print(message: str):
48
+ print(f"[{datetime.datetime.now().isoformat()}] {message}")
49
+
50
+ def word_count(text: str) -> int:
51
+ return len(text.split())
52
+
53
+ # Initialize a tokenizer for token counting (using gpt2 as a generic fallback)
54
+ def initialize_tokenizer():
55
+ try:
56
+ return AutoTokenizer.from_pretrained("gpt2")
57
+ except Exception as e:
58
+ debug_print("Failed to initialize tokenizer: " + str(e))
59
+ return None
60
+
61
+ global_tokenizer = initialize_tokenizer()
62
+
63
+ def count_tokens(text: str) -> int:
64
+ if global_tokenizer:
65
+ try:
66
+ return len(global_tokenizer.encode(text))
67
+ except Exception as e:
68
+ return len(text.split())
69
+ return len(text.split())
70
+
71
+ def truncate_prompt(prompt: str, max_tokens: int = 4096) -> str:
72
+ if global_tokenizer:
73
+ try:
74
+ tokens = global_tokenizer.encode(prompt)
75
+ if len(tokens) > max_tokens:
76
+ tokens = tokens[-max_tokens:] # keep the last max_tokens tokens
77
+ return global_tokenizer.decode(tokens)
78
+ except Exception as e:
79
+ debug_print("Truncation error: " + str(e))
80
+ words = prompt.split()
81
+ if len(words) > max_tokens:
82
+ return " ".join(words[-max_tokens:])
83
+ return prompt
84
+
85
+ default_prompt = """\
86
+ {conversation_history}
87
+ Use the following context to provide a detailed technical answer to the user's question.
88
+ Do not include an introduction like "Based on the provided documents, ...". Just answer the question.
89
+ If you don't know the answer, please respond with "I don't know".
90
+
91
+ Context:
92
+ {context}
93
+
94
+ User's question:
95
+ {question}
96
+ """
97
+
98
+ def load_txt_from_url(url: str) -> Document:
99
+ response = requests.get(url)
100
+ if response.status_code == 200:
101
+ text = response.text.strip()
102
+ if not text:
103
+ raise ValueError(f"TXT file at {url} is empty.")
104
+ return Document(page_content=text, metadata={"source": url})
105
+ else:
106
+ raise Exception(f"Failed to load {url} with status {response.status_code}")
107
+
108
+ class ElevatedRagChain:
109
+ def __init__(self, llm_choice: str = "Meta-Llama-3", prompt_template: str = default_prompt,
110
+ bm25_weight: float = 0.6, temperature: float = 0.5, top_p: float = 0.95) -> None:
111
+ debug_print(f"Initializing ElevatedRagChain with model: {llm_choice}")
112
+ self.embed_func = HuggingFaceEmbeddings(
113
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
114
+ model_kwargs={"device": "cpu"}
115
+ )
116
+ self.bm25_weight = bm25_weight
117
+ self.faiss_weight = 1.0 - bm25_weight
118
+ self.top_k = 5
119
+ self.llm_choice = llm_choice
120
+ self.temperature = temperature
121
+ self.top_p = top_p
122
+ self.prompt_template = prompt_template
123
+ self.context = ""
124
+ self.conversation_history: List[Dict[str, str]] = []
125
+ self.raw_data = None
126
+ self.split_data = None
127
+ self.elevated_rag_chain = None
128
+
129
+ # Instance method to capture context and conversation history
130
+ def capture_context(self, result):
131
+ self.context = "\n".join([str(doc) for doc in result["context"]])
132
+ result["context"] = self.context
133
+ history_text = (
134
+ "\n".join([f"Q: {conv['query']}\nA: {conv['response']}" for conv in self.conversation_history])
135
+ if self.conversation_history else ""
136
+ )
137
+ result["conversation_history"] = history_text
138
+ return result
139
+
140
+ # Instance method to extract question from input data
141
+ def extract_question(self, input_data):
142
+ return input_data["question"]
143
+
144
+ def create_llm_pipeline(self):
145
+ normalized = self.llm_choice.lower()
146
+ if "remote" in normalized:
147
+ debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...")
148
+ from huggingface_hub import InferenceClient
149
+ repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
150
+ hf_api_token = os.environ.get("HF_API_TOKEN")
151
+ if not hf_api_token:
152
+ raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.")
153
+ client = InferenceClient(token=hf_api_token, timeout=240)
154
+
155
+ from huggingface_hub.utils._errors import HfHubHTTPError
156
+ def remote_generate(prompt: str) -> str:
157
+ max_retries = 5
158
+ backoff = 2 # start with 2 seconds
159
+ response = None
160
+ for attempt in range(max_retries):
161
+ try:
162
+ response = client.text_generation(
163
+ prompt,
164
+ model=repo_id,
165
+ temperature=self.temperature,
166
+ top_p=self.top_p,
167
+ repetition_penalty=1.1
168
+ )
169
+ return response
170
+ except HfHubHTTPError as e:
171
+ debug_print(f"Attempt {attempt+1} failed with error: {e}")
172
+ # if this is the last attempt, re-raise the error
173
+ if attempt == max_retries - 1:
174
+ raise
175
+ time.sleep(backoff)
176
+ backoff *= 2 # exponential backoff
177
+ return response
178
+
179
+ class RemoteLLM(LLM):
180
+ @property
181
+ def _llm_type(self) -> str:
182
+ return "remote_llm"
183
+ def _call(self, prompt: str, stop: typing.Optional[List[str]] = None) -> str:
184
+ return remote_generate(prompt)
185
+ @property
186
+ def _identifying_params(self) -> dict:
187
+ return {"model": repo_id}
188
+ debug_print("Remote Meta-Llama-3 pipeline created successfully.")
189
+ return RemoteLLM()
190
+ elif "mistral-api" in normalized:
191
+ debug_print("Creating Mistral API pipeline...")
192
+ mistral_api_key = os.environ.get("MISTRAL_API_KEY")
193
+ if not mistral_api_key:
194
+ raise ValueError("Please set the MISTRAL_API_KEY environment variable to use Mistral API.")
195
+
196
+ class MistralLLM(LLM):
197
+ temperature: float = 0.7
198
+ top_p: float = 0.95
199
+ client: Any = None # Changed from _client PrivateAttr
200
+
201
+ def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any):
202
+ super().__init__(temperature=temperature, top_p=top_p, **kwargs)
203
+ # Initialize the client as a regular attribute instead of PrivateAttr
204
+ self.client = Mistral(api_key=api_key)
205
+
206
+ @property
207
+ def _llm_type(self) -> str:
208
+ return "mistral_llm"
209
+
210
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
211
+ response = self.client.chat.complete( # Use self.client instead of self._client
212
+ model="mistral-small-latest",
213
+ messages=[{"role": "user", "content": prompt}],
214
+ temperature=self.temperature,
215
+ top_p=self.top_p,
216
+ max_tokens=32000
217
+ )
218
+ return response.choices[0].message.content
219
+
220
+ @property
221
+ def _identifying_params(self) -> dict:
222
+ return {"model": "mistral-small-latest"}
223
+
224
+ mistral_llm = MistralLLM(api_key=mistral_api_key, temperature=self.temperature, top_p=self.top_p)
225
+ debug_print("Mistral API pipeline created successfully.")
226
+ return mistral_llm
227
+ else:
228
+ # Default branch: assume Llama
229
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
230
+ extra_kwargs = {}
231
+ if "llama" in normalized or model_id.startswith("meta-llama"):
232
+ extra_kwargs["max_length"] = 4096
233
+ pipe = pipeline(
234
+ "text-generation",
235
+ model=model_id,
236
+ model_kwargs={"torch_dtype": torch.bfloat16},
237
+ do_sample=True,
238
+ temperature=self.temperature,
239
+ top_p=self.top_p,
240
+ device=-1,
241
+ **extra_kwargs
242
+ )
243
+
244
+ class LocalLLM(LLM):
245
+ @property
246
+ def _llm_type(self) -> str:
247
+ return "local_llm"
248
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
249
+ # Reserve tokens for generation (e.g., 512 tokens)
250
+ reserved_gen = 512
251
+ max_total = 8192
252
+ max_prompt_tokens = max_total - reserved_gen
253
+ truncated_prompt = truncate_prompt(prompt, max_tokens=max_prompt_tokens)
254
+ generated = pipe(truncated_prompt, max_new_tokens=reserved_gen)[0]["generated_text"]
255
+ return generated
256
+ @property
257
+ def _identifying_params(self) -> dict:
258
+ return {"model": model_id, "max_length": extra_kwargs.get("max_length")}
259
+ debug_print("Local Llama pipeline created successfully with max_length=4096.")
260
+ return LocalLLM()
261
+
262
+ def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, prompt_template: str, bm25_weight: float):
263
+ debug_print(f"Updating chain with new model: {new_model_choice}")
264
+ self.llm_choice = new_model_choice
265
+ self.temperature = temperature
266
+ self.top_p = top_p
267
+ self.prompt_template = prompt_template
268
+ self.bm25_weight = bm25_weight
269
+ self.faiss_weight = 1.0 - bm25_weight
270
+ self.llm = self.create_llm_pipeline()
271
+ def format_response(response: str) -> str:
272
+ input_tokens = count_tokens(self.context + self.prompt_template)
273
+ output_tokens = count_tokens(response)
274
+ formatted = f"### Response\n\n{response}\n\n---\n"
275
+ formatted += f"- **Input tokens:** {input_tokens}\n"
276
+ formatted += f"- **Output tokens:** {output_tokens}\n"
277
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
278
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
279
+ return formatted
280
+ base_runnable = RunnableParallel({
281
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
282
+ "question": RunnableLambda(self.extract_question)
283
+ }) | self.capture_context
284
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
285
+ debug_print("Chain updated successfully with new LLM pipeline.")
286
+
287
+ def add_pdfs_to_vectore_store(self, file_links: List[str]) -> None:
288
+ debug_print(f"Processing files using {self.llm_choice}")
289
+ self.raw_data = []
290
+ for link in file_links:
291
+ if link.lower().endswith(".pdf"):
292
+ debug_print(f"Loading PDF: {link}")
293
+ loaded_docs = OnlinePDFLoader(link).load()
294
+ if loaded_docs:
295
+ self.raw_data.append(loaded_docs[0])
296
+ else:
297
+ debug_print(f"No content found in PDF: {link}")
298
+ elif link.lower().endswith(".txt") or link.lower().endswith(".utf-8"):
299
+ debug_print(f"Loading TXT: {link}")
300
+ try:
301
+ self.raw_data.append(load_txt_from_url(link))
302
+ except Exception as e:
303
+ debug_print(f"Error loading TXT file {link}: {e}")
304
+ else:
305
+ debug_print(f"File type not supported for URL: {link}")
306
+ if not self.raw_data:
307
+ raise ValueError("No files were successfully loaded. Please check the URLs and file formats.")
308
+ debug_print("Files loaded successfully.")
309
+ debug_print("Starting text splitting...")
310
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=100)
311
+ self.split_data = self.text_splitter.split_documents(self.raw_data)
312
+ if not self.split_data:
313
+ raise ValueError("Text splitting resulted in no chunks. Check the file contents.")
314
+ debug_print(f"Text splitting completed. Number of chunks: {len(self.split_data)}")
315
+ debug_print("Creating BM25 retriever...")
316
+ self.bm25_retriever = BM25Retriever.from_documents(self.split_data)
317
+ self.bm25_retriever.k = self.top_k
318
+ debug_print("BM25 retriever created.")
319
+ debug_print("Embedding chunks and creating FAISS vector store...")
320
+ self.vector_store = FAISS.from_documents(self.split_data, self.embed_func)
321
+ self.faiss_retriever = self.vector_store.as_retriever(search_kwargs={"k": self.top_k})
322
+ debug_print("FAISS vector store created successfully.")
323
+ self.ensemble_retriever = EnsembleRetriever(
324
+ retrievers=[self.bm25_retriever, self.faiss_retriever],
325
+ weights=[self.bm25_weight, self.faiss_weight]
326
+ )
327
+ base_runnable = RunnableParallel({
328
+ "context": RunnableLambda(self.extract_question) | self.ensemble_retriever,
329
+ "question": RunnableLambda(self.extract_question)
330
+ }) | self.capture_context
331
+ self.rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
332
+ self.str_output_parser = StrOutputParser()
333
+ debug_print("Selecting LLM pipeline based on choice: " + self.llm_choice)
334
+ self.llm = self.create_llm_pipeline()
335
+ def format_response(response: str) -> str:
336
+ input_tokens = count_tokens(self.context + self.prompt_template)
337
+ output_tokens = count_tokens(response)
338
+ formatted = f"### Response\n\n{response}\n\n---\n"
339
+ formatted += f"- **Input tokens:** {input_tokens}\n"
340
+ formatted += f"- **Output tokens:** {output_tokens}\n"
341
+ formatted += f"- **Generated using:** {self.llm_choice}\n"
342
+ formatted += f"\n**Conversation History:** {len(self.conversation_history)} conversation(s) considered.\n"
343
+ return formatted
344
+ self.elevated_rag_chain = base_runnable | self.rag_prompt | self.llm | format_response
345
+ debug_print("Elevated RAG chain successfully built and ready to use.")
346
+
347
+ def get_current_context(self) -> str:
348
+ base_context = "\n".join([str(doc) for doc in self.split_data[:3]]) if self.split_data else "No context available."
349
+ history_summary = "\n\n---\n**Recent Conversations (last 3):**\n"
350
+ recent = self.conversation_history[-3:]
351
+ if recent:
352
+ for i, conv in enumerate(recent, 1):
353
+ history_summary += f"**Conversation {i}:**\n- Query: {conv['query']}\n- Response: {conv['response']}\n"
354
+ else:
355
+ history_summary += "No conversation history."
356
+ return base_context + history_summary
357
+
358
+ # ----------------------------
359
+ # Gradio Interface Functions
360
+ # ----------------------------
361
+ global rag_chain
362
+ rag_chain = ElevatedRagChain()
363
+
364
+ def load_pdfs_updated(file_links, model_choice, prompt_template, bm25_weight, temperature, top_p):
365
+ debug_print("Inside load_pdfs function.")
366
+ if not file_links:
367
+ debug_print("Please enter non-empty URLs")
368
+ return "Please enter non-empty URLs", "Word count: N/A", "Model used: N/A", "Context: N/A"
369
+ try:
370
+ links = [link.strip() for link in file_links.split("\n") if link.strip()]
371
+ global rag_chain
372
+ if rag_chain.raw_data:
373
+ rag_chain.update_llm_pipeline(model_choice, temperature, top_p, prompt_template, bm25_weight)
374
+ context_display = rag_chain.get_current_context()
375
+ response_msg = f"Files already loaded. Chain updated with model: {model_choice}"
376
+ return (
377
+ response_msg,
378
+ f"Word count: {word_count(rag_chain.context)}",
379
+ f"Model used: {rag_chain.llm_choice}",
380
+ f"Context:\n{context_display}"
381
+ )
382
+ else:
383
+ rag_chain = ElevatedRagChain(
384
+ llm_choice=model_choice,
385
+ prompt_template=prompt_template,
386
+ bm25_weight=bm25_weight,
387
+ temperature=temperature,
388
+ top_p=top_p
389
+ )
390
+ rag_chain.add_pdfs_to_vectore_store(links)
391
+ context_display = rag_chain.get_current_context()
392
+ response_msg = f"Files loaded successfully. Using model: {model_choice}"
393
+ return (
394
+ response_msg,
395
+ f"Word count: {word_count(rag_chain.context)}",
396
+ f"Model used: {rag_chain.llm_choice}",
397
+ f"Context:\n{context_display}"
398
+ )
399
+ except Exception as e:
400
+ error_msg = traceback.format_exc()
401
+ debug_print("Could not load files. Error: " + error_msg)
402
+ return (
403
+ "Error loading files: " + str(e),
404
+ f"Word count: {word_count('')}",
405
+ f"Model used: {rag_chain.llm_choice}",
406
+ "Context: N/A"
407
+ )
408
+
409
+ def update_model(new_model: str):
410
+ global rag_chain
411
+ if rag_chain and rag_chain.raw_data:
412
+ rag_chain.update_llm_pipeline(new_model, rag_chain.temperature, rag_chain.top_p,
413
+ rag_chain.prompt_template, rag_chain.bm25_weight)
414
+ debug_print(f"Model updated to {rag_chain.llm_choice}")
415
+ return f"Model updated to: {rag_chain.llm_choice}"
416
+ else:
417
+ return "No files loaded; please load files first."
418
+
419
+ def submit_query_updated(query):
420
+ debug_print("Inside submit_query function.")
421
+ if not query:
422
+ debug_print("Please enter a non-empty query")
423
+ return "Please enter a non-empty query", "", "Input tokens: 0", "Output tokens: 0"
424
+
425
+ if not hasattr(rag_chain, 'elevated_rag_chain'):
426
+ return "Please load files first.", "", "Input tokens: 0", "Output tokens: 0"
427
+
428
+ try:
429
+ # Collect and truncate conversation history if needed
430
+ history_text = ""
431
+ if rag_chain.conversation_history:
432
+ # Only keep the last 3 conversations to limit context size
433
+ recent_history = rag_chain.conversation_history[-3:]
434
+ history_text = "\n".join([f"Q: {conv['query']}\nA: {conv['response']}"
435
+ for conv in recent_history])
436
+
437
+ # Determine max context size based on model
438
+ max_context_tokens = 32000 if "mistral" in rag_chain.llm_choice.lower() else 4096
439
+ # Reserve 1000 tokens for the question and generation
440
+ max_context_tokens -= 1000
441
+
442
+ # Truncate context if needed
443
+ context = truncate_prompt(rag_chain.context, max_tokens=max_context_tokens)
444
+
445
+ prompt_variables = {
446
+ "conversation_history": history_text,
447
+ "context": context,
448
+ "question": query
449
+ }
450
+
451
+ response = rag_chain.elevated_rag_chain.invoke({"question": query})
452
+ rag_chain.conversation_history.append({"query": query, "response": response})
453
+
454
+ input_token_count = count_tokens(query)
455
+ output_token_count = count_tokens(response)
456
+
457
+ return (
458
+ response,
459
+ rag_chain.get_current_context(),
460
+ f"Input tokens: {input_token_count}",
461
+ f"Output tokens: {output_token_count}"
462
+ )
463
+ except Exception as e:
464
+ error_msg = traceback.format_exc()
465
+ debug_print("LLM error. Error: " + error_msg)
466
+ return (
467
+ f"Query error: {str(e)}\n\nTry using a smaller document or simplifying your query.",
468
+ "",
469
+ "Input tokens: 0",
470
+ "Output tokens: 0"
471
+ )
472
+
473
+ def reset_app_updated():
474
+ global rag_chain
475
+ rag_chain = ElevatedRagChain()
476
+ debug_print("App reset successfully.")
477
+ return (
478
+ "App reset successfully. You can now load new files",
479
+ "",
480
+ "Model used: Not selected"
481
+ )
482
+
483
+ # ----------------------------
484
+ # Gradio Interface Setup
485
+ # ----------------------------
486
+ custom_css = """
487
+ textarea {
488
+ overflow-y: scroll !important;
489
+ max-height: 200px;
490
+ }
491
+ """
492
+
493
+ with gr.Blocks(css=custom_css) as app:
494
+ gr.Markdown('''# PhiRAG
495
+ **PhiRAG** Query Your Data with Advanced RAG Techniques
496
+
497
+ **Model Selection & Parameters:** Choose from the following options:
498
+ - ๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3 - has context windows of 8000 tokens
499
+ - ๐Ÿ‡ช๐Ÿ‡บ Mistral-API - has context windows of 32000 tokens
500
+
501
+ **๐Ÿ”ฅ Randomness (Temperature):** Adjusts output predictability.
502
+ - Example: 0.2 makes the output very deterministic (less creative), while 0.8 introduces more variety and spontaneity.
503
+
504
+ **๐ŸŽฏ Word Variety (Topโ€‘p):** Limits word choices to a set probability percentage.
505
+ - Example: 0.5 restricts output to the most likely 50% of token choices for a focused answer; 0.95 allows almost all possibilities for more diverse responses.
506
+
507
+ **โš–๏ธ BM25 Weight:** Adjust Lexical vs Semantics.
508
+ - Example: A value of 0.8 puts more emphasis on exact keyword (lexical) matching, while 0.3 shifts emphasis toward semantic similarity.
509
+
510
+ **โœ๏ธ Prompt Template:** Edit as desired.
511
+
512
+ **๐Ÿ”— File URLs:** Enter one URL per line (.pdf or .txt).\
513
+ - Example: Provide one URL per line, such as
514
+ https://www.gutenberg.org/ebooks/8438.txt.utf-8
515
+
516
+ **๐Ÿ” Query:** Enter your query below.
517
+ - Example: Select all parts in each book focusing on moral responsibility in Aristotle philosophy and discuss in a comprehensible way and link the topics to other top world philosophers. Use a structure and bullet points
518
+
519
+ The response displays the model used, word count, and current context (with conversation history).
520
+ ''')
521
+ with gr.Row():
522
+ with gr.Column():
523
+ model_dropdown = gr.Dropdown(
524
+ choices=["๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3", "๐Ÿ‡ช๐Ÿ‡บ Mistral-API"],
525
+ value="๐Ÿ‡บ๐Ÿ‡ธ Remote Meta-Llama-3",
526
+ label="Select Model"
527
+ )
528
+ temperature_slider = gr.Slider(
529
+ minimum=0.1, maximum=1.0, value=0.5, step=0.1,
530
+ label="Randomness (Temperature)"
531
+ )
532
+ top_p_slider = gr.Slider(
533
+ minimum=0.1, maximum=0.99, value=0.95, step=0.05,
534
+ label="Word Variety (Top-p)"
535
+ )
536
+ with gr.Column():
537
+ pdf_input = gr.Textbox(
538
+ label="Enter your file URLs (one per line)",
539
+ placeholder="Enter one URL per line (.pdf or .txt)",
540
+ lines=4
541
+ )
542
+ prompt_input = gr.Textbox(
543
+ label="Custom Prompt Template",
544
+ placeholder="Enter your custom prompt template here",
545
+ lines=8,
546
+ value=default_prompt
547
+ )
548
+ with gr.Column():
549
+ bm25_weight_slider = gr.Slider(
550
+ minimum=0.0, maximum=1.0, value=0.6, step=0.1,
551
+ label="Lexical vs Semantics (BM25 Weight)"
552
+ )
553
+ load_button = gr.Button("Load Files")
554
+
555
+ with gr.Row():
556
+ with gr.Column():
557
+ query_input = gr.Textbox(
558
+ label="Enter your query here",
559
+ placeholder="Type your query",
560
+ lines=4
561
+ )
562
+ submit_button = gr.Button("Submit")
563
+ with gr.Column():
564
+ reset_button = gr.Button("Reset App")
565
+
566
+ with gr.Row():
567
+ response_output = gr.Textbox(
568
+ label="Response",
569
+ placeholder="Response will appear here (formatted as Markdown)",
570
+ lines=6
571
+ )
572
+ context_output = gr.Textbox(
573
+ label="Current Context",
574
+ placeholder="Retrieved context and conversation history will appear here",
575
+ lines=6
576
+ )
577
+
578
+ with gr.Row():
579
+ input_tokens = gr.Markdown("Input tokens: 0")
580
+ output_tokens = gr.Markdown("Output tokens: 0")
581
+ model_output = gr.Markdown("**Current Model**: Not selected")
582
+
583
+ load_button.click(
584
+ load_pdfs_updated,
585
+ inputs=[pdf_input, model_dropdown, prompt_input, bm25_weight_slider, temperature_slider, top_p_slider],
586
+ outputs=[response_output, context_output, model_output]
587
+
588
+ )
589
+
590
+ submit_button.click(
591
+ submit_query_updated,
592
+ inputs=[query_input],
593
+ outputs=[response_output, context_output, input_tokens, output_tokens]
594
+ )
595
+
596
+ reset_button.click(
597
+ reset_app_updated,
598
+ inputs=[],
599
+ outputs=[response_output, context_output, model_output]
600
+ )
601
+
602
+ model_dropdown.change(
603
+ fn=update_model,
604
+ inputs=model_dropdown,
605
+ outputs=model_output
606
+ )
607
+
608
+ if __name__ == "__main__":
609
+ debug_print("Launching Gradio interface.")
610
+ app.launch(share=False)