Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from pymongo import MongoClient | |
| from datetime import datetime | |
| from typing import List, Dict | |
| import numpy as np | |
| from embedding_service import JinaClipEmbeddingService | |
| from qdrant_service import QdrantVectorService | |
| class ChatbotRAG: | |
| """ | |
| Chatbot RAG vα»i: | |
| - LLM: GPT-OSS-20B (Hugging Face) | |
| - Embeddings: Jina CLIP v2 | |
| - Vector DB: Qdrant | |
| - Document Store: MongoDB | |
| """ | |
| def __init__( | |
| self, | |
| mongodb_uri: str = "mongodb+srv://truongtn7122003:[email protected]/", | |
| db_name: str = "chatbot_rag", | |
| collection_name: str = "documents" | |
| ): | |
| """ | |
| Initialize ChatbotRAG | |
| Args: | |
| mongodb_uri: MongoDB connection string | |
| db_name: Database name | |
| collection_name: Collection name for documents | |
| """ | |
| print("Initializing ChatbotRAG...") | |
| # MongoDB client | |
| self.mongo_client = MongoClient(mongodb_uri) | |
| self.db = self.mongo_client[db_name] | |
| self.documents_collection = self.db[collection_name] | |
| self.chat_history_collection = self.db["chat_history"] | |
| # Embedding service (Jina CLIP v2) | |
| self.embedding_service = JinaClipEmbeddingService( | |
| model_path="jinaai/jina-clip-v2" | |
| ) | |
| # Qdrant vector service | |
| self.qdrant_service = QdrantVectorService( | |
| collection_name="chatbot_rag_vectors", | |
| vector_size=self.embedding_service.get_embedding_dimension() | |
| ) | |
| print("β ChatbotRAG initialized successfully") | |
| def add_document(self, text: str, metadata: Dict = None) -> str: | |
| """ | |
| Add document to MongoDB and Qdrant | |
| Args: | |
| text: Document text | |
| metadata: Additional metadata | |
| Returns: | |
| Document ID | |
| """ | |
| # Save to MongoDB | |
| doc_data = { | |
| "text": text, | |
| "metadata": metadata or {}, | |
| "created_at": datetime.utcnow() | |
| } | |
| result = self.documents_collection.insert_one(doc_data) | |
| doc_id = str(result.inserted_id) | |
| # Generate embedding | |
| embedding = self.embedding_service.encode_text(text) | |
| # Index to Qdrant | |
| self.qdrant_service.index_data( | |
| doc_id=doc_id, | |
| embedding=embedding, | |
| metadata={ | |
| "text": text, | |
| "source": "user_upload", | |
| **(metadata or {}) | |
| } | |
| ) | |
| return doc_id | |
| def retrieve_context(self, query: str, top_k: int = 3) -> List[Dict]: | |
| """ | |
| Retrieve relevant context from vector DB | |
| Args: | |
| query: User query | |
| top_k: Number of results to retrieve | |
| Returns: | |
| List of relevant documents | |
| """ | |
| # Generate query embedding | |
| query_embedding = self.embedding_service.encode_text(query) | |
| # Search in Qdrant | |
| results = self.qdrant_service.search( | |
| query_embedding=query_embedding, | |
| limit=top_k, | |
| score_threshold=0.5 # Only get relevant results | |
| ) | |
| return results | |
| def save_chat_history(self, user_message: str, assistant_response: str, context_used: List[Dict]): | |
| """ | |
| Save chat interaction to MongoDB | |
| Args: | |
| user_message: User's message | |
| assistant_response: Assistant's response | |
| context_used: Context retrieved from RAG | |
| """ | |
| chat_data = { | |
| "user_message": user_message, | |
| "assistant_response": assistant_response, | |
| "context_used": context_used, | |
| "timestamp": datetime.utcnow() | |
| } | |
| self.chat_history_collection.insert_one(chat_data) | |
| def respond( | |
| self, | |
| message: str, | |
| history: List[Dict[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| use_rag: bool, | |
| hf_token: gr.OAuthToken, | |
| ): | |
| """ | |
| Generate response with RAG | |
| Args: | |
| message: User message | |
| history: Chat history | |
| system_message: System prompt | |
| max_tokens: Max tokens to generate | |
| temperature: Temperature for generation | |
| top_p: Top-p sampling | |
| use_rag: Whether to use RAG retrieval | |
| hf_token: Hugging Face token | |
| Yields: | |
| Generated response | |
| """ | |
| # Initialize LLM client | |
| client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b") | |
| # Prepare context from RAG | |
| context_text = "" | |
| context_used = [] | |
| if use_rag: | |
| # Retrieve relevant context | |
| retrieved_docs = self.retrieve_context(message, top_k=3) | |
| context_used = retrieved_docs | |
| if retrieved_docs: | |
| context_text = "\n\n**Relevant Context:**\n" | |
| for i, doc in enumerate(retrieved_docs, 1): | |
| doc_text = doc["metadata"].get("text", "") | |
| confidence = doc["confidence"] | |
| context_text += f"\n[{i}] (Confidence: {confidence:.2f})\n{doc_text}\n" | |
| # Add context to system message | |
| system_message = f"{system_message}\n\n{context_text}\n\nPlease use the above context to answer the user's question when relevant." | |
| # Build messages for LLM | |
| messages = [{"role": "system", "content": system_message}] | |
| messages.extend(history) | |
| messages.append({"role": "user", "content": message}) | |
| # Generate response | |
| response = "" | |
| try: | |
| for msg in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| choices = msg.choices | |
| token = "" | |
| if len(choices) and choices[0].delta.content: | |
| token = choices[0].delta.content | |
| response += token | |
| yield response | |
| # Save to chat history | |
| self.save_chat_history(message, response, context_used) | |
| except Exception as e: | |
| error_msg = f"Error generating response: {str(e)}" | |
| yield error_msg | |
| # Initialize ChatbotRAG | |
| chatbot_rag = ChatbotRAG() | |
| def respond_wrapper( | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| use_rag, | |
| hf_token, | |
| ): | |
| """Wrapper for Gradio ChatInterface""" | |
| yield from chatbot_rag.respond( | |
| message=message, | |
| history=history, | |
| system_message=system_message, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| use_rag=use_rag, | |
| hf_token=hf_token, | |
| ) | |
| def add_document_to_rag(text: str) -> str: | |
| """ | |
| Add document to RAG knowledge base | |
| Args: | |
| text: Document text | |
| Returns: | |
| Success message | |
| """ | |
| try: | |
| doc_id = chatbot_rag.add_document(text) | |
| return f"β Document added successfully! ID: {doc_id}" | |
| except Exception as e: | |
| return f"β Error adding document: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="ChatbotRAG - GPT-OSS-20B + Jina CLIP v2 + MongoDB") as demo: | |
| gr.Markdown(""" | |
| # π€ ChatbotRAG | |
| **Features:** | |
| - π¬ LLM: GPT-OSS-20B | |
| - π Embeddings: Jina CLIP v2 (Vietnamese support) | |
| - π Vector DB: Qdrant Cloud | |
| - ποΈ Document Store: MongoDB | |
| **How to use:** | |
| 1. Add documents to knowledge base (optional) | |
| 2. Toggle "Use RAG" to enable context retrieval | |
| 3. Chat with the bot! | |
| """) | |
| with gr.Sidebar(): | |
| gr.LoginButton() | |
| gr.Markdown("### βοΈ Settings") | |
| use_rag = gr.Checkbox( | |
| label="Use RAG", | |
| value=True, | |
| info="Enable RAG to retrieve relevant context from knowledge base" | |
| ) | |
| system_message = gr.Textbox( | |
| value="You are a helpful AI assistant. Answer questions based on the provided context when available.", | |
| label="System message", | |
| lines=3 | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| value=512, | |
| step=1, | |
| label="Max new tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=4.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ) | |
| # Chat interface | |
| chatbot = gr.ChatInterface( | |
| respond_wrapper, | |
| type="messages", | |
| additional_inputs=[ | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| use_rag, | |
| ], | |
| ) | |
| # Document management | |
| with gr.Accordion("π Knowledge Base Management", open=False): | |
| gr.Markdown("### Add Documents to Knowledge Base") | |
| doc_text = gr.Textbox( | |
| label="Document Text", | |
| placeholder="Enter document text here...", | |
| lines=5 | |
| ) | |
| add_btn = gr.Button("Add Document", variant="primary") | |
| output_msg = gr.Textbox(label="Status", interactive=False) | |
| add_btn.click( | |
| fn=add_document_to_rag, | |
| inputs=[doc_text], | |
| outputs=[output_msg] | |
| ) | |
| chatbot.render() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |