dhammaai / lightrag_wrapper.py
SJS-HUB
Set OpenAI as default LLM provider with cost/quality breakdown
d47d333
"""lightrag_wrapper.py
Provides a resilient RAG interface for the Vipassana agent.
Now includes:
- Cross-Encoder Reranking for improved precision and accuracy
- BM25 keyword-based search for hybrid retrieval (BM25 + FAISS)
"""
import os
import json
import pickle
import faiss
import numpy as np
import openai
import re
import unicodedata
from typing import List, Dict, Optional, Tuple, Any
# Conditional imports for RAG components
try:
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.preprocessing import normalize
except ImportError:
SentenceTransformer = None
CrossEncoder = None
normalize = None
try:
from openai import OpenAI
except ImportError:
OpenAI = None
try:
from huggingface_hub import InferenceClient
except ImportError:
InferenceClient = None
try:
from rank_bm25 import BM25Okapi
except ImportError:
BM25Okapi = None
try:
from multilingual_prompts import detect_language, get_system_prompt, get_user_prompt
except ImportError:
print("Warning: multilingual_prompts module not found. Fallback to basic prompts.")
detect_language = None
get_system_prompt = None
get_user_prompt = None
# ============================================================================
# PLUG-AND-PLAY LLM CONFIGURATION
# ============================================================================
# Switch between OpenAI and HuggingFace easily:
# - Set LLM_PROVIDER to "openai" to use OpenAI GPT models (requires OPENAI_API_KEY)
# * Cost: ~$0.002-0.01 per chat (very cheap)
# * Quality: Excellent (GPT-3.5-turbo)
# * Speed: Fast
# - Set LLM_PROVIDER to "huggingface" to use HF Inference API (requires HF_API_TOKEN)
# * Cost: FREE
# * Quality: Good (Qwen2.5-7B-Instruct)
# * Speed: Moderate
#
# For HuggingFace Spaces with OpenAI: Set LLM_PROVIDER=openai + OPENAI_API_KEY
# For HuggingFace Spaces without OpenAI: Set LLM_PROVIDER=huggingface + HF_API_TOKEN
# For Local Development: Set to "openai" if you have OpenAI API key
# ============================================================================
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai") # "openai" or "huggingface" - default to openai for better quality
# OpenAI Configuration
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
# HuggingFace Configuration
# Recommended models for best multilingual support (Hindi, Marathi, English):
# - "Qwen/Qwen2.5-7B-Instruct" (BEST multilingual quality - 18T tokens trained)
# - "mistralai/Mistral-7B-Instruct-v0.2" (good English, weak Hindi/Marathi)
# - "meta-llama/Meta-Llama-3.1-8B-Instruct" (decent multilingual)
# - "google/gemma-2-9b-it" (good multilingual, larger model)
HF_MODEL = os.getenv("HF_MODEL", "Qwen/Qwen2.5-7B-Instruct")
# --- Other Configuration ---
EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
# Using a powerful, free, public Cross-Encoder for better relevance scoring
RERANK_MODEL = os.getenv("RERANK_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
INDEX_PATH = os.getenv("VIPASSANA_INDEX_PATH", "data/vector_store/index.faiss")
META_PATH = os.getenv("VIPASSANA_META_PATH", "data/vector_store/meta.json")
BM25_PATH = os.getenv("VIPASSANA_BM25_PATH", "data/vector_store/bm25.pkl")
class VipassanaRAGAgent:
"""Unified RAG wrapper with Multi-Query Search, BM25, and Cross-Encoder Reranking."""
def __init__(self, openai_api_key: Optional[str] = None, hf_api_token: Optional[str] = None):
# --- CRITICAL FIX: Initialize status attributes ---
self.is_ready = False
self.index_loaded = False
# --------------------------------------------------
self.model = None # Sentence Transformer for embedding
self.reranker = None # Cross Encoder for re-scoring
self.bm25 = None # BM25 for keyword search
self.index = None
self.metadatas = []
# LLM clients (only one will be used based on LLM_PROVIDER)
self.openai_client = None
self.hf_client = None
self.llm_provider = LLM_PROVIDER
# 1. Check for required packages
if SentenceTransformer is None or CrossEncoder is None or faiss is None or np is None:
print("ERROR: Missing required packages (sentence-transformers, faiss, numpy). Cannot initialize RAG components.")
return
# 2. Setup LLM Client based on LLM_PROVIDER
print(f"[LLM] Initializing LLM provider: {self.llm_provider}")
if self.llm_provider == "openai":
# ========== OPENAI SETUP ==========
if OpenAI is None:
print("ERROR: openai package not installed. Run: pip install openai")
return
try:
# Use provided key or environment key
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY not provided. Cannot initialize OpenAI client.")
print("HINT: Set OPENAI_API_KEY in your .env file or HF Spaces secrets")
return
# Initialize with longer timeout for HF Spaces network
self.openai_client = OpenAI(
api_key=api_key,
timeout=60.0, # 60 second timeout
max_retries=3 # Retry up to 3 times on connection errors
)
print(f"[LLM] OpenAI client initialized successfully (model: {OPENAI_MODEL})")
except Exception as e:
print(f"ERROR: Failed to initialize OpenAI client: {e}")
return
elif self.llm_provider == "huggingface":
# ========== HUGGINGFACE SETUP ==========
if InferenceClient is None:
print("ERROR: huggingface_hub package not installed. Run: pip install huggingface_hub")
return
try:
# Use provided token or environment token (optional for public models)
token = hf_api_token or os.getenv("HF_API_TOKEN") or os.getenv("HF_TOKEN")
# Initialize HF Inference Client
self.hf_client = InferenceClient(
model=HF_MODEL,
token=token,
timeout=60.0
)
print(f"[LLM] HuggingFace Inference client initialized successfully")
print(f"[LLM] Using model: {HF_MODEL}")
if not token:
print("[LLM] No HF_API_TOKEN provided - using public inference (may have rate limits)")
except Exception as e:
print(f"ERROR: Failed to initialize HuggingFace client: {e}")
return
else:
print(f"ERROR: Invalid LLM_PROVIDER: {self.llm_provider}. Must be 'openai' or 'huggingface'")
return
# 3. Load Models and Index
try:
print(f"Initializing embedding model: {EMBED_MODEL}")
self.model = SentenceTransformer(EMBED_MODEL)
print(f"Initializing reranking model: {RERANK_MODEL}")
# Reranker model is a CrossEncoder, which expects (query, document) pairs
self.reranker = CrossEncoder(RERANK_MODEL)
self._load_index()
except Exception as e:
print(f"Error loading RAG components (models or index): {e}")
return
# 4. Set final readiness status
if self.index_loaded:
self.is_ready = True
print("VipassanaRAGAgent initialized successfully with Reranking enabled.")
def _load_index(self):
"""Loads the FAISS index, BM25 index, and metadata files."""
print(f"Attempting to load index from {INDEX_PATH}...")
try:
self.index = faiss.read_index(INDEX_PATH)
with open(META_PATH, "r", encoding="utf-8") as f:
self.metadatas = json.load(f)
# Load BM25 index if available
if os.path.exists(BM25_PATH) and BM25Okapi is not None:
with open(BM25_PATH, "rb") as f:
self.bm25 = pickle.load(f)
print(f"BM25 index loaded successfully from {BM25_PATH}")
else:
print(f"Warning: BM25 index not found at {BM25_PATH}. Falling back to FAISS only.")
self.bm25 = None
self.index_loaded = True
print(f"Index loaded successfully with {len(self.metadatas)} chunks.")
except Exception as e:
print(f"Warning: Could not load index from {INDEX_PATH}. Error: {e}")
self.index_loaded = False
self.index = None
self.metadatas = []
self.bm25 = None
def _expand_query(self, query: str) -> List[str]:
"""
Expand query into a list of query variations for better initial recall.
(Same logic as previous revision, focused on getting *more* relevant candidates).
"""
vipassana_terms_map = {
'meditation': ['practice', 'technique', 'sadhana'],
'vipassana': ['insight meditation', 'mindfulness', 'awareness'],
'anapana': ['breathing', 'breath', 'respiration'],
'goenka': ['s.n. goenka', 'teacher', 'acharya'],
'dhamma': ['dharma', 'teaching', 'truth'],
'suffering': ['dukkha', 'pain', 'misery', 'stress'],
'anicca': ['impermanence', 'change'],
}
query_lower = query.lower()
query_list = [query_lower]
# 1. Keyword Substitution Queries
for key, terms in vipassana_terms_map.items():
if key in query_lower:
for term in terms:
# Create a query with one key term substituted
query_list.append(query_lower.replace(key, term))
# 2. Simple Rephrase Heuristic
if "?" in query or query_lower.startswith(("what is", "how to", "why is")):
simple_rephrase = re.sub(r"what is|how to|why is", "", query_lower, 1).strip().replace("?", "")
if simple_rephrase:
query_list.append(simple_rephrase)
# Clean and de-duplicate
unique_queries = []
seen_queries = set()
for q in query_list:
cleaned_q = re.sub(r"\s+", " ", q).strip()
if cleaned_q and cleaned_q not in seen_queries:
unique_queries.append(cleaned_q)
seen_queries.add(cleaned_q)
# Ensure a reasonable limit
return unique_queries[:5]
def _rerank_retrieved_items(self, query: str, retrieved_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Uses a Cross-Encoder model to re-score the relevance of each chunk to the query.
This greatly improves precision.
"""
if not self.reranker:
return retrieved_items # Fallback to original score if reranker isn't loaded
if not retrieved_items:
return []
# Prepare input pairs: [[query, chunk_text], [query, chunk_text], ...]
sentences_to_rank = [[query, item["chunk"]] for item in retrieved_items]
# Calculate new relevance scores (logits)
new_scores = self.reranker.predict(sentences_to_rank)
# Update the score and sort
for i, item in enumerate(retrieved_items):
# Cross-Encoder scores are high for high relevance
item["rerank_score"] = float(new_scores[i])
# Sort by the new, more accurate rerank_score
retrieved_items.sort(key=lambda x: x["rerank_score"], reverse=True)
return retrieved_items
def _bm25_search(self, query: str, top_k: int = 15) -> List[Dict[str, Any]]:
"""
Performs BM25 keyword-based search.
Returns a list of candidate chunks with BM25 scores.
"""
if self.bm25 is None:
return []
# Tokenize query
tokenized_query = query.lower().split()
# Get BM25 scores for all documents
bm25_scores = self.bm25.get_scores(tokenized_query)
# Get top-k indices
top_indices = np.argsort(bm25_scores)[::-1][:top_k]
results = []
for idx in top_indices:
if idx < len(self.metadatas) and bm25_scores[idx] > 0:
metadata = self.metadatas[idx]
results.append({
"score": float(bm25_scores[idx]),
"chunk": metadata.get("chunk", ""),
"source": metadata.get("source"),
"metadata": metadata
})
return results
def retrieve(self, query: str, top_k: int = 5, top_k_initial: int = 30) -> List[Dict[str, Any]]:
"""
Retrieves context using hybrid search (BM25 + FAISS) followed by cross-encoder reranking.
top_k: The final number of chunks passed to the LLM.
top_k_initial: The number of chunks retrieved from FAISS before reranking.
"""
if not self.index_loaded:
return []
# 1. Multi-Query Generation
query_list = self._expand_query(query)
# 2. Search FAISS index with all queries (High Recall)
all_retrieved_items: Dict[str, Dict[str, Any]] = {}
# Distribute the initial search load across all queries (e.g., 30 chunks total)
search_k_per_query = max(5, top_k_initial // len(query_list)) if len(query_list) > 0 else top_k_initial
for single_query in query_list:
if not single_query.strip():
continue
query_embedding = self.model.encode(single_query, convert_to_numpy=True)
query_embedding = normalize(query_embedding.reshape(1, -1))
# Search FAISS index for a large set of candidates
distances, indices = self.index.search(query_embedding.astype("float32"), search_k_per_query)
for i, score in zip(indices[0], distances[0]):
if i >= 0:
metadata = self.metadatas[i]
chunk_text = metadata.get("chunk", "")
if chunk_text:
# De-duplicate chunks, keeping the best original FAISS score
if chunk_text not in all_retrieved_items or score > all_retrieved_items[chunk_text]["score"]:
all_retrieved_items[chunk_text] = {
"score": float(score), # Original FAISS score (used for initial selection)
"chunk": chunk_text,
"source": metadata.get("source"),
"metadata": metadata
}
# 2.5 BM25 Keyword Search (Hybrid Retrieval)
bm25_results = self._bm25_search(query, top_k=15)
for item in bm25_results:
chunk_text = item["chunk"]
if chunk_text and chunk_text not in all_retrieved_items:
all_retrieved_items[chunk_text] = item
initial_candidates = list(all_retrieved_items.values())
# 3. Reranking (High Precision)
if not initial_candidates:
return []
# Rerank all candidates to get the true relevance score
reranked_items = self._rerank_retrieved_items(query, initial_candidates)
# 4. Final Truncation: Return the very best chunks after reranking
return reranked_items[:top_k]
def _get_llm_response(self, query: str, context: str, mode: str = "long") -> str:
"""
Unified method to get LLM response from either OpenAI or HuggingFace.
Automatically routes to the correct provider based on LLM_PROVIDER setting.
"""
if not self.openai_client and not self.hf_client:
return "Internal error: LLM client is not initialized."
# Detect language and get appropriate prompts
if detect_language and get_system_prompt and get_user_prompt:
detected_language = detect_language(query)
system_prompt = get_system_prompt(detected_language)
prompt = get_user_prompt(query, context, detected_language)
print(f"[LLM] Detected language: {detected_language}")
else:
# Fallback to basic English prompts if multilingual module not available
detected_language = "english"
system_prompt = (
"You are the Vipassana Guide AI, a compassionate meditation teacher. "
"Use ONLY the provided CONTEXT. Do not add external knowledge.\n\n"
"Format responses with ## headings, **bold**, and *italic* for Pali/Sanskrit terms.\n"
"Be direct and practical. Include [Source: filename] references."
)
prompt = f"""CONTEXT (Vipassana Knowledge Base):
{context}
USER'S QUESTION:
{query}
Generate the response based strictly on the provided CONTEXT."""
# ========== ROUTE TO CORRECT LLM PROVIDER ==========
if self.llm_provider == "openai":
# ========== OPENAI IMPLEMENTATION ==========
try:
response = self.openai_client.chat.completions.create(
model=OPENAI_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
],
temperature=0.05, # Low temperature ensures the model stays faithful to the context
max_tokens=1200,
top_p=0.9,
timeout=60 # 60 second timeout for this specific request
)
return response.choices[0].message.content
except openai.APIConnectionError as e:
# Network connectivity issues
error_msg = f"Network connection error to OpenAI API. This may be temporary. Please try again in a moment."
print(f"[OpenAI API] Connection Error: {e}")
return error_msg
except openai.APITimeoutError as e:
# Request timed out
error_msg = f"Request to OpenAI timed out. The service may be slow. Please try again."
print(f"[OpenAI API] Timeout Error: {e}")
return error_msg
except openai.AuthenticationError as e:
# API key issue
error_msg = f"OpenAI API authentication failed. Please check your API key configuration."
print(f"[OpenAI API] Authentication Error: {e}")
return error_msg
except openai.RateLimitError as e:
# Rate limit exceeded
error_msg = f"OpenAI API rate limit exceeded. Please wait a moment and try again."
print(f"[OpenAI API] Rate Limit Error: {e}")
return error_msg
except Exception as e:
# Generic error handling
error_msg = f"Error during OpenAI generation: {type(e).__name__} - {str(e)}"
print(f"[OpenAI API] Generic Error: {e}")
return error_msg
elif self.llm_provider == "huggingface":
# ========== HUGGINGFACE IMPLEMENTATION ==========
try:
# Use chat completion API for instruction-tuned models like Mistral
# Format messages similar to OpenAI's chat format
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
# Call HuggingFace Chat Completion API
response = self.hf_client.chat_completion(
messages=messages,
max_tokens=1200,
temperature=0.05, # Very low temperature for consistency (especially important for multilingual)
top_p=0.85, # Reduced top_p for more focused output
)
# Extract the assistant's response
answer = response.choices[0].message.content
return answer.strip()
except Exception as e:
# HuggingFace error handling
error_msg = f"Error during HuggingFace generation: {type(e).__name__} - {str(e)}"
print(f"[HF API] Error: {e}")
print(f"[HF API] Full error details: {e}")
return error_msg
else:
return f"Error: Unknown LLM provider: {self.llm_provider}"
def answer(self, query: str, top_k: int = 6, mode: str = "long") -> Tuple[str, List[str]]:
"""Main method to retrieve context and generate an answer with optimized performance."""
if not self.is_ready:
return "RAG Agent failed to initialize. Please check the console for errors.", []
# 1. Optimized retrieval with reranking for better performance
retrieved_items = self.retrieve(query, top_k=top_k, top_k_initial=20) # Reduced initial search
if not retrieved_items:
return "Could not retrieve relevant documents from the knowledge base.", []
# 2. Format context with better organization
context_parts = []
sources = []
for i, item in enumerate(retrieved_items):
chunk = item.get("chunk", "")
source = item.get("source", "")
if chunk:
# Add source information to each chunk
context_parts.append(f"[Source: {source}]\n{chunk}")
if source and source not in sources:
sources.append(source)
context = "\n\n---\n\n".join(context_parts)
# 3. Generate answer using enhanced context (works with both OpenAI and HuggingFace)
answer = self._get_llm_response(query, context, mode)
return answer, sources
def _clean_text(self, text: str) -> str:
"""Basic cleaning for text extracted from PDFs."""
if not isinstance(text, str):
return ""
# Standard cleaning logic remains (normalized, cleanup hyphens/spaces, etc.)
cleaned = unicodedata.normalize('NFKC', text)
cleaned = re.sub(r"(/c\d+)+", " ", cleaned)
cleaned = re.sub(r"\[\s*\d+\s*\]", " ", cleaned)
cleaned = re.sub(r"\(\s*\d+\s*\)", " ", cleaned)
cleaned = re.sub(r"-\s*\n\s*", "", cleaned)
cleaned = re.sub(r"\n+", " ", cleaned)
cleaned = re.sub(r"[\x00-\x1F\x7F-\x9F]", " ", cleaned)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
return cleaned
# NOTE: The initialization logic in app.py will automatically pick up the new
# RAG agent, but you must ensure that all required dependencies
# (sentence-transformers[for CrossEncoder], numpy, faiss, openai) are installed
# in the environment where this code runs.