|
|
"""lightrag_wrapper.py |
|
|
|
|
|
Provides a resilient RAG interface for the Vipassana agent. |
|
|
Now includes: |
|
|
- Cross-Encoder Reranking for improved precision and accuracy |
|
|
- BM25 keyword-based search for hybrid retrieval (BM25 + FAISS) |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
import pickle |
|
|
import faiss |
|
|
import numpy as np |
|
|
import openai |
|
|
import re |
|
|
import unicodedata |
|
|
from typing import List, Dict, Optional, Tuple, Any |
|
|
|
|
|
|
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder |
|
|
from sklearn.preprocessing import normalize |
|
|
except ImportError: |
|
|
SentenceTransformer = None |
|
|
CrossEncoder = None |
|
|
normalize = None |
|
|
|
|
|
try: |
|
|
from openai import OpenAI |
|
|
except ImportError: |
|
|
OpenAI = None |
|
|
|
|
|
try: |
|
|
from huggingface_hub import InferenceClient |
|
|
except ImportError: |
|
|
InferenceClient = None |
|
|
|
|
|
try: |
|
|
from rank_bm25 import BM25Okapi |
|
|
except ImportError: |
|
|
BM25Okapi = None |
|
|
|
|
|
try: |
|
|
from multilingual_prompts import detect_language, get_system_prompt, get_user_prompt |
|
|
except ImportError: |
|
|
print("Warning: multilingual_prompts module not found. Fallback to basic prompts.") |
|
|
detect_language = None |
|
|
get_system_prompt = None |
|
|
get_user_prompt = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai") |
|
|
|
|
|
|
|
|
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_MODEL = os.getenv("HF_MODEL", "Qwen/Qwen2.5-7B-Instruct") |
|
|
|
|
|
|
|
|
EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
RERANK_MODEL = os.getenv("RERANK_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
INDEX_PATH = os.getenv("VIPASSANA_INDEX_PATH", "data/vector_store/index.faiss") |
|
|
META_PATH = os.getenv("VIPASSANA_META_PATH", "data/vector_store/meta.json") |
|
|
BM25_PATH = os.getenv("VIPASSANA_BM25_PATH", "data/vector_store/bm25.pkl") |
|
|
|
|
|
|
|
|
class VipassanaRAGAgent: |
|
|
"""Unified RAG wrapper with Multi-Query Search, BM25, and Cross-Encoder Reranking.""" |
|
|
|
|
|
def __init__(self, openai_api_key: Optional[str] = None, hf_api_token: Optional[str] = None): |
|
|
|
|
|
self.is_ready = False |
|
|
self.index_loaded = False |
|
|
|
|
|
|
|
|
self.model = None |
|
|
self.reranker = None |
|
|
self.bm25 = None |
|
|
self.index = None |
|
|
self.metadatas = [] |
|
|
|
|
|
|
|
|
self.openai_client = None |
|
|
self.hf_client = None |
|
|
self.llm_provider = LLM_PROVIDER |
|
|
|
|
|
|
|
|
if SentenceTransformer is None or CrossEncoder is None or faiss is None or np is None: |
|
|
print("ERROR: Missing required packages (sentence-transformers, faiss, numpy). Cannot initialize RAG components.") |
|
|
return |
|
|
|
|
|
|
|
|
print(f"[LLM] Initializing LLM provider: {self.llm_provider}") |
|
|
|
|
|
if self.llm_provider == "openai": |
|
|
|
|
|
if OpenAI is None: |
|
|
print("ERROR: openai package not installed. Run: pip install openai") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
api_key = openai_api_key or os.getenv("OPENAI_API_KEY") |
|
|
if not api_key: |
|
|
print("ERROR: OPENAI_API_KEY not provided. Cannot initialize OpenAI client.") |
|
|
print("HINT: Set OPENAI_API_KEY in your .env file or HF Spaces secrets") |
|
|
return |
|
|
|
|
|
|
|
|
self.openai_client = OpenAI( |
|
|
api_key=api_key, |
|
|
timeout=60.0, |
|
|
max_retries=3 |
|
|
) |
|
|
print(f"[LLM] OpenAI client initialized successfully (model: {OPENAI_MODEL})") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to initialize OpenAI client: {e}") |
|
|
return |
|
|
|
|
|
elif self.llm_provider == "huggingface": |
|
|
|
|
|
if InferenceClient is None: |
|
|
print("ERROR: huggingface_hub package not installed. Run: pip install huggingface_hub") |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
token = hf_api_token or os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
self.hf_client = InferenceClient( |
|
|
model=HF_MODEL, |
|
|
token=token, |
|
|
timeout=60.0 |
|
|
) |
|
|
print(f"[LLM] HuggingFace Inference client initialized successfully") |
|
|
print(f"[LLM] Using model: {HF_MODEL}") |
|
|
if not token: |
|
|
print("[LLM] No HF_API_TOKEN provided - using public inference (may have rate limits)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR: Failed to initialize HuggingFace client: {e}") |
|
|
return |
|
|
else: |
|
|
print(f"ERROR: Invalid LLM_PROVIDER: {self.llm_provider}. Must be 'openai' or 'huggingface'") |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
print(f"Initializing embedding model: {EMBED_MODEL}") |
|
|
self.model = SentenceTransformer(EMBED_MODEL) |
|
|
|
|
|
print(f"Initializing reranking model: {RERANK_MODEL}") |
|
|
|
|
|
self.reranker = CrossEncoder(RERANK_MODEL) |
|
|
|
|
|
self._load_index() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading RAG components (models or index): {e}") |
|
|
return |
|
|
|
|
|
|
|
|
if self.index_loaded: |
|
|
self.is_ready = True |
|
|
print("VipassanaRAGAgent initialized successfully with Reranking enabled.") |
|
|
|
|
|
|
|
|
def _load_index(self): |
|
|
"""Loads the FAISS index, BM25 index, and metadata files.""" |
|
|
print(f"Attempting to load index from {INDEX_PATH}...") |
|
|
try: |
|
|
self.index = faiss.read_index(INDEX_PATH) |
|
|
with open(META_PATH, "r", encoding="utf-8") as f: |
|
|
self.metadatas = json.load(f) |
|
|
|
|
|
|
|
|
if os.path.exists(BM25_PATH) and BM25Okapi is not None: |
|
|
with open(BM25_PATH, "rb") as f: |
|
|
self.bm25 = pickle.load(f) |
|
|
print(f"BM25 index loaded successfully from {BM25_PATH}") |
|
|
else: |
|
|
print(f"Warning: BM25 index not found at {BM25_PATH}. Falling back to FAISS only.") |
|
|
self.bm25 = None |
|
|
|
|
|
self.index_loaded = True |
|
|
print(f"Index loaded successfully with {len(self.metadatas)} chunks.") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load index from {INDEX_PATH}. Error: {e}") |
|
|
self.index_loaded = False |
|
|
self.index = None |
|
|
self.metadatas = [] |
|
|
self.bm25 = None |
|
|
|
|
|
def _expand_query(self, query: str) -> List[str]: |
|
|
""" |
|
|
Expand query into a list of query variations for better initial recall. |
|
|
(Same logic as previous revision, focused on getting *more* relevant candidates). |
|
|
""" |
|
|
vipassana_terms_map = { |
|
|
'meditation': ['practice', 'technique', 'sadhana'], |
|
|
'vipassana': ['insight meditation', 'mindfulness', 'awareness'], |
|
|
'anapana': ['breathing', 'breath', 'respiration'], |
|
|
'goenka': ['s.n. goenka', 'teacher', 'acharya'], |
|
|
'dhamma': ['dharma', 'teaching', 'truth'], |
|
|
'suffering': ['dukkha', 'pain', 'misery', 'stress'], |
|
|
'anicca': ['impermanence', 'change'], |
|
|
} |
|
|
|
|
|
query_lower = query.lower() |
|
|
query_list = [query_lower] |
|
|
|
|
|
|
|
|
for key, terms in vipassana_terms_map.items(): |
|
|
if key in query_lower: |
|
|
for term in terms: |
|
|
|
|
|
query_list.append(query_lower.replace(key, term)) |
|
|
|
|
|
|
|
|
if "?" in query or query_lower.startswith(("what is", "how to", "why is")): |
|
|
simple_rephrase = re.sub(r"what is|how to|why is", "", query_lower, 1).strip().replace("?", "") |
|
|
if simple_rephrase: |
|
|
query_list.append(simple_rephrase) |
|
|
|
|
|
|
|
|
unique_queries = [] |
|
|
seen_queries = set() |
|
|
for q in query_list: |
|
|
cleaned_q = re.sub(r"\s+", " ", q).strip() |
|
|
if cleaned_q and cleaned_q not in seen_queries: |
|
|
unique_queries.append(cleaned_q) |
|
|
seen_queries.add(cleaned_q) |
|
|
|
|
|
|
|
|
return unique_queries[:5] |
|
|
|
|
|
def _rerank_retrieved_items(self, query: str, retrieved_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Uses a Cross-Encoder model to re-score the relevance of each chunk to the query. |
|
|
This greatly improves precision. |
|
|
""" |
|
|
if not self.reranker: |
|
|
return retrieved_items |
|
|
|
|
|
if not retrieved_items: |
|
|
return [] |
|
|
|
|
|
|
|
|
sentences_to_rank = [[query, item["chunk"]] for item in retrieved_items] |
|
|
|
|
|
|
|
|
new_scores = self.reranker.predict(sentences_to_rank) |
|
|
|
|
|
|
|
|
for i, item in enumerate(retrieved_items): |
|
|
|
|
|
item["rerank_score"] = float(new_scores[i]) |
|
|
|
|
|
|
|
|
retrieved_items.sort(key=lambda x: x["rerank_score"], reverse=True) |
|
|
|
|
|
return retrieved_items |
|
|
|
|
|
|
|
|
def _bm25_search(self, query: str, top_k: int = 15) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Performs BM25 keyword-based search. |
|
|
Returns a list of candidate chunks with BM25 scores. |
|
|
""" |
|
|
if self.bm25 is None: |
|
|
return [] |
|
|
|
|
|
|
|
|
tokenized_query = query.lower().split() |
|
|
|
|
|
|
|
|
bm25_scores = self.bm25.get_scores(tokenized_query) |
|
|
|
|
|
|
|
|
top_indices = np.argsort(bm25_scores)[::-1][:top_k] |
|
|
|
|
|
results = [] |
|
|
for idx in top_indices: |
|
|
if idx < len(self.metadatas) and bm25_scores[idx] > 0: |
|
|
metadata = self.metadatas[idx] |
|
|
results.append({ |
|
|
"score": float(bm25_scores[idx]), |
|
|
"chunk": metadata.get("chunk", ""), |
|
|
"source": metadata.get("source"), |
|
|
"metadata": metadata |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def retrieve(self, query: str, top_k: int = 5, top_k_initial: int = 30) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Retrieves context using hybrid search (BM25 + FAISS) followed by cross-encoder reranking. |
|
|
|
|
|
top_k: The final number of chunks passed to the LLM. |
|
|
top_k_initial: The number of chunks retrieved from FAISS before reranking. |
|
|
""" |
|
|
if not self.index_loaded: |
|
|
return [] |
|
|
|
|
|
|
|
|
query_list = self._expand_query(query) |
|
|
|
|
|
|
|
|
all_retrieved_items: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
|
|
|
search_k_per_query = max(5, top_k_initial // len(query_list)) if len(query_list) > 0 else top_k_initial |
|
|
|
|
|
for single_query in query_list: |
|
|
if not single_query.strip(): |
|
|
continue |
|
|
|
|
|
query_embedding = self.model.encode(single_query, convert_to_numpy=True) |
|
|
query_embedding = normalize(query_embedding.reshape(1, -1)) |
|
|
|
|
|
|
|
|
distances, indices = self.index.search(query_embedding.astype("float32"), search_k_per_query) |
|
|
|
|
|
for i, score in zip(indices[0], distances[0]): |
|
|
if i >= 0: |
|
|
metadata = self.metadatas[i] |
|
|
chunk_text = metadata.get("chunk", "") |
|
|
|
|
|
if chunk_text: |
|
|
|
|
|
if chunk_text not in all_retrieved_items or score > all_retrieved_items[chunk_text]["score"]: |
|
|
all_retrieved_items[chunk_text] = { |
|
|
"score": float(score), |
|
|
"chunk": chunk_text, |
|
|
"source": metadata.get("source"), |
|
|
"metadata": metadata |
|
|
} |
|
|
|
|
|
|
|
|
bm25_results = self._bm25_search(query, top_k=15) |
|
|
for item in bm25_results: |
|
|
chunk_text = item["chunk"] |
|
|
if chunk_text and chunk_text not in all_retrieved_items: |
|
|
all_retrieved_items[chunk_text] = item |
|
|
|
|
|
initial_candidates = list(all_retrieved_items.values()) |
|
|
|
|
|
|
|
|
if not initial_candidates: |
|
|
return [] |
|
|
|
|
|
|
|
|
reranked_items = self._rerank_retrieved_items(query, initial_candidates) |
|
|
|
|
|
|
|
|
return reranked_items[:top_k] |
|
|
|
|
|
|
|
|
def _get_llm_response(self, query: str, context: str, mode: str = "long") -> str: |
|
|
""" |
|
|
Unified method to get LLM response from either OpenAI or HuggingFace. |
|
|
Automatically routes to the correct provider based on LLM_PROVIDER setting. |
|
|
""" |
|
|
if not self.openai_client and not self.hf_client: |
|
|
return "Internal error: LLM client is not initialized." |
|
|
|
|
|
|
|
|
if detect_language and get_system_prompt and get_user_prompt: |
|
|
detected_language = detect_language(query) |
|
|
system_prompt = get_system_prompt(detected_language) |
|
|
prompt = get_user_prompt(query, context, detected_language) |
|
|
print(f"[LLM] Detected language: {detected_language}") |
|
|
else: |
|
|
|
|
|
detected_language = "english" |
|
|
system_prompt = ( |
|
|
"You are the Vipassana Guide AI, a compassionate meditation teacher. " |
|
|
"Use ONLY the provided CONTEXT. Do not add external knowledge.\n\n" |
|
|
"Format responses with ## headings, **bold**, and *italic* for Pali/Sanskrit terms.\n" |
|
|
"Be direct and practical. Include [Source: filename] references." |
|
|
) |
|
|
prompt = f"""CONTEXT (Vipassana Knowledge Base): |
|
|
{context} |
|
|
|
|
|
USER'S QUESTION: |
|
|
{query} |
|
|
|
|
|
Generate the response based strictly on the provided CONTEXT.""" |
|
|
|
|
|
|
|
|
|
|
|
if self.llm_provider == "openai": |
|
|
|
|
|
try: |
|
|
response = self.openai_client.chat.completions.create( |
|
|
model=OPENAI_MODEL, |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt} |
|
|
], |
|
|
temperature=0.05, |
|
|
max_tokens=1200, |
|
|
top_p=0.9, |
|
|
timeout=60 |
|
|
) |
|
|
return response.choices[0].message.content |
|
|
|
|
|
except openai.APIConnectionError as e: |
|
|
|
|
|
error_msg = f"Network connection error to OpenAI API. This may be temporary. Please try again in a moment." |
|
|
print(f"[OpenAI API] Connection Error: {e}") |
|
|
return error_msg |
|
|
except openai.APITimeoutError as e: |
|
|
|
|
|
error_msg = f"Request to OpenAI timed out. The service may be slow. Please try again." |
|
|
print(f"[OpenAI API] Timeout Error: {e}") |
|
|
return error_msg |
|
|
except openai.AuthenticationError as e: |
|
|
|
|
|
error_msg = f"OpenAI API authentication failed. Please check your API key configuration." |
|
|
print(f"[OpenAI API] Authentication Error: {e}") |
|
|
return error_msg |
|
|
except openai.RateLimitError as e: |
|
|
|
|
|
error_msg = f"OpenAI API rate limit exceeded. Please wait a moment and try again." |
|
|
print(f"[OpenAI API] Rate Limit Error: {e}") |
|
|
return error_msg |
|
|
except Exception as e: |
|
|
|
|
|
error_msg = f"Error during OpenAI generation: {type(e).__name__} - {str(e)}" |
|
|
print(f"[OpenAI API] Generic Error: {e}") |
|
|
return error_msg |
|
|
|
|
|
elif self.llm_provider == "huggingface": |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
|
|
|
|
|
|
response = self.hf_client.chat_completion( |
|
|
messages=messages, |
|
|
max_tokens=1200, |
|
|
temperature=0.05, |
|
|
top_p=0.85, |
|
|
) |
|
|
|
|
|
|
|
|
answer = response.choices[0].message.content |
|
|
return answer.strip() |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
error_msg = f"Error during HuggingFace generation: {type(e).__name__} - {str(e)}" |
|
|
print(f"[HF API] Error: {e}") |
|
|
print(f"[HF API] Full error details: {e}") |
|
|
return error_msg |
|
|
|
|
|
else: |
|
|
return f"Error: Unknown LLM provider: {self.llm_provider}" |
|
|
|
|
|
def answer(self, query: str, top_k: int = 6, mode: str = "long") -> Tuple[str, List[str]]: |
|
|
"""Main method to retrieve context and generate an answer with optimized performance.""" |
|
|
if not self.is_ready: |
|
|
return "RAG Agent failed to initialize. Please check the console for errors.", [] |
|
|
|
|
|
|
|
|
retrieved_items = self.retrieve(query, top_k=top_k, top_k_initial=20) |
|
|
|
|
|
if not retrieved_items: |
|
|
return "Could not retrieve relevant documents from the knowledge base.", [] |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
sources = [] |
|
|
|
|
|
for i, item in enumerate(retrieved_items): |
|
|
chunk = item.get("chunk", "") |
|
|
source = item.get("source", "") |
|
|
|
|
|
if chunk: |
|
|
|
|
|
context_parts.append(f"[Source: {source}]\n{chunk}") |
|
|
|
|
|
if source and source not in sources: |
|
|
sources.append(source) |
|
|
|
|
|
context = "\n\n---\n\n".join(context_parts) |
|
|
|
|
|
|
|
|
answer = self._get_llm_response(query, context, mode) |
|
|
|
|
|
return answer, sources |
|
|
|
|
|
def _clean_text(self, text: str) -> str: |
|
|
"""Basic cleaning for text extracted from PDFs.""" |
|
|
if not isinstance(text, str): |
|
|
return "" |
|
|
|
|
|
|
|
|
cleaned = unicodedata.normalize('NFKC', text) |
|
|
cleaned = re.sub(r"(/c\d+)+", " ", cleaned) |
|
|
cleaned = re.sub(r"\[\s*\d+\s*\]", " ", cleaned) |
|
|
cleaned = re.sub(r"\(\s*\d+\s*\)", " ", cleaned) |
|
|
cleaned = re.sub(r"-\s*\n\s*", "", cleaned) |
|
|
cleaned = re.sub(r"\n+", " ", cleaned) |
|
|
cleaned = re.sub(r"[\x00-\x1F\x7F-\x9F]", " ", cleaned) |
|
|
cleaned = re.sub(r"\s+", " ", cleaned).strip() |
|
|
|
|
|
return cleaned |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|