|
|
import gradio as gr |
|
|
from huggingface_hub import InferenceClient |
|
|
from pymongo import MongoClient |
|
|
from datetime import datetime |
|
|
from typing import List, Dict |
|
|
import numpy as np |
|
|
|
|
|
from embedding_service import JinaClipEmbeddingService |
|
|
from qdrant_service import QdrantVectorService |
|
|
|
|
|
|
|
|
class ChatbotRAG: |
|
|
""" |
|
|
Chatbot RAG vα»i: |
|
|
- LLM: GPT-OSS-20B (Hugging Face) |
|
|
- Embeddings: Jina CLIP v2 |
|
|
- Vector DB: Qdrant |
|
|
- Document Store: MongoDB |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mongodb_uri: str = "mongodb+srv://truongtn7122003:[email protected]/", |
|
|
db_name: str = "chatbot_rag", |
|
|
collection_name: str = "documents" |
|
|
): |
|
|
""" |
|
|
Initialize ChatbotRAG |
|
|
|
|
|
Args: |
|
|
mongodb_uri: MongoDB connection string |
|
|
db_name: Database name |
|
|
collection_name: Collection name for documents |
|
|
""" |
|
|
print("Initializing ChatbotRAG...") |
|
|
|
|
|
|
|
|
self.mongo_client = MongoClient(mongodb_uri) |
|
|
self.db = self.mongo_client[db_name] |
|
|
self.documents_collection = self.db[collection_name] |
|
|
self.chat_history_collection = self.db["chat_history"] |
|
|
|
|
|
|
|
|
self.embedding_service = JinaClipEmbeddingService( |
|
|
model_path="jinaai/jina-clip-v2" |
|
|
) |
|
|
|
|
|
|
|
|
self.qdrant_service = QdrantVectorService( |
|
|
collection_name="chatbot_rag_vectors", |
|
|
vector_size=self.embedding_service.get_embedding_dimension() |
|
|
) |
|
|
|
|
|
print("β ChatbotRAG initialized successfully") |
|
|
|
|
|
def add_document(self, text: str, metadata: Dict = None) -> str: |
|
|
""" |
|
|
Add document to MongoDB and Qdrant |
|
|
|
|
|
Args: |
|
|
text: Document text |
|
|
metadata: Additional metadata |
|
|
|
|
|
Returns: |
|
|
Document ID |
|
|
""" |
|
|
|
|
|
doc_data = { |
|
|
"text": text, |
|
|
"metadata": metadata or {}, |
|
|
"created_at": datetime.utcnow() |
|
|
} |
|
|
result = self.documents_collection.insert_one(doc_data) |
|
|
doc_id = str(result.inserted_id) |
|
|
|
|
|
|
|
|
embedding = self.embedding_service.encode_text(text) |
|
|
|
|
|
|
|
|
self.qdrant_service.index_data( |
|
|
doc_id=doc_id, |
|
|
embedding=embedding, |
|
|
metadata={ |
|
|
"text": text, |
|
|
"source": "user_upload", |
|
|
**(metadata or {}) |
|
|
} |
|
|
) |
|
|
|
|
|
return doc_id |
|
|
|
|
|
def retrieve_context(self, query: str, top_k: int = 3) -> List[Dict]: |
|
|
""" |
|
|
Retrieve relevant context from vector DB |
|
|
|
|
|
Args: |
|
|
query: User query |
|
|
top_k: Number of results to retrieve |
|
|
|
|
|
Returns: |
|
|
List of relevant documents |
|
|
""" |
|
|
|
|
|
query_embedding = self.embedding_service.encode_text(query) |
|
|
|
|
|
|
|
|
results = self.qdrant_service.search( |
|
|
query_embedding=query_embedding, |
|
|
limit=top_k, |
|
|
score_threshold=0.5 |
|
|
) |
|
|
|
|
|
return results |
|
|
|
|
|
def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]): |
|
|
""" |
|
|
Save chat interaction to MongoDB |
|
|
|
|
|
Args: |
|
|
user_message: User's message |
|
|
assistant_response: Assistant's response |
|
|
context_used: Context retrieved from RAG |
|
|
""" |
|
|
chat_data = { |
|
|
"user_message": user_message, |
|
|
"assistant_response": assistant_response, |
|
|
"context_used": context_used, |
|
|
"timestamp": datetime.utcnow() |
|
|
} |
|
|
self.chat_history_collection.insert_one(chat_data) |
|
|
|
|
|
def respond( |
|
|
self, |
|
|
message: str, |
|
|
history: List[Dict[str, str]], |
|
|
system_message: str, |
|
|
max_tokens: int, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
use_rag: bool, |
|
|
hf_token: gr.OAuthToken, |
|
|
): |
|
|
""" |
|
|
Generate response with RAG |
|
|
|
|
|
Args: |
|
|
message: User message |
|
|
history: Chat history |
|
|
system_message: System prompt |
|
|
max_tokens: Max tokens to generate |
|
|
temperature: Temperature for generation |
|
|
top_p: Top-p sampling |
|
|
use_rag: Whether to use RAG retrieval |
|
|
hf_token: Hugging Face token |
|
|
|
|
|
Yields: |
|
|
Generated response |
|
|
""" |
|
|
|
|
|
client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b") |
|
|
|
|
|
|
|
|
context_text = "" |
|
|
context_used = [] |
|
|
|
|
|
if use_rag: |
|
|
|
|
|
retrieved_docs = self.retrieve_context(message, top_k=3) |
|
|
context_used = retrieved_docs |
|
|
|
|
|
if retrieved_docs: |
|
|
context_text = "\n\n**Relevant Context:**\n" |
|
|
for i, doc in enumerate(retrieved_docs, 1): |
|
|
doc_text = doc["metadata"].get("text", "") |
|
|
confidence = doc["confidence"] |
|
|
context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n" |
|
|
|
|
|
|
|
|
system_message = f"{system_message}\n\n{context_text}\n\nPlease use the above context to answer the user's question when relevant." |
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": system_message}] |
|
|
messages.extend(history) |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
response = "" |
|
|
|
|
|
try: |
|
|
for msg in client.chat_completion( |
|
|
messages, |
|
|
max_tokens=max_tokens, |
|
|
stream=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
): |
|
|
choices = msg.choices |
|
|
token = "" |
|
|
if len(choices) and choices[0].delta.content: |
|
|
token = choices[0].delta.content |
|
|
|
|
|
response += token |
|
|
yield response |
|
|
|
|
|
|
|
|
self.save_chat_history(message, response, context_used) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error generating response: {str(e)}" |
|
|
yield error_msg |
|
|
|
|
|
|
|
|
|
|
|
chatbot_rag = ChatbotRAG() |
|
|
|
|
|
|
|
|
def respond_wrapper( |
|
|
message, |
|
|
history, |
|
|
system_message, |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
use_rag, |
|
|
hf_token, |
|
|
): |
|
|
"""Wrapper for Gradio ChatInterface""" |
|
|
yield from chatbot_rag.respond( |
|
|
message=message, |
|
|
history=history, |
|
|
system_message=system_message, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
use_rag=use_rag, |
|
|
hf_token=hf_token, |
|
|
) |
|
|
|
|
|
|
|
|
def add_document_to_rag(text: str) -> str: |
|
|
""" |
|
|
Add document to RAG knowledge base |
|
|
|
|
|
Args: |
|
|
text: Document text |
|
|
|
|
|
Returns: |
|
|
Success message |
|
|
""" |
|
|
try: |
|
|
doc_id = chatbot_rag.add_document(text) |
|
|
return f"β Document added successfully! ID: {doc_id}" |
|
|
except Exception as e: |
|
|
return f"β Error adding document: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="ChatbotRAG - GPT-OSS-20B + Jina CLIP v2 + MongoDB") as demo: |
|
|
gr.Markdown(""" |
|
|
# π€ ChatbotRAG |
|
|
|
|
|
**Features:** |
|
|
- π¬ LLM: GPT-OSS-20B |
|
|
- π Embeddings: Jina CLIP v2 (Vietnamese support) |
|
|
- π Vector DB: Qdrant Cloud |
|
|
- ποΈ Document Store: MongoDB |
|
|
|
|
|
**How to use:** |
|
|
1. Add documents to knowledge base (optional) |
|
|
2. Toggle "Use RAG" to enable context retrieval |
|
|
3. Chat with the bot! |
|
|
""") |
|
|
|
|
|
with gr.Sidebar(): |
|
|
gr.LoginButton() |
|
|
|
|
|
gr.Markdown("### βοΈ Settings") |
|
|
|
|
|
use_rag = gr.Checkbox( |
|
|
label="Use RAG", |
|
|
value=True, |
|
|
info="Enable RAG to retrieve relevant context from knowledge base" |
|
|
) |
|
|
|
|
|
system_message = gr.Textbox( |
|
|
value="You are a helpful AI assistant. Answer questions based on the provided context when available.", |
|
|
label="System message", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
max_tokens = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=2048, |
|
|
value=512, |
|
|
step=1, |
|
|
label="Max new tokens" |
|
|
) |
|
|
|
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=4.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
) |
|
|
|
|
|
top_p = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.95, |
|
|
step=0.05, |
|
|
label="Top-p (nucleus sampling)" |
|
|
) |
|
|
|
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
respond_wrapper, |
|
|
type="messages", |
|
|
additional_inputs=[ |
|
|
system_message, |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
use_rag, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Knowledge Base Management", open=False): |
|
|
gr.Markdown("### Add Documents to Knowledge Base") |
|
|
|
|
|
doc_text = gr.Textbox( |
|
|
label="Document Text", |
|
|
placeholder="Enter document text here...", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
add_btn = gr.Button("Add Document", variant="primary") |
|
|
output_msg = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
add_btn.click( |
|
|
fn=add_document_to_rag, |
|
|
inputs=[doc_text], |
|
|
outputs=[output_msg] |
|
|
) |
|
|
|
|
|
chatbot.render() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|