File size: 10,715 Bytes
01f0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4971bd
 
 
 
 
 
 
 
01f0120
 
 
 
 
 
 
 
 
b4971bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01f0120
 
b4971bd
01f0120
 
b4971bd
 
 
 
01f0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0cde84
01f0120
 
 
f0cde84
27da814
f0cde84
 
27da814
f0cde84
 
27da814
 
72e11ad
f0cde84
27da814
 
01f0120
27da814
f0cde84
27da814
01f0120
27da814
01f0120
27da814
 
 
 
01f0120
27da814
 
 
 
 
01f0120
27da814
 
 
 
 
 
 
01f0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4971bd
 
01f0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0dba11
01f0120
 
 
 
b4971bd
 
 
01f0120
b4971bd
01f0120
 
 
 
 
 
 
 
 
 
 
 
 
 
b4971bd
 
 
01f0120
 
b4971bd
01f0120
 
b4971bd
 
01f0120
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
#!/usr/bin/env python3
"""
VedaMD Enhanced: Sri Lankan Clinical Assistant
Main Gradio Application for Hugging Face Spaces Deployment

Enhanced Medical-Grade RAG System with:
βœ… 5x Enhanced Retrieval (15+ documents vs previous 5)
βœ… Medical Entity Extraction & Clinical Terminology  
βœ… Clinical ModernBERT (768d medical embeddings)
βœ… Medical Response Verification & Safety Protocols
βœ… Advanced Re-ranking & Coverage Verification
βœ… Source Traceability & Citation Support
"""

import os
import logging
import gradio as gr
from typing import List, Dict, Optional
import sys

# Add src directory to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))

from src.enhanced_groq_medical_rag import EnhancedGroqMedicalRAG, EnhancedMedicalResponse

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Security: Verify API key is loaded from environment (not hardcoded)
# For Hugging Face Spaces: Set CEREBRAS_API_KEY in Space Settings > Repository secrets
if not os.getenv("CEREBRAS_API_KEY"):
    logger.error("❌ CEREBRAS_API_KEY not found in environment variables!")
    logger.error("⚠️ For Hugging Face Spaces: Add CEREBRAS_API_KEY in Settings > Repository secrets")
    logger.error("⚠️ Get your free API key at: https://cloud.cerebras.ai")
    raise ValueError("CEREBRAS_API_KEY environment variable is required. Please configure in HF Spaces secrets.")

# Initialize Enhanced Medical RAG System
logger.info("πŸ₯ Initializing VedaMD Enhanced for Hugging Face Spaces...")
try:
    enhanced_rag_system = EnhancedGroqMedicalRAG()
    logger.info("βœ… Enhanced Medical RAG system ready!")
except Exception as e:
    logger.error(f"❌ Failed to initialize system: {e}")
    raise

def validate_input(message: str) -> tuple[bool, str]:
    """
    Validate user input for security and quality
    Returns: (is_valid, error_message)
    """
    if not message or not message.strip():
        return False, "Please enter a medical question about Sri Lankan clinical guidelines."

    if len(message) > 2000:
        return False, "⚠️ Query too long. Please limit your question to 2000 characters."

    # Check for potential prompt injection patterns
    suspicious_patterns = ['ignore previous', 'ignore above', 'system:', 'disregard']
    if any(pattern in message.lower() for pattern in suspicious_patterns):
        return False, "⚠️ Invalid query format. Please rephrase your medical question."

    return True, ""

def process_enhanced_medical_query(message: str, history: List[List[str]]) -> str:
    """
    Process medical query with enhanced RAG system and input validation
    """
    try:
        # Validate input
        is_valid, error_msg = validate_input(message)
        if not is_valid:
            return error_msg
        
        # Convert Gradio chat history to our format
        formatted_history = []
        if history:
            for chat_pair in history:
                if len(chat_pair) >= 2:
                    user_msg, assistant_msg = chat_pair[0], chat_pair[1]
                    if user_msg:
                        formatted_history.append({"role": "user", "content": user_msg})
                    if assistant_msg:
                        formatted_history.append({"role": "assistant", "content": assistant_msg})
        
        # Get enhanced response
        response: EnhancedMedicalResponse = enhanced_rag_system.query(
            query=message, 
            history=formatted_history
        )
        
        # Format enhanced response for display
        formatted_response = format_enhanced_medical_response(response)
        return formatted_response
        
    except Exception as e:
        logger.error(f"Error processing query: {e}")
        return f"⚠️ **System Error**: {str(e)}\n\nPlease try again or contact support if the issue persists."

def format_enhanced_medical_response(response: EnhancedMedicalResponse) -> str:
    """
    Format the enhanced medical response for display, ensuring citations are always included.
    """
    formatted_parts = []
    
    # Main response from the LLM
    final_response_text = response.answer.strip()
    formatted_parts.append(final_response_text)
    
    # ALWAYS add the clinical sources section with clear numbering
    if response.sources:
        formatted_parts.append("\n\n---\n")
        formatted_parts.append("### πŸ“‹ **Clinical Sources & Citations**")
        formatted_parts.append("\nThis response is based on the following Sri Lankan clinical guidelines:")
        # Create a numbered list of all sources used for the response
        for i, source in enumerate(response.sources, 1):
            # Make the citation number bold and add a clear label
            formatted_parts.append(f"\n**[{i}]** Source: {source}")
    
    # Enhanced information section with clear separation
    formatted_parts.append("\n\n---\n")
    formatted_parts.append("### πŸ“Š **Response Analysis**")
    
    # Safety and verification info with clearer formatting
    if response.verification_result:
        safety_status = "βœ… SAFE" if response.safety_status == "SAFE" else "⚠️ CAUTION"
        formatted_parts.append(f"\n**Medical Safety Status**: {safety_status}")
        formatted_parts.append(f"**Verification Score**: {response.verification_result.verification_score:.1%}")
        formatted_parts.append(f"**Verified Medical Claims**: {response.verification_result.verified_claims}/{response.verification_result.total_claims}")
    
    # Enhanced retrieval metrics
    formatted_parts.append(f"\n**Medical Information Coverage**:")
    formatted_parts.append(f"- 🧠 Medical Entities: {response.medical_entities_count}")
    formatted_parts.append(f"- 🎯 Context Adherence: {response.context_adherence_score:.1%}")
    formatted_parts.append(f"- πŸ“š Guidelines Referenced: {len(response.sources)}")
    
    # Always include processing time if available
    if hasattr(response, 'query_time'):
        formatted_parts.append(f"- ⚑ Processing Time: {response.query_time:.2f}s")
    
    # Medical disclaimer with clear separation
    formatted_parts.append("\n\n---\n")
    formatted_parts.append("*βš•οΈ This information is derived from Sri Lankan clinical guidelines and is for reference only. Always consult with qualified healthcare professionals for patient care decisions.*")
    
    return "\n".join(formatted_parts)

def create_enhanced_medical_interface():
    """
    Create the enhanced Gradio interface for Hugging Face Spaces
    """
    # Custom CSS for medical theme
    custom_css = """
    .gradio-container {
        max-width: 900px !important;
        margin: auto !important;
    }
    .medical-header {
        background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
        color: white;
        padding: 20px;
        border-radius: 10px;
        margin-bottom: 20px;
        text-align: center;
    }
    """
    
    with gr.Blocks(
        title="πŸ₯ VedaMD Enhanced: Sri Lankan Clinical Assistant",
        theme=gr.themes.Soft(),
        css=custom_css
    ) as demo:
        
        # Header
        gr.HTML("""
        <div class="medical-header">
            <h1>πŸ₯ VedaMD Enhanced: Sri Lankan Clinical Assistant</h1>
            <h3>Ultra-Fast Medical AI powered by Cerebras Inference</h3>
            <p>⚑ World's Fastest Inference β€’ βœ… Medical Verification β€’ βœ… Clinical ModernBERT β€’ βœ… Free to Use</p>
        </div>
        """)
        
        # Description
        gr.Markdown("""
        **🩺 Advanced Medical AI Assistant** for Sri Lankan maternal health guidelines with **enhanced safety protocols**:

        🎯 **Enhanced Features:**
        - **5x Enhanced Retrieval**: 15+ documents analyzed vs previous 5
        - **Medical Entity Extraction**: Advanced clinical terminology recognition  
        - **Clinical ModernBERT**: Specialized 768d medical domain embeddings
        - **Medical Response Verification**: 100% source traceability validation
        - **Advanced Re-ranking**: Medical relevance scoring with coverage verification
        - **Safety Protocols**: Comprehensive medical claim verification before delivery

        **Ask me anything about Sri Lankan clinical guidelines with confidence!** πŸ‡±πŸ‡°
        """)
        
        # Chat interface
        chatbot = gr.ChatInterface(
            fn=process_enhanced_medical_query,
            examples=[
                "What is the complete management protocol for severe preeclampsia in Sri Lankan guidelines?",
                "How should postpartum hemorrhage be managed according to our local clinical protocols?",
                "What medications are contraindicated during pregnancy based on Sri Lankan guidelines?",
                "What are the evidence-based recommendations for managing gestational diabetes?",
                "How should puerperal sepsis be diagnosed and treated according to our guidelines?",
                "What are the protocols for assisted vaginal delivery in complicated cases?",
                "How should intrapartum fever be managed based on Sri Lankan standards?"
            ],
            cache_examples=False
        )
        chatbot.api_name = "chat"
        
        # Footer with technical info
        gr.Markdown("""
        ---
        **⚑ Powered by**: Cerebras Inference - World's Fastest AI (2000+ tokens/sec with Llama 3.3 70B)

        **πŸ”§ Technical Details**: Enhanced RAG with Clinical ModernBERT embeddings, medical entity extraction,
        response verification, and multi-stage retrieval for comprehensive medical information coverage.

        **βš–οΈ Disclaimer**: This AI assistant is for clinical reference only and does not replace professional medical judgment.
        Always consult with qualified healthcare professionals for patient care decisions.
        """)
    
    return demo

# Create and launch the interface
if __name__ == "__main__":
    logger.info("πŸš€ Launching VedaMD Enhanced for Hugging Face Spaces...")
    
    # Create the interface
    demo = create_enhanced_medical_interface()
    
    # Launch with appropriate settings for HF Spaces
    # Security: Add concurrency limits and enable queue for rate limiting
    # Port can be set via GRADIO_SERVER_PORT env variable, defaults to 7860
    server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
    demo.launch(
        server_name="0.0.0.0",
        server_port=server_port,
        share=False,
        show_error=True,
        show_api=True,
        max_threads=40,  # Limit concurrent requests for stability
    )