Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import hashlib | |
| from typing import List, Dict, Tuple | |
| import os | |
| import time | |
| from document_processor.file_handler import DocumentProcessor | |
| from retriever.builder import RetrieverBuilder | |
| from agents.workflow import AgentWorkflow | |
| from config import constants, settings | |
| from utils.logging import logger | |
| def main(): | |
| processor = DocumentProcessor() | |
| retriever_builder = RetrieverBuilder() | |
| workflow = AgentWorkflow() | |
| # Define custom CSS for ChatGPT-like styling with dark sidebar | |
| css = """ | |
| .sidebar { | |
| background: #202123 !important; | |
| border-right: 1px solid #343541 !important; | |
| max-height: 90vh !important; | |
| height: auto !important; | |
| overflow-y: auto !important; | |
| padding: 15px !important; | |
| color: #ffffff !important; | |
| } | |
| .sidebar * { | |
| color: #ffffff !important; | |
| } | |
| .sidebar label { | |
| color: #d1d5db !important; | |
| } | |
| .sidebar input, .sidebar select, .sidebar textarea { | |
| background: #343541 !important; | |
| color: #ffffff !important; | |
| border: 1px solid #565869 !important; | |
| } | |
| .main-container { | |
| max-height: 90vh !important; | |
| height: auto !important; | |
| overflow-yL: auto !important; | |
| } | |
| .chat-area { | |
| height: 100vh !important; | |
| display: flex !important; | |
| flex-direction: column !important; | |
| padding: 10px !important; | |
| } | |
| .chatbot-container { | |
| flex: 1 1 auto !important; | |
| min-height: 300px !important; | |
| max-height: calc(100vh - 280px) !important; | |
| overflow: hidden !important; | |
| } | |
| .chatbot-container .gradio-chatbot { | |
| height: 100% !important; | |
| max-height: calc(100vh - 280px) !important; | |
| } | |
| .input-area { | |
| margin-top: 10px !important; | |
| } | |
| .processing-status { | |
| padding: 8px !important; | |
| border-radius: 6px !important; | |
| margin: 8px 0 !important; | |
| font-size: 0.9em !important; | |
| } | |
| .success { | |
| background: #d4edda !important; | |
| color: #155724 !important; | |
| border: 1px solid #c3e6cb !important; | |
| } | |
| .error { | |
| background: #f8d7da !important; | |
| color: #721c24 !important; | |
| border: 1px solid #f5c6cb !important; | |
| } | |
| .progress-info { | |
| font-size: 0.85em !important; | |
| color: #666 !important; | |
| margin-top: 5px !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), title="DocChat", css=css) as demo: | |
| # Session state for document processing | |
| session_state = gr.State({ | |
| "file_hashes": frozenset(), | |
| "retriever": None, | |
| "processed_files": [], | |
| "chat_history": [] | |
| }) | |
| # Main layout: Sidebar + Chat | |
| with gr.Row(equal_height=True, elem_classes="main-container"): | |
| # Left Sidebar for file management (narrower) | |
| with gr.Column(scale=0.7, min_width=250, elem_classes="sidebar"): | |
| gr.Markdown("## π Upload your documents here", elem_classes="title") | |
| # File upload component - files shown here | |
| files = gr.Files( | |
| label="Upload Documents", | |
| file_types=constants.ALLOWED_TYPES, | |
| height=150, | |
| show_label=True | |
| ) | |
| # Sync button and status | |
| with gr.Row(): | |
| sync_btn = gr.Button("π Sync", variant="primary", scale=1, size="sm", elem_classes=["flex-item"]) | |
| with gr.Row(): | |
| sync_status_indicator = gr.HTML( | |
| '<div style="padding: 6px; text-align: center; border-radius: 4px; background: #343541; color: #9ca3af; font-size: 0.85em; width: 100%;">Not synced</div>', | |
| visible=True, | |
| elem_classes=["flex-item"] | |
| ) | |
| # Provide equal layout using some CSS tweaks | |
| gr.HTML(""" | |
| <style> | |
| .svelte-1ipelgc.flex-item { flex: 1 1 0 !important; min-width: 0 !important; } | |
| </style> | |
| """) | |
| # Processing status (only show when processing/processed) | |
| processing_status = gr.Markdown("", elem_classes="processing-status", visible=False) | |
| # Verification Report Section | |
| gr.HTML(""" | |
| <style> | |
| .compact-markdown p, | |
| .compact-markdown h4, | |
| .compact-markdown h5, | |
| .compact-markdown h6 { | |
| margin-top: 0.25rem !important; | |
| margin-bottom: 0.25rem !important; | |
| } | |
| </style> | |
| """) | |
| # gr.Markdown("---") | |
| gr.Markdown('<span style="font-size: 1em;">π Verification Report</span>', elem_classes="compact-markdown") | |
| # gr.Markdown('<span style="font-size: 0.85em; color: #8e9297;"><em>Of the last message</em></span>', elem_classes="compact-markdown") | |
| verification_output = gr.Textbox( | |
| label="", | |
| interactive=False, | |
| lines=12, | |
| max_lines=12, | |
| value="", | |
| placeholder="""### Verification Report""", | |
| show_label=False | |
| ) | |
| # Right side: Chat interface | |
| with gr.Column(scale=4, elem_classes="chat-area"): | |
| # Header section | |
| gr.Markdown("# π€ GeekBot *Your personal AI*") | |
| gr.Markdown("*Enter your documents and start chatting about it. Supports ppt,pdf,txt, etc*") | |
| # Chat interface - flex to fill available space | |
| with gr.Column(elem_classes="chatbot-container"): | |
| chatbot = gr.Chatbot( | |
| label="", | |
| show_label=False, | |
| show_copy_button=True, | |
| avatar_images=(None, "π€"), | |
| container=True, | |
| height=550 | |
| ) | |
| # Input area | |
| with gr.Row(elem_classes="input-area"): | |
| msg = gr.Textbox( | |
| label="", | |
| placeholder="Type your question here...", | |
| show_label=False, | |
| scale=9, | |
| container=False | |
| ) | |
| submit_btn = gr.Button("Send", scale=1, variant="primary") | |
| # Function to remove files from ChromaDB when they're removed from UI | |
| def handle_file_removal(current_files: List, state: Dict): | |
| """Handle file removal - clean up ChromaDB and retriever if files are removed.""" | |
| if not current_files: | |
| # All files removed - reset retriever | |
| if state.get("retriever"): | |
| logger.info("All files removed. Resetting retriever.") | |
| state.update({ | |
| "retriever": None, | |
| "file_hashes": frozenset(), | |
| "processed_files": [] | |
| }) | |
| return ( | |
| get_sync_status_html("ready"), | |
| "", # processing_status | |
| gr.update(visible=False), # processing_status visibility | |
| state | |
| ) | |
| return ( | |
| get_sync_status_html("ready"), | |
| "", | |
| gr.update(visible=False), | |
| state | |
| ) | |
| # Check if any files were removed | |
| current_hashes = _get_file_hashes(current_files) | |
| if state.get("file_hashes") and current_hashes != state["file_hashes"]: | |
| # Files were removed - need to rebuild retriever with remaining files | |
| logger.info("Files were removed. Rebuilding retriever with remaining files...") | |
| try: | |
| chunks = processor.process(current_files) | |
| retriever = retriever_builder.build_hybrid_retriever(chunks) | |
| state.update({ | |
| "file_hashes": current_hashes, | |
| "retriever": retriever, | |
| "processed_files": current_files | |
| }) | |
| status_html = "β **Documents resynced**<br>" | |
| status_html += f"<div class='progress-info'>{len(chunks)} chunks indexed.</div>" | |
| return ( | |
| get_sync_status_html("synced", len(chunks)), | |
| status_html, | |
| gr.update(visible=True), | |
| state | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error resyncing after file removal: {e}") | |
| return ( | |
| get_sync_status_html("error"), | |
| f"β Error: {str(e)}", | |
| gr.update(visible=True), | |
| state | |
| ) | |
| return ( | |
| get_sync_status_html("synced", len(state.get("processed_files", []))), | |
| "", | |
| gr.update(visible=False), | |
| state | |
| ) | |
| # Function to update sync status indicator | |
| def get_sync_status_html(status: str, count: int = 0) -> str: | |
| """Generate HTML for sync status indicator.""" | |
| if status == "synced": | |
| return f'<div style="padding: 8px; text-align: center; border-radius: 4px; background: #16a34a; color: #ffffff; font-weight: bold;">β Synced ({count} chunks)</div>' | |
| elif status == "syncing": | |
| return '<div style="padding: 8px; text-align: center; border-radius: 4px; background: #f59e0b; color: #ffffff; font-weight: bold;">π Syncing...</div>' | |
| elif status == "error": | |
| return '<div style="padding: 8px; text-align: center; border-radius: 4px; background: #dc2626; color: #ffffff; font-weight: bold;">β Error</div>' | |
| else: | |
| return '<div style="padding: 8px; text-align: center; border-radius: 4px; background: #343541; color: #9ca3af;">Not synced</div>' | |
| # Function to process files (called by sync button) | |
| def process_files(uploaded_files: List, state: Dict): | |
| """Process files and build retriever.""" | |
| if not uploaded_files: | |
| return ( | |
| get_sync_status_html("ready"), # sync_status | |
| "", # processing_status | |
| gr.update(visible=False), # processing_status visibility | |
| state | |
| ) | |
| try: | |
| current_hashes = _get_file_hashes(uploaded_files) | |
| # Check if files are new or changed | |
| if state["retriever"] is None or current_hashes != state["file_hashes"]: | |
| # Process documents | |
| logger.info("Processing new/changed documents...") | |
| chunks = processor.process(uploaded_files) | |
| logger.info("Building retriever...") | |
| retriever = retriever_builder.build_hybrid_retriever(chunks) | |
| logger.info("Retriever built successfully") | |
| state.update({ | |
| "file_hashes": current_hashes, | |
| "retriever": retriever, | |
| "processed_files": uploaded_files | |
| }) | |
| status_html = "β **Documents synced successfully!**<br>" | |
| status_html += f"<div class='progress-info'>{len(chunks)} chunks indexed. Ready for questions!</div>" | |
| return ( | |
| get_sync_status_html("synced", len(chunks)), # sync_status | |
| status_html, # processing_status | |
| gr.update(visible=True), # processing_status visibility | |
| state | |
| ) | |
| else: | |
| # Files unchanged, already synced | |
| status_html = "β **Documents already synced**<br>" | |
| status_html += "<div class='progress-info'>Files are up to date. Ready for questions!</div>" | |
| # Get chunk count from state if available | |
| chunk_count = len(state.get("processed_files", [])) | |
| return ( | |
| get_sync_status_html("synced", chunk_count), # sync_status | |
| status_html, # processing_status | |
| gr.update(visible=True), # processing_status visibility | |
| state | |
| ) | |
| except Exception as e: | |
| error_html = f"β **Error syncing documents**<br>" | |
| error_html += f"<div class='progress-info'>{str(e)}</div>" | |
| logger.error(f"File processing error: {str(e)}") | |
| return ( | |
| get_sync_status_html("error"), # sync_status | |
| error_html, # processing_status | |
| gr.update(visible=True), # processing_status visibility | |
| state | |
| ) | |
| # Chat function for handling questions | |
| def chat_function(message: str, history: List, state: Dict, verification_state: str): | |
| """Handle chat messages and generate responses.""" | |
| try: | |
| if not message.strip(): | |
| history.append((message, "Please enter a question.")) | |
| return history, "", state, verification_state | |
| if state["retriever"] is None: | |
| history.append((message, "β No documents uploaded. Please upload documents first.")) | |
| return history, "", state, verification_state | |
| # Get answer from workflow | |
| result = workflow.full_pipeline( | |
| question=message, | |
| retriever=state["retriever"] | |
| ) | |
| answer = result["draft_answer"] | |
| verification_report = result["verification_report"] | |
| # Add to chat history | |
| history.append((message, answer)) | |
| # Update state | |
| if "chat_history" not in state: | |
| state["chat_history"] = [] | |
| state["chat_history"].append({"question": message, "answer": answer}) | |
| return history, "", state, verification_report | |
| except Exception as e: | |
| logger.error(f"Chat error: {str(e)}") | |
| error_msg = f"β Error: {str(e)}" | |
| history.append((message, error_msg)) | |
| return history, "", state, "" | |
| # Event handlers | |
| # Handle file removal - check when files change | |
| files.change( | |
| fn=handle_file_removal, | |
| inputs=[files, session_state], | |
| outputs=[sync_status_indicator, processing_status, processing_status, session_state] | |
| ) | |
| # Sync button to process files | |
| sync_btn.click( | |
| fn=process_files, | |
| inputs=[files, session_state], | |
| outputs=[sync_status_indicator, processing_status, processing_status, session_state], | |
| show_progress=True | |
| ) | |
| # Chat submission | |
| msg.submit( | |
| fn=chat_function, | |
| inputs=[msg, chatbot, session_state, verification_output], | |
| outputs=[chatbot, msg, session_state, verification_output] | |
| ) | |
| submit_btn.click( | |
| fn=chat_function, | |
| inputs=[msg, chatbot, session_state, verification_output], | |
| outputs=[chatbot, msg, session_state, verification_output] | |
| ) | |
| # demo.launch(server_name="127.0.0.1", server_port=5000, share=True) | |
| return demo | |
| def _get_file_hashes(uploaded_files: List) -> frozenset: | |
| """Generate SHA-256 hashes for uploaded files.""" | |
| hashes = set() | |
| for file in uploaded_files: | |
| # Handle both Gradio file objects and string paths | |
| file_path = file.name if hasattr(file, 'name') else file | |
| try: | |
| with open(file_path, "rb") as f: | |
| hashes.add(hashlib.sha256(f.read()).hexdigest()) | |
| except Exception as e: | |
| logger.error(f"Error hashing file {file_path}: {e}") | |
| continue | |
| return frozenset(hashes) | |
| if __name__ == "__main__": | |
| main().launch() |