import gradio as gr
import hashlib
from typing import List, Dict, Tuple
import os
import time
from document_processor.file_handler import DocumentProcessor
from retriever.builder import RetrieverBuilder
from agents.workflow import AgentWorkflow
from config import constants, settings
from utils.logging import logger
def main():
processor = DocumentProcessor()
retriever_builder = RetrieverBuilder()
workflow = AgentWorkflow()
# Define custom CSS for ChatGPT-like styling with dark sidebar
css = """
.sidebar {
background: #202123 !important;
border-right: 1px solid #343541 !important;
max-height: 90vh !important;
height: auto !important;
overflow-y: auto !important;
padding: 15px !important;
color: #ffffff !important;
}
.sidebar * {
color: #ffffff !important;
}
.sidebar label {
color: #d1d5db !important;
}
.sidebar input, .sidebar select, .sidebar textarea {
background: #343541 !important;
color: #ffffff !important;
border: 1px solid #565869 !important;
}
.main-container {
max-height: 90vh !important;
height: auto !important;
overflow-yL: auto !important;
}
.chat-area {
height: 100vh !important;
display: flex !important;
flex-direction: column !important;
padding: 10px !important;
}
.chatbot-container {
flex: 1 1 auto !important;
min-height: 300px !important;
max-height: calc(100vh - 280px) !important;
overflow: hidden !important;
}
.chatbot-container .gradio-chatbot {
height: 100% !important;
max-height: calc(100vh - 280px) !important;
}
.input-area {
margin-top: 10px !important;
}
.processing-status {
padding: 8px !important;
border-radius: 6px !important;
margin: 8px 0 !important;
font-size: 0.9em !important;
}
.success {
background: #d4edda !important;
color: #155724 !important;
border: 1px solid #c3e6cb !important;
}
.error {
background: #f8d7da !important;
color: #721c24 !important;
border: 1px solid #f5c6cb !important;
}
.progress-info {
font-size: 0.85em !important;
color: #666 !important;
margin-top: 5px !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), title="DocChat", css=css) as demo:
# Session state for document processing
session_state = gr.State({
"file_hashes": frozenset(),
"retriever": None,
"processed_files": [],
"chat_history": []
})
# Main layout: Sidebar + Chat
with gr.Row(equal_height=True, elem_classes="main-container"):
# Left Sidebar for file management (narrower)
with gr.Column(scale=0.7, min_width=250, elem_classes="sidebar"):
gr.Markdown("## 📁 Upload your documents here", elem_classes="title")
# File upload component - files shown here
files = gr.Files(
label="Upload Documents",
file_types=constants.ALLOWED_TYPES,
height=150,
show_label=True
)
# Sync button and status
with gr.Row():
sync_btn = gr.Button("🔄 Sync", variant="primary", scale=1, size="sm", elem_classes=["flex-item"])
with gr.Row():
sync_status_indicator = gr.HTML(
'
Not synced
',
visible=True,
elem_classes=["flex-item"]
)
# Provide equal layout using some CSS tweaks
gr.HTML("""
""")
# Processing status (only show when processing/processed)
processing_status = gr.Markdown("", elem_classes="processing-status", visible=False)
# Verification Report Section
gr.HTML("""
""")
# gr.Markdown("---")
gr.Markdown('📊 Verification Report', elem_classes="compact-markdown")
# gr.Markdown('Of the last message', elem_classes="compact-markdown")
verification_output = gr.Textbox(
label="",
interactive=False,
lines=12,
max_lines=12,
value="",
placeholder="""### Verification Report""",
show_label=False
)
# Right side: Chat interface
with gr.Column(scale=4, elem_classes="chat-area"):
# Header section
gr.Markdown("# 🤖 GeekBot *Your personal AI*")
gr.Markdown("*Enter your documents and start chatting about it. Supports ppt,pdf,txt, etc*")
# Chat interface - flex to fill available space
with gr.Column(elem_classes="chatbot-container"):
chatbot = gr.Chatbot(
label="",
show_label=False,
show_copy_button=True,
avatar_images=(None, "🤖"),
container=True,
height=550
)
# Input area
with gr.Row(elem_classes="input-area"):
msg = gr.Textbox(
label="",
placeholder="Type your question here...",
show_label=False,
scale=9,
container=False
)
submit_btn = gr.Button("Send", scale=1, variant="primary")
# Function to remove files from ChromaDB when they're removed from UI
def handle_file_removal(current_files: List, state: Dict):
"""Handle file removal - clean up ChromaDB and retriever if files are removed."""
if not current_files:
# All files removed - reset retriever
if state.get("retriever"):
logger.info("All files removed. Resetting retriever.")
state.update({
"retriever": None,
"file_hashes": frozenset(),
"processed_files": []
})
return (
get_sync_status_html("ready"),
"", # processing_status
gr.update(visible=False), # processing_status visibility
state
)
return (
get_sync_status_html("ready"),
"",
gr.update(visible=False),
state
)
# Check if any files were removed
current_hashes = _get_file_hashes(current_files)
if state.get("file_hashes") and current_hashes != state["file_hashes"]:
# Files were removed - need to rebuild retriever with remaining files
logger.info("Files were removed. Rebuilding retriever with remaining files...")
try:
chunks = processor.process(current_files)
retriever = retriever_builder.build_hybrid_retriever(chunks)
state.update({
"file_hashes": current_hashes,
"retriever": retriever,
"processed_files": current_files
})
status_html = "✅ **Documents resynced**
"
status_html += f"{len(chunks)} chunks indexed.
"
return (
get_sync_status_html("synced", len(chunks)),
status_html,
gr.update(visible=True),
state
)
except Exception as e:
logger.error(f"Error resyncing after file removal: {e}")
return (
get_sync_status_html("error"),
f"❌ Error: {str(e)}",
gr.update(visible=True),
state
)
return (
get_sync_status_html("synced", len(state.get("processed_files", []))),
"",
gr.update(visible=False),
state
)
# Function to update sync status indicator
def get_sync_status_html(status: str, count: int = 0) -> str:
"""Generate HTML for sync status indicator."""
if status == "synced":
return f'✅ Synced ({count} chunks)
'
elif status == "syncing":
return '🔄 Syncing...
'
elif status == "error":
return '❌ Error
'
else:
return 'Not synced
'
# Function to process files (called by sync button)
def process_files(uploaded_files: List, state: Dict):
"""Process files and build retriever."""
if not uploaded_files:
return (
get_sync_status_html("ready"), # sync_status
"", # processing_status
gr.update(visible=False), # processing_status visibility
state
)
try:
current_hashes = _get_file_hashes(uploaded_files)
# Check if files are new or changed
if state["retriever"] is None or current_hashes != state["file_hashes"]:
# Process documents
logger.info("Processing new/changed documents...")
chunks = processor.process(uploaded_files)
logger.info("Building retriever...")
retriever = retriever_builder.build_hybrid_retriever(chunks)
logger.info("Retriever built successfully")
state.update({
"file_hashes": current_hashes,
"retriever": retriever,
"processed_files": uploaded_files
})
status_html = "✅ **Documents synced successfully!**
"
status_html += f"{len(chunks)} chunks indexed. Ready for questions!
"
return (
get_sync_status_html("synced", len(chunks)), # sync_status
status_html, # processing_status
gr.update(visible=True), # processing_status visibility
state
)
else:
# Files unchanged, already synced
status_html = "✅ **Documents already synced**
"
status_html += "Files are up to date. Ready for questions!
"
# Get chunk count from state if available
chunk_count = len(state.get("processed_files", []))
return (
get_sync_status_html("synced", chunk_count), # sync_status
status_html, # processing_status
gr.update(visible=True), # processing_status visibility
state
)
except Exception as e:
error_html = f"❌ **Error syncing documents**
"
error_html += f"{str(e)}
"
logger.error(f"File processing error: {str(e)}")
return (
get_sync_status_html("error"), # sync_status
error_html, # processing_status
gr.update(visible=True), # processing_status visibility
state
)
# Chat function for handling questions
def chat_function(message: str, history: List, state: Dict, verification_state: str):
"""Handle chat messages and generate responses."""
try:
if not message.strip():
history.append((message, "Please enter a question."))
return history, "", state, verification_state
if state["retriever"] is None:
history.append((message, "❌ No documents uploaded. Please upload documents first."))
return history, "", state, verification_state
# Get answer from workflow
result = workflow.full_pipeline(
question=message,
retriever=state["retriever"]
)
answer = result["draft_answer"]
verification_report = result["verification_report"]
# Add to chat history
history.append((message, answer))
# Update state
if "chat_history" not in state:
state["chat_history"] = []
state["chat_history"].append({"question": message, "answer": answer})
return history, "", state, verification_report
except Exception as e:
logger.error(f"Chat error: {str(e)}")
error_msg = f"❌ Error: {str(e)}"
history.append((message, error_msg))
return history, "", state, ""
# Event handlers
# Handle file removal - check when files change
files.change(
fn=handle_file_removal,
inputs=[files, session_state],
outputs=[sync_status_indicator, processing_status, processing_status, session_state]
)
# Sync button to process files
sync_btn.click(
fn=process_files,
inputs=[files, session_state],
outputs=[sync_status_indicator, processing_status, processing_status, session_state],
show_progress=True
)
# Chat submission
msg.submit(
fn=chat_function,
inputs=[msg, chatbot, session_state, verification_output],
outputs=[chatbot, msg, session_state, verification_output]
)
submit_btn.click(
fn=chat_function,
inputs=[msg, chatbot, session_state, verification_output],
outputs=[chatbot, msg, session_state, verification_output]
)
# demo.launch(server_name="127.0.0.1", server_port=5000, share=True)
return demo
def _get_file_hashes(uploaded_files: List) -> frozenset:
"""Generate SHA-256 hashes for uploaded files."""
hashes = set()
for file in uploaded_files:
# Handle both Gradio file objects and string paths
file_path = file.name if hasattr(file, 'name') else file
try:
with open(file_path, "rb") as f:
hashes.add(hashlib.sha256(f.read()).hexdigest())
except Exception as e:
logger.error(f"Error hashing file {file_path}: {e}")
continue
return frozenset(hashes)
if __name__ == "__main__":
main().launch()