Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Medical Response Verification Layer | |
| VedaMD Medical RAG - Phase 2: Task 2.2 | |
| This module provides comprehensive medical response verification to ensure: | |
| 1. 100% source traceability for all medical claims | |
| 2. Context adherence validation against provided Sri Lankan guidelines | |
| 3. Prevention of medical hallucination and external knowledge injection | |
| 4. Regulatory compliance for medical device applications | |
| CRITICAL SAFETY PROTOCOL: | |
| - Every medical fact MUST be traceable to provided source documents | |
| - No medical information allowed without explicit context support | |
| - Strict verification of dosages, procedures, and protocols | |
| - Comprehensive medical claim validation system | |
| """ | |
| import re | |
| import logging | |
| from typing import List, Dict, Set, Tuple, Optional, Any, Union | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import json | |
| from pathlib import Path | |
| class VerificationStatus(Enum): | |
| """Verification status for medical claims""" | |
| VERIFIED = "verified" | |
| NOT_FOUND = "not_found" | |
| PARTIAL_MATCH = "partial_match" | |
| CONTRADICTED = "contradicted" | |
| INSUFFICIENT_CONTEXT = "insufficient_context" | |
| class MedicalClaimType(Enum): | |
| """Types of medical claims to verify""" | |
| DOSAGE = "dosage" | |
| MEDICATION = "medication" | |
| PROCEDURE = "procedure" | |
| CONDITION = "condition" | |
| VITAL_SIGN = "vital_sign" | |
| CONTRAINDICATION = "contraindication" | |
| INDICATION = "indication" | |
| PROTOCOL = "protocol" | |
| EVIDENCE_LEVEL = "evidence_level" | |
| class MedicalClaim: | |
| """Individual medical claim extracted from LLM response""" | |
| text: str | |
| claim_type: MedicalClaimType | |
| context: str | |
| confidence: float | |
| citation_required: bool = True | |
| extracted_values: Dict[str, str] = None | |
| class VerificationResult: | |
| """Result of medical claim verification""" | |
| claim: MedicalClaim | |
| status: VerificationStatus | |
| supporting_sources: List[str] | |
| confidence_score: float | |
| verification_details: str | |
| suggested_correction: Optional[str] = None | |
| class MedicalResponseVerification: | |
| """Complete medical response verification result""" | |
| original_response: str | |
| total_claims: int | |
| verified_claims: int | |
| failed_verifications: List[VerificationResult] | |
| verification_score: float | |
| is_safe_for_medical_use: bool | |
| detailed_results: List[VerificationResult] | |
| safety_warnings: List[str] | |
| class MedicalResponseVerifier: | |
| """ | |
| Medical response verification system for context adherence validation | |
| """ | |
| def __init__(self): | |
| self.setup_logging() | |
| self.medical_claim_patterns = self._initialize_medical_patterns() | |
| def setup_logging(self): | |
| """Setup logging for medical response verification""" | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| self.logger = logging.getLogger(__name__) | |
| def _initialize_medical_patterns(self) -> Dict[MedicalClaimType, List[str]]: | |
| """Initialize patterns for extracting medical claims from responses""" | |
| return { | |
| MedicalClaimType.DOSAGE: [ | |
| r'(?:administer|give|prescribe|dose of?)\s+(\d+(?:\.\d+)?)\s*(mg|g|ml|units?|tablets?)', | |
| r'(\d+(?:\.\d+)?)\s*(mg|g|ml|units?)\s+(?:of |every |per )', | |
| r'(?:low|moderate|high|maximum|minimum)\s+dose' | |
| ], | |
| MedicalClaimType.MEDICATION: [ | |
| r'\b(magnesium sulfate|MgSO4|oxytocin|methyldopa|nifedipine|labetalol|hydralazine)\b', | |
| r'\b(ampicillin|gentamicin|ceftriaxone|azithromycin|doxycycline)\b', | |
| r'\b(insulin|metformin|glibenclamide|aspirin|atorvastatin)\b' | |
| ], | |
| MedicalClaimType.PROCEDURE: [ | |
| r'\b(cesarean section|C-section|vaginal delivery|assisted delivery)\b', | |
| r'\b(IV access|urinary catheter|nasogastric tube|blood transfusion)\b', | |
| r'\b(blood pressure monitoring|fetal monitoring|CTG)\b' | |
| ], | |
| MedicalClaimType.CONDITION: [ | |
| r'\b(preeclampsia|eclampsia|HELLP syndrome|gestational hypertension)\b', | |
| r'\b(postpartum hemorrhage|PPH|retained placenta|uterine atony)\b', | |
| r'\b(puerperal sepsis|endometritis|wound infection)\b' | |
| ], | |
| MedicalClaimType.VITAL_SIGN: [ | |
| r'blood pressure.*?(\d+/\d+)\s*mmHg', | |
| r'BP.*?([<>β€β₯]?\s*\d+/\d+)\s*mmHg', | |
| r'heart rate.*?(\d+)\s*bpm' | |
| ], | |
| MedicalClaimType.CONTRAINDICATION: [ | |
| r'contraindicated|avoid|do not use|should not be given', | |
| r'not recommended|prohibited|forbidden' | |
| ], | |
| MedicalClaimType.INDICATION: [ | |
| r'indicated for|recommended for|used to treat', | |
| r'first-line treatment|treatment of choice' | |
| ], | |
| MedicalClaimType.PROTOCOL: [ | |
| r'according to protocol|standard protocol|clinical protocol', | |
| r'guideline recommends|evidence-based approach' | |
| ] | |
| } | |
| def extract_medical_claims(self, response: str) -> List[MedicalClaim]: | |
| """ | |
| Extract all medical claims from LLM response that need verification | |
| """ | |
| claims = [] | |
| sentences = re.split(r'[.!?]+', response) | |
| for sentence_idx, sentence in enumerate(sentences): | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| for claim_type, patterns in self.medical_claim_patterns.items(): | |
| for pattern in patterns: | |
| matches = re.finditer(pattern, sentence, re.IGNORECASE) | |
| for match in matches: | |
| # Extract specific values if present | |
| extracted_values = {} | |
| if match.groups(): | |
| for i, group in enumerate(match.groups()): | |
| extracted_values[f'value_{i}'] = group | |
| claim = MedicalClaim( | |
| text=match.group(), | |
| claim_type=claim_type, | |
| context=sentence, | |
| confidence=self._calculate_claim_confidence(match.group(), sentence), | |
| citation_required=self._requires_citation(claim_type), | |
| extracted_values=extracted_values | |
| ) | |
| claims.append(claim) | |
| # Remove duplicate claims | |
| unique_claims = [] | |
| seen_claims = set() | |
| for claim in claims: | |
| claim_key = (claim.text.lower(), claim.claim_type) | |
| if claim_key not in seen_claims: | |
| unique_claims.append(claim) | |
| seen_claims.add(claim_key) | |
| self.logger.info(f"Extracted {len(unique_claims)} medical claims for verification") | |
| return unique_claims | |
| def verify_claim_against_context(self, claim: MedicalClaim, | |
| provided_context: List[str]) -> VerificationResult: | |
| """ | |
| Verify a medical claim against provided source documents | |
| """ | |
| supporting_sources = [] | |
| verification_details = [] | |
| best_match_score = 0.0 | |
| # Check each context document for supporting evidence | |
| for source_idx, context_doc in enumerate(provided_context): | |
| context_lower = context_doc.lower() | |
| claim_text_lower = claim.text.lower() | |
| # Direct text match | |
| if claim_text_lower in context_lower: | |
| supporting_sources.append(f"Document_{source_idx + 1}") | |
| verification_details.append(f"Exact match found in source document") | |
| best_match_score = max(best_match_score, 1.0) | |
| continue | |
| # Semantic verification for different claim types | |
| if claim.claim_type == MedicalClaimType.DOSAGE: | |
| score = self._verify_dosage_claim(claim, context_doc) | |
| if score > 0.7: | |
| supporting_sources.append(f"Document_{source_idx + 1}") | |
| verification_details.append(f"Dosage information supported (confidence: {score:.2f})") | |
| best_match_score = max(best_match_score, score) | |
| elif claim.claim_type == MedicalClaimType.MEDICATION: | |
| score = self._verify_medication_claim(claim, context_doc) | |
| if score > 0.8: | |
| supporting_sources.append(f"Document_{source_idx + 1}") | |
| verification_details.append(f"Medication information supported (confidence: {score:.2f})") | |
| best_match_score = max(best_match_score, score) | |
| elif claim.claim_type == MedicalClaimType.PROCEDURE: | |
| score = self._verify_procedure_claim(claim, context_doc) | |
| if score > 0.7: | |
| supporting_sources.append(f"Document_{source_idx + 1}") | |
| verification_details.append(f"Procedure information supported (confidence: {score:.2f})") | |
| best_match_score = max(best_match_score, score) | |
| # Determine verification status | |
| if best_match_score >= 0.9: | |
| status = VerificationStatus.VERIFIED | |
| elif best_match_score >= 0.6: | |
| status = VerificationStatus.PARTIAL_MATCH | |
| elif len(supporting_sources) == 0: | |
| status = VerificationStatus.NOT_FOUND | |
| else: | |
| status = VerificationStatus.INSUFFICIENT_CONTEXT | |
| return VerificationResult( | |
| claim=claim, | |
| status=status, | |
| supporting_sources=supporting_sources, | |
| confidence_score=best_match_score, | |
| verification_details="; ".join(verification_details) if verification_details else "No supporting evidence found", | |
| suggested_correction=self._generate_correction_suggestion(claim, status) | |
| ) | |
| def _verify_dosage_claim(self, claim: MedicalClaim, context: str) -> float: | |
| """Verify dosage claims against context""" | |
| confidence = 0.0 | |
| if claim.extracted_values: | |
| for key, value in claim.extracted_values.items(): | |
| if re.search(rf'\b{re.escape(value)}\b', context, re.IGNORECASE): | |
| confidence += 0.4 | |
| # Check for dosage-related keywords in context | |
| dosage_keywords = ['dose', 'administer', 'give', 'mg', 'g', 'units'] | |
| for keyword in dosage_keywords: | |
| if keyword in context.lower(): | |
| confidence += 0.1 | |
| return min(confidence, 1.0) | |
| def _verify_medication_claim(self, claim: MedicalClaim, context: str) -> float: | |
| """Verify medication claims against context""" | |
| medication_name = claim.text.lower() | |
| context_lower = context.lower() | |
| # Check for exact medication name | |
| if medication_name in context_lower: | |
| return 1.0 | |
| # Check for common medication aliases | |
| medication_aliases = { | |
| 'mgso4': 'magnesium sulfate', | |
| 'magnesium sulfate': 'mgso4', | |
| 'bp': 'blood pressure' | |
| } | |
| for alias, full_name in medication_aliases.items(): | |
| if medication_name == alias and full_name in context_lower: | |
| return 0.9 | |
| elif medication_name == full_name and alias in context_lower: | |
| return 0.9 | |
| return 0.0 | |
| def _verify_procedure_claim(self, claim: MedicalClaim, context: str) -> float: | |
| """Verify procedure claims against context""" | |
| procedure_name = claim.text.lower() | |
| context_lower = context.lower() | |
| if procedure_name in context_lower: | |
| return 1.0 | |
| # Check for procedure synonyms | |
| procedure_synonyms = { | |
| 'c-section': 'cesarean section', | |
| 'cesarean section': 'c-section', | |
| 'iv access': 'intravenous access' | |
| } | |
| for synonym, standard_name in procedure_synonyms.items(): | |
| if procedure_name == synonym and standard_name in context_lower: | |
| return 0.9 | |
| return 0.0 | |
| def verify_medical_response(self, response: str, | |
| provided_context: List[str]) -> MedicalResponseVerification: | |
| """ | |
| Comprehensive verification of medical response against provided context | |
| """ | |
| self.logger.info("π Starting comprehensive medical response verification") | |
| # Extract all medical claims from response | |
| medical_claims = self.extract_medical_claims(response) | |
| # Verify each claim against provided context | |
| verification_results = [] | |
| verified_count = 0 | |
| failed_verifications = [] | |
| safety_warnings = [] | |
| for claim in medical_claims: | |
| result = self.verify_claim_against_context(claim, provided_context) | |
| verification_results.append(result) | |
| if result.status == VerificationStatus.VERIFIED: | |
| verified_count += 1 | |
| else: | |
| failed_verifications.append(result) | |
| # Generate safety warnings for critical failures | |
| if claim.claim_type in [MedicalClaimType.DOSAGE, MedicalClaimType.MEDICATION, | |
| MedicalClaimType.CONTRAINDICATION]: | |
| safety_warnings.append(f"CRITICAL: {claim.claim_type.value} claim not verified - '{claim.text}'") | |
| # Calculate overall verification score | |
| total_claims = len(medical_claims) | |
| verification_score = (verified_count / total_claims) if total_claims > 0 else 1.0 | |
| # Determine if response is safe for medical use | |
| is_safe = verification_score >= 0.9 and len(safety_warnings) == 0 | |
| verification_result = MedicalResponseVerification( | |
| original_response=response, | |
| total_claims=total_claims, | |
| verified_claims=verified_count, | |
| failed_verifications=failed_verifications, | |
| verification_score=verification_score, | |
| is_safe_for_medical_use=is_safe, | |
| detailed_results=verification_results, | |
| safety_warnings=safety_warnings | |
| ) | |
| self.logger.info(f"β Medical verification complete: {verified_count}/{total_claims} claims verified " | |
| f"(Score: {verification_score:.1%}, Safe: {is_safe})") | |
| return verification_result | |
| def _calculate_claim_confidence(self, claim_text: str, context: str) -> float: | |
| """Calculate confidence score for extracted medical claim""" | |
| confidence = 0.5 | |
| # Higher confidence for claims with specific numerical values | |
| if re.search(r'\d+', claim_text): | |
| confidence += 0.2 | |
| # Higher confidence for claims in clinical context | |
| clinical_indicators = ['patient', 'treatment', 'administer', 'protocol', 'guideline'] | |
| if any(indicator in context.lower() for indicator in clinical_indicators): | |
| confidence += 0.2 | |
| return min(confidence, 1.0) | |
| def _requires_citation(self, claim_type: MedicalClaimType) -> bool: | |
| """Determine if claim type requires citation""" | |
| critical_types = [ | |
| MedicalClaimType.DOSAGE, | |
| MedicalClaimType.MEDICATION, | |
| MedicalClaimType.CONTRAINDICATION, | |
| MedicalClaimType.PROTOCOL | |
| ] | |
| return claim_type in critical_types | |
| def _generate_correction_suggestion(self, claim: MedicalClaim, | |
| status: VerificationStatus) -> Optional[str]: | |
| """Generate correction suggestions for unverified claims""" | |
| if status == VerificationStatus.NOT_FOUND: | |
| return f"Remove claim '{claim.text}' - not supported by provided guidelines" | |
| elif status == VerificationStatus.INSUFFICIENT_CONTEXT: | |
| return f"Add qualification: 'Based on available guidelines, {claim.text.lower()}' or remove if not essential" | |
| return None | |
| def test_medical_response_verifier(): | |
| """Test the medical response verification system""" | |
| print("π§ͺ Testing Medical Response Verification System") | |
| # Test medical response from LLM | |
| test_response = """ | |
| For preeclampsia management, administer magnesium sulfate 4g IV bolus for seizure prophylaxis. | |
| Control blood pressure with methyldopa 250mg orally every 8 hours. | |
| Monitor vital signs including blood pressure β₯140/90 mmHg. | |
| This medication is contraindicated in patients with myasthenia gravis. | |
| Alternative treatment includes nifedipine 10mg sublingually, though this is not mentioned in current guidelines. | |
| """ | |
| # Provided context from Sri Lankan guidelines | |
| test_context = [ | |
| """ | |
| Preeclampsia Management Protocol: | |
| - Administer magnesium sulfate (MgSO4) 4g IV bolus for seizure prophylaxis | |
| - Control BP with methyldopa 250mg orally every 8 hours | |
| - Monitor blood pressure β₯140/90 mmHg | |
| - Contraindicated: magnesium sulfate is contraindicated in myasthenia gravis | |
| """, | |
| """ | |
| Additional clinical guidelines for severe preeclampsia: | |
| - Immediate delivery considerations for severe cases | |
| - Laboratory monitoring requirements | |
| - Multidisciplinary team involvement | |
| """ | |
| ] | |
| verifier = MedicalResponseVerifier() | |
| # Perform comprehensive verification | |
| verification = verifier.verify_medical_response(test_response, test_context) | |
| print(f"\nπ Verification Results:") | |
| print(f" Total Claims: {verification.total_claims}") | |
| print(f" Verified Claims: {verification.verified_claims}") | |
| print(f" Verification Score: {verification.verification_score:.1%}") | |
| print(f" Safe for Medical Use: {verification.is_safe_for_medical_use}") | |
| print(f"\nπ Detailed Results:") | |
| for result in verification.detailed_results: | |
| status_emoji = "β " if result.status == VerificationStatus.VERIFIED else "β" | |
| print(f" {status_emoji} {result.claim.text} ({result.claim.claim_type.value})") | |
| print(f" Status: {result.status.value} | Confidence: {result.confidence_score:.2f}") | |
| if result.verification_details: | |
| print(f" Details: {result.verification_details}") | |
| if verification.safety_warnings: | |
| print(f"\nβ οΈ Safety Warnings:") | |
| for warning in verification.safety_warnings: | |
| print(f" - {warning}") | |
| print(f"\nβ Medical Response Verification Test Completed") | |
| if __name__ == "__main__": | |
| test_medical_response_verifier() |