Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| VedaMD Enhanced: Sri Lankan Clinical Assistant | |
| Main Gradio Application for Hugging Face Spaces Deployment | |
| Enhanced Medical-Grade RAG System with: | |
| β 5x Enhanced Retrieval (15+ documents vs previous 5) | |
| β Medical Entity Extraction & Clinical Terminology | |
| β Clinical ModernBERT (768d medical embeddings) | |
| β Medical Response Verification & Safety Protocols | |
| β Advanced Re-ranking & Coverage Verification | |
| β Source Traceability & Citation Support | |
| """ | |
| import os | |
| import logging | |
| import gradio as gr | |
| from typing import List, Dict, Optional | |
| import sys | |
| # Add src directory to path for imports | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) | |
| from src.enhanced_groq_medical_rag import EnhancedGroqMedicalRAG, EnhancedMedicalResponse | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Security: Verify API key is loaded from environment (not hardcoded) | |
| # For Hugging Face Spaces: Set CEREBRAS_API_KEY in Space Settings > Repository secrets | |
| if not os.getenv("CEREBRAS_API_KEY"): | |
| logger.error("β CEREBRAS_API_KEY not found in environment variables!") | |
| logger.error("β οΈ For Hugging Face Spaces: Add CEREBRAS_API_KEY in Settings > Repository secrets") | |
| logger.error("β οΈ Get your free API key at: https://cloud.cerebras.ai") | |
| raise ValueError("CEREBRAS_API_KEY environment variable is required. Please configure in HF Spaces secrets.") | |
| # Initialize Enhanced Medical RAG System | |
| logger.info("π₯ Initializing VedaMD Enhanced for Hugging Face Spaces...") | |
| try: | |
| enhanced_rag_system = EnhancedGroqMedicalRAG() | |
| logger.info("β Enhanced Medical RAG system ready!") | |
| except Exception as e: | |
| logger.error(f"β Failed to initialize system: {e}") | |
| raise | |
| def validate_input(message: str) -> tuple[bool, str]: | |
| """ | |
| Validate user input for security and quality | |
| Returns: (is_valid, error_message) | |
| """ | |
| if not message or not message.strip(): | |
| return False, "Please enter a medical question about Sri Lankan clinical guidelines." | |
| if len(message) > 2000: | |
| return False, "β οΈ Query too long. Please limit your question to 2000 characters." | |
| # Check for potential prompt injection patterns | |
| suspicious_patterns = ['ignore previous', 'ignore above', 'system:', 'disregard'] | |
| if any(pattern in message.lower() for pattern in suspicious_patterns): | |
| return False, "β οΈ Invalid query format. Please rephrase your medical question." | |
| return True, "" | |
| def process_enhanced_medical_query(message: str, history: List[List[str]]) -> str: | |
| """ | |
| Process medical query with enhanced RAG system and input validation | |
| """ | |
| try: | |
| # Validate input | |
| is_valid, error_msg = validate_input(message) | |
| if not is_valid: | |
| return error_msg | |
| # Convert Gradio chat history to our format | |
| formatted_history = [] | |
| if history: | |
| for chat_pair in history: | |
| if len(chat_pair) >= 2: | |
| user_msg, assistant_msg = chat_pair[0], chat_pair[1] | |
| if user_msg: | |
| formatted_history.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| formatted_history.append({"role": "assistant", "content": assistant_msg}) | |
| # Get enhanced response | |
| response: EnhancedMedicalResponse = enhanced_rag_system.query( | |
| query=message, | |
| history=formatted_history | |
| ) | |
| # Format enhanced response for display | |
| formatted_response = format_enhanced_medical_response(response) | |
| return formatted_response | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}") | |
| return f"β οΈ **System Error**: {str(e)}\n\nPlease try again or contact support if the issue persists." | |
| def format_enhanced_medical_response(response: EnhancedMedicalResponse) -> str: | |
| """ | |
| Format the enhanced medical response for display, ensuring citations are always included. | |
| """ | |
| formatted_parts = [] | |
| # Main response from the LLM | |
| final_response_text = response.answer.strip() | |
| formatted_parts.append(final_response_text) | |
| # ALWAYS add the clinical sources section with clear numbering | |
| if response.sources: | |
| formatted_parts.append("\n\n---\n") | |
| formatted_parts.append("### π **Clinical Sources & Citations**") | |
| formatted_parts.append("\nThis response is based on the following Sri Lankan clinical guidelines:") | |
| # Create a numbered list of all sources used for the response | |
| for i, source in enumerate(response.sources, 1): | |
| # Make the citation number bold and add a clear label | |
| formatted_parts.append(f"\n**[{i}]** Source: {source}") | |
| # Enhanced information section with clear separation | |
| formatted_parts.append("\n\n---\n") | |
| formatted_parts.append("### π **Response Analysis**") | |
| # Safety and verification info with clearer formatting | |
| if response.verification_result: | |
| safety_status = "β SAFE" if response.safety_status == "SAFE" else "β οΈ CAUTION" | |
| formatted_parts.append(f"\n**Medical Safety Status**: {safety_status}") | |
| formatted_parts.append(f"**Verification Score**: {response.verification_result.verification_score:.1%}") | |
| formatted_parts.append(f"**Verified Medical Claims**: {response.verification_result.verified_claims}/{response.verification_result.total_claims}") | |
| # Enhanced retrieval metrics | |
| formatted_parts.append(f"\n**Medical Information Coverage**:") | |
| formatted_parts.append(f"- π§ Medical Entities: {response.medical_entities_count}") | |
| formatted_parts.append(f"- π― Context Adherence: {response.context_adherence_score:.1%}") | |
| formatted_parts.append(f"- π Guidelines Referenced: {len(response.sources)}") | |
| # Always include processing time if available | |
| if hasattr(response, 'query_time'): | |
| formatted_parts.append(f"- β‘ Processing Time: {response.query_time:.2f}s") | |
| # Medical disclaimer with clear separation | |
| formatted_parts.append("\n\n---\n") | |
| formatted_parts.append("*βοΈ This information is derived from Sri Lankan clinical guidelines and is for reference only. Always consult with qualified healthcare professionals for patient care decisions.*") | |
| return "\n".join(formatted_parts) | |
| def create_enhanced_medical_interface(): | |
| """ | |
| Create the enhanced Gradio interface for Hugging Face Spaces | |
| """ | |
| # Custom CSS for medical theme | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 900px !important; | |
| margin: auto !important; | |
| } | |
| .medical-header { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| text-align: center; | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="π₯ VedaMD Enhanced: Sri Lankan Clinical Assistant", | |
| theme=gr.themes.Soft(), | |
| css=custom_css | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="medical-header"> | |
| <h1>π₯ VedaMD Enhanced: Sri Lankan Clinical Assistant</h1> | |
| <h3>Ultra-Fast Medical AI powered by Cerebras Inference</h3> | |
| <p>β‘ World's Fastest Inference β’ β Medical Verification β’ β Clinical ModernBERT β’ β Free to Use</p> | |
| </div> | |
| """) | |
| # Description | |
| gr.Markdown(""" | |
| **π©Ί Advanced Medical AI Assistant** for Sri Lankan maternal health guidelines with **enhanced safety protocols**: | |
| π― **Enhanced Features:** | |
| - **5x Enhanced Retrieval**: 15+ documents analyzed vs previous 5 | |
| - **Medical Entity Extraction**: Advanced clinical terminology recognition | |
| - **Clinical ModernBERT**: Specialized 768d medical domain embeddings | |
| - **Medical Response Verification**: 100% source traceability validation | |
| - **Advanced Re-ranking**: Medical relevance scoring with coverage verification | |
| - **Safety Protocols**: Comprehensive medical claim verification before delivery | |
| **Ask me anything about Sri Lankan clinical guidelines with confidence!** π±π° | |
| """) | |
| # Chat interface | |
| chatbot = gr.ChatInterface( | |
| fn=process_enhanced_medical_query, | |
| examples=[ | |
| "What is the complete management protocol for severe preeclampsia in Sri Lankan guidelines?", | |
| "How should postpartum hemorrhage be managed according to our local clinical protocols?", | |
| "What medications are contraindicated during pregnancy based on Sri Lankan guidelines?", | |
| "What are the evidence-based recommendations for managing gestational diabetes?", | |
| "How should puerperal sepsis be diagnosed and treated according to our guidelines?", | |
| "What are the protocols for assisted vaginal delivery in complicated cases?", | |
| "How should intrapartum fever be managed based on Sri Lankan standards?" | |
| ], | |
| cache_examples=False | |
| ) | |
| chatbot.api_name = "chat" | |
| # Footer with technical info | |
| gr.Markdown(""" | |
| --- | |
| **β‘ Powered by**: Cerebras Inference - World's Fastest AI (2000+ tokens/sec with Llama 3.3 70B) | |
| **π§ Technical Details**: Enhanced RAG with Clinical ModernBERT embeddings, medical entity extraction, | |
| response verification, and multi-stage retrieval for comprehensive medical information coverage. | |
| **βοΈ Disclaimer**: This AI assistant is for clinical reference only and does not replace professional medical judgment. | |
| Always consult with qualified healthcare professionals for patient care decisions. | |
| """) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| logger.info("π Launching VedaMD Enhanced for Hugging Face Spaces...") | |
| # Create the interface | |
| demo = create_enhanced_medical_interface() | |
| # Launch with appropriate settings for HF Spaces | |
| # Security: Add concurrency limits and enable queue for rate limiting | |
| # Port can be set via GRADIO_SERVER_PORT env variable, defaults to 7860 | |
| server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860")) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=server_port, | |
| share=False, | |
| show_error=True, | |
| show_api=True, | |
| max_threads=40, # Limit concurrent requests for stability | |
| ) |