EBD_Fest / chatbot_rag.py
minhvtt's picture
Upload 20 files
cb93402 verified
import gradio as gr
from huggingface_hub import InferenceClient
from pymongo import MongoClient
from datetime import datetime
from typing import List, Dict
import numpy as np
from embedding_service import JinaClipEmbeddingService
from qdrant_service import QdrantVectorService
class ChatbotRAG:
"""
Chatbot RAG vα»›i:
- LLM: GPT-OSS-20B (Hugging Face)
- Embeddings: Jina CLIP v2
- Vector DB: Qdrant
- Document Store: MongoDB
"""
def __init__(
self,
mongodb_uri: str = "mongodb+srv://truongtn7122003:[email protected]/",
db_name: str = "chatbot_rag",
collection_name: str = "documents"
):
"""
Initialize ChatbotRAG
Args:
mongodb_uri: MongoDB connection string
db_name: Database name
collection_name: Collection name for documents
"""
print("Initializing ChatbotRAG...")
# MongoDB client
self.mongo_client = MongoClient(mongodb_uri)
self.db = self.mongo_client[db_name]
self.documents_collection = self.db[collection_name]
self.chat_history_collection = self.db["chat_history"]
# Embedding service (Jina CLIP v2)
self.embedding_service = JinaClipEmbeddingService(
model_path="jinaai/jina-clip-v2"
)
# Qdrant vector service
self.qdrant_service = QdrantVectorService(
collection_name="chatbot_rag_vectors",
vector_size=self.embedding_service.get_embedding_dimension()
)
print("βœ“ ChatbotRAG initialized successfully")
def add_document(self, text: str, metadata: Dict = None) -> str:
"""
Add document to MongoDB and Qdrant
Args:
text: Document text
metadata: Additional metadata
Returns:
Document ID
"""
# Save to MongoDB
doc_data = {
"text": text,
"metadata": metadata or {},
"created_at": datetime.utcnow()
}
result = self.documents_collection.insert_one(doc_data)
doc_id = str(result.inserted_id)
# Generate embedding
embedding = self.embedding_service.encode_text(text)
# Index to Qdrant
self.qdrant_service.index_data(
doc_id=doc_id,
embedding=embedding,
metadata={
"text": text,
"source": "user_upload",
**(metadata or {})
}
)
return doc_id
def retrieve_context(self, query: str, top_k: int = 3) -> List[Dict]:
"""
Retrieve relevant context from vector DB
Args:
query: User query
top_k: Number of results to retrieve
Returns:
List of relevant documents
"""
# Generate query embedding
query_embedding = self.embedding_service.encode_text(query)
# Search in Qdrant
results = self.qdrant_service.search(
query_embedding=query_embedding,
limit=top_k,
score_threshold=0.5 # Only get relevant results
)
return results
def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]):
"""
Save chat interaction to MongoDB
Args:
user_message: User's message
assistant_response: Assistant's response
context_used: Context retrieved from RAG
"""
chat_data = {
"user_message": user_message,
"assistant_response": assistant_response,
"context_used": context_used,
"timestamp": datetime.utcnow()
}
self.chat_history_collection.insert_one(chat_data)
def respond(
self,
message: str,
history: List[Dict[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
use_rag: bool,
hf_token: gr.OAuthToken,
):
"""
Generate response with RAG
Args:
message: User message
history: Chat history
system_message: System prompt
max_tokens: Max tokens to generate
temperature: Temperature for generation
top_p: Top-p sampling
use_rag: Whether to use RAG retrieval
hf_token: Hugging Face token
Yields:
Generated response
"""
# Initialize LLM client
client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
# Prepare context from RAG
context_text = ""
context_used = []
if use_rag:
# Retrieve relevant context
retrieved_docs = self.retrieve_context(message, top_k=3)
context_used = retrieved_docs
if retrieved_docs:
context_text = "\n\n**Relevant Context:**\n"
for i, doc in enumerate(retrieved_docs, 1):
doc_text = doc["metadata"].get("text", "")
confidence = doc["confidence"]
context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n"
# Add context to system message
system_message = f"{system_message}\n\n{context_text}\n\nPlease use the above context to answer the user's question when relevant."
# Build messages for LLM
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": message})
# Generate response
response = ""
try:
for msg in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
choices = msg.choices
token = ""
if len(choices) and choices[0].delta.content:
token = choices[0].delta.content
response += token
yield response
# Save to chat history
self.save_chat_history(message, response, context_used)
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
yield error_msg
# Initialize ChatbotRAG
chatbot_rag = ChatbotRAG()
def respond_wrapper(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
use_rag,
hf_token,
):
"""Wrapper for Gradio ChatInterface"""
yield from chatbot_rag.respond(
message=message,
history=history,
system_message=system_message,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
use_rag=use_rag,
hf_token=hf_token,
)
def add_document_to_rag(text: str) -> str:
"""
Add document to RAG knowledge base
Args:
text: Document text
Returns:
Success message
"""
try:
doc_id = chatbot_rag.add_document(text)
return f"βœ“ Document added successfully! ID: {doc_id}"
except Exception as e:
return f"βœ— Error adding document: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="ChatbotRAG - GPT-OSS-20B + Jina CLIP v2 + MongoDB") as demo:
gr.Markdown("""
# πŸ€– ChatbotRAG
**Features:**
- πŸ’¬ LLM: GPT-OSS-20B
- πŸ” Embeddings: Jina CLIP v2 (Vietnamese support)
- πŸ“Š Vector DB: Qdrant Cloud
- πŸ—„οΈ Document Store: MongoDB
**How to use:**
1. Add documents to knowledge base (optional)
2. Toggle "Use RAG" to enable context retrieval
3. Chat with the bot!
""")
with gr.Sidebar():
gr.LoginButton()
gr.Markdown("### βš™οΈ Settings")
use_rag = gr.Checkbox(
label="Use RAG",
value=True,
info="Enable RAG to retrieve relevant context from knowledge base"
)
system_message = gr.Textbox(
value="You are a helpful AI assistant. Answer questions based on the provided context when available.",
label="System message",
lines=3
)
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
)
# Chat interface
chatbot = gr.ChatInterface(
respond_wrapper,
type="messages",
additional_inputs=[
system_message,
max_tokens,
temperature,
top_p,
use_rag,
],
)
# Document management
with gr.Accordion("πŸ“š Knowledge Base Management", open=False):
gr.Markdown("### Add Documents to Knowledge Base")
doc_text = gr.Textbox(
label="Document Text",
placeholder="Enter document text here...",
lines=5
)
add_btn = gr.Button("Add Document", variant="primary")
output_msg = gr.Textbox(label="Status", interactive=False)
add_btn.click(
fn=add_document_to_rag,
inputs=[doc_text],
outputs=[output_msg]
)
chatbot.render()
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)