Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| LangChain Medical Agents Architecture - Refactored | |
| A multi-agent system for processing medical transcriptions and documents. | |
| """ | |
| import os | |
| import re | |
| from datetime import datetime | |
| from dotenv import load_dotenv | |
| from langchain_openai import AzureChatOpenAI | |
| from langchain.agents import AgentExecutor, create_openai_tools_agent | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| # Import modular components | |
| from models import TemplateAnalysis, MedicalTranscription, SectionContent, InsertSectionsInput | |
| from sftp_agent import create_sftp_downloader_agent, download_model_from_sftp | |
| from template_analyzer import create_template_analyzer_agent, analyze_word_template | |
| from transcription_processor import ( | |
| create_transcription_corrector_chain, | |
| create_medical_analyzer_chain, | |
| create_title_generator_chain, | |
| load_transcription_with_user_id | |
| ) | |
| from section_generator import create_dynamic_section_prompt, fix_section_names | |
| from document_assembler import create_document_assembler_agent | |
| from document_validator import validate_generated_document, create_validation_chain | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize LLM with Azure OpenAI | |
| llm = AzureChatOpenAI( | |
| azure_deployment="gtp-4o-eastus2", | |
| openai_api_version="2024-02-15-preview", | |
| azure_endpoint="https://voxist-gpt-eastus2.openai.azure.com/", | |
| api_key="98db8190a2ff438b904c7e9862a13210", | |
| temperature=0.1 | |
| ) | |
| class MedicalDocumentOrchestrator: | |
| """Main orchestrator that coordinates all agents.""" | |
| def __init__(self, template_path: str = None, transcription_path: str = None, transcriptions_dir: str = "transcriptions"): | |
| self.template_path = template_path | |
| self.transcription_path = transcription_path | |
| self.transcriptions_dir = transcriptions_dir | |
| self.template_analysis = None | |
| self.corrected_transcription = None | |
| self.medical_data = None | |
| self.generated_sections = None | |
| self.generated_title = None | |
| self.downloaded_models = None | |
| def run_full_pipeline(self, output_path: str = None) -> str: | |
| """Run the complete medical document processing pipeline.""" | |
| print("π Starting LangChain Medical Document Pipeline...") | |
| # Step 0: Download only the model corresponding to the transcription | |
| print("\nπ₯ Step 0: Downloading model from SFTP for the selected transcription...") | |
| try: | |
| transcription_filename = os.path.basename(self.transcription_path) | |
| match = re.search(r'transcriptions_(.+)\.rtf_', | |
| transcription_filename) | |
| if match: | |
| model_id = match.group(1) | |
| model_filename = f"{model_id}.rtf" | |
| local_filename = f"{model_id}.doc" | |
| local_template_path = os.path.join("models", local_filename) | |
| print(f"π Model identifier for this transcription: {model_id}") | |
| # Download only the required model via a simple agent | |
| simple_sftp_agent = create_openai_tools_agent( | |
| llm=llm, | |
| tools=[download_model_from_sftp], | |
| prompt=ChatPromptTemplate.from_messages([ | |
| ("system", "You are an SFTP downloader. Download the specified model file."), | |
| ("human", "Download the model file: {model_filename}"), | |
| MessagesPlaceholder("agent_scratchpad") | |
| ]) | |
| ) | |
| simple_sftp_executor = AgentExecutor( | |
| agent=simple_sftp_agent, | |
| tools=[download_model_from_sftp], | |
| verbose=True | |
| ) | |
| result = simple_sftp_executor.invoke({ | |
| "model_filename": model_filename | |
| }) | |
| print( | |
| f"β Model downloaded and available as: {local_template_path}") | |
| self.template_path = local_template_path | |
| self.downloaded_models = [{ | |
| 'model_id': model_id, | |
| 'model_filename': model_filename, | |
| 'local_filename': local_filename, | |
| 'local_path': local_template_path, | |
| 'status': 'success' | |
| }] | |
| else: | |
| raise ValueError( | |
| "Unable to extract the model identifier from the transcription filename.") | |
| except Exception as e: | |
| print(f"β Error during SFTP download step: {str(e)}") | |
| if self.template_path: | |
| print("β οΈ Continuing with pipeline using the provided template_path...") | |
| else: | |
| print( | |
| "β No template path provided and SFTP download failed. Cannot continue.") | |
| raise Exception( | |
| "Cannot continue without a template. SFTP download failed and no template path was provided.") | |
| self.downloaded_models = [] | |
| # Step 1: Analyze template | |
| print("\nπ Step 1: Analyzing template...") | |
| if not self.template_path: | |
| raise ValueError("No template path available for analysis") | |
| self.template_analysis = analyze_word_template(self.template_path) | |
| print( | |
| f"β Template analyzed: {len(self.template_analysis.get('sections', []))} sections found") | |
| # Step 2: Load and correct transcription | |
| print("\nβοΈ Step 2: Correcting transcription...") | |
| raw_transcription, user_id = load_transcription_with_user_id( | |
| self.transcription_path) | |
| transcription_corrector_chain = create_transcription_corrector_chain( | |
| llm) | |
| self.corrected_transcription = transcription_corrector_chain.invoke({ | |
| "transcription": raw_transcription | |
| }).content | |
| # β Ajoute ces deux lignes juste aprΓ¨s | |
| print("\n===== Transcription après correction =====") | |
| print(self.corrected_transcription) | |
| print("β Transcription corrected") | |
| print("β Transcription corrected") | |
| # Step 3: Analyze medical data | |
| print("\n㪠Step 3: Analyzing medical data...") | |
| medical_analyzer_chain = create_medical_analyzer_chain(llm) | |
| self.medical_data = medical_analyzer_chain.invoke({ | |
| "corrected_transcription": self.corrected_transcription | |
| }).content | |
| print("β Medical data analyzed") | |
| # Step 4: Generate title | |
| print("\nπ Step 4: Generating title...") | |
| title_generator_chain = create_title_generator_chain(llm) | |
| self.generated_title = title_generator_chain.invoke({ | |
| "medical_data": self.medical_data | |
| }).content | |
| print(f"β Title generated: {self.generated_title}") | |
| # Step 5: Generate sections | |
| print("\nπ Step 5: Generating sections...") | |
| # Extract sections from template analysis | |
| template_sections = [] | |
| # Debug: see exactly what template_analysis contains | |
| print("\n--- DEBUG: Type and content of template_analysis ---") | |
| print(f"Type: {type(self.template_analysis)}") | |
| print(f"Content: {self.template_analysis}") | |
| if hasattr(self.template_analysis, '__dict__'): | |
| print(f"Attributes: {self.template_analysis.__dict__}") | |
| print("--- END DEBUG ---\n") | |
| # Always retrieve the sections list if possible | |
| try: | |
| if isinstance(self.template_analysis, dict) and 'sections' in self.template_analysis: | |
| template_sections = [section['text'] | |
| for section in self.template_analysis['sections']] | |
| elif hasattr(self.template_analysis, 'get') and 'sections' in self.template_analysis: | |
| template_sections = [section['text'] | |
| for section in self.template_analysis['sections']] | |
| elif hasattr(self.template_analysis, 'output') and isinstance(self.template_analysis.output, dict) and 'sections' in self.template_analysis.output: | |
| template_sections = [section['text'] | |
| for section in self.template_analysis.output['sections']] | |
| except Exception as e: | |
| print('Error extracting sections:', e) | |
| # Fallback: try to extract from the agent response text | |
| if not template_sections: | |
| response_text = str(self.template_analysis) | |
| if 'Technique' in response_text and 'RΓ©sultat' in response_text and 'Conclusion' in response_text: | |
| template_sections = ['Technique\xa0:', | |
| 'RΓ©sultat\xa0:', 'Conclusion\xa0:'] | |
| elif 'CONCLUSION' in response_text: | |
| template_sections = ['CONCLUSION\xa0:'] | |
| # Create dynamic prompt based on template sections | |
| dynamic_section_prompt = create_dynamic_section_prompt( | |
| template_sections) | |
| section_generator_chain = dynamic_section_prompt | llm | |
| generated_content = section_generator_chain.invoke({ | |
| "template_sections": template_sections, | |
| "medical_data": self.medical_data, | |
| "corrected_transcription": self.corrected_transcription | |
| }).content | |
| # Post-process to ensure exact section names are used | |
| self.generated_sections = fix_section_names( | |
| generated_content, template_sections) | |
| print("\n--- DEBUG: Generated sections ---") | |
| print(self.generated_sections) | |
| print("--- END DEBUG ---\n") | |
| print("\n--- DEBUG: Template sections ---") | |
| print(template_sections) | |
| print("--- END DEBUG ---\n") | |
| print("\n--- DEBUG: Generated title ---") | |
| print(self.generated_title) | |
| print("--- END DEBUG ---\n") | |
| # Step 6: Assemble document | |
| print("\nπ Step 6: Assembling document...") | |
| if output_path is None: | |
| # Generate output filename based on user_id | |
| # Replace the last extension with .docx | |
| if '.' in user_id: | |
| # Split by dots and replace the last part with docx | |
| parts = user_id.split('.') | |
| parts[-1] = 'docx' | |
| output_filename = '.'.join(parts) | |
| else: | |
| # If no extension, just add .docx | |
| output_filename = f"{user_id}.docx" | |
| output_path = output_filename | |
| # Use the agent for assembly | |
| document_assembler_executor = create_document_assembler_agent(llm) | |
| result = document_assembler_executor.invoke({ | |
| "template_path": self.template_path, | |
| "sections_text": self.generated_sections, | |
| "title": self.generated_title, | |
| "output_path": output_path | |
| }) | |
| print(f"π Pipeline completed! Document saved: {output_path}") | |
| # Step 7: Validate document | |
| print("\nπ Step 7: Validating document...") | |
| validation_result = validate_generated_document( | |
| self.template_path, self.transcription_path, output_path) | |
| # Display validation results | |
| print("\n" + "=" * 60) | |
| print("π VALIDATION RESULTS") | |
| print("=" * 60) | |
| # Overall score | |
| score = validation_result["overall_score"] | |
| score_emoji = "π’" if score >= 0.8 else "π‘" if score >= 0.6 else "π΄" | |
| print(f"{score_emoji} Overall Score: {score:.1%}") | |
| # Structure validation | |
| structure_valid = validation_result["structure_valid"] | |
| structure_emoji = "β " if structure_valid else "β" | |
| print(f"{structure_emoji} Structure Valid: {structure_valid}") | |
| if not structure_valid: | |
| missing = validation_result["missing_sections"] | |
| print(f" Missing sections: {', '.join(missing)}") | |
| # Entities validation | |
| entities_coverage = validation_result["entities_coverage"] | |
| entities_emoji = "β " if entities_coverage >= 80 else "β οΈ" | |
| print(f"{entities_emoji} Medical Entities Coverage: {entities_coverage:.1f}%") | |
| if entities_coverage < 80: | |
| missing_entities = validation_result["missing_entities"][:5] | |
| print(f" Missing entities: {', '.join(missing_entities)}") | |
| # Generate AI validation report | |
| print("\nπ AI Validation Report:") | |
| print("-" * 40) | |
| # Extract content for AI validation | |
| from docx import Document | |
| doc = Document(output_path) | |
| generated_content = [] | |
| for paragraph in doc.paragraphs: | |
| text = paragraph.text.strip() | |
| if text and not text.startswith("Date:") and not text.startswith("Heure:"): | |
| generated_content.append(text) | |
| generated_text = "\n".join(generated_content) | |
| validation_chain = create_validation_chain(llm) | |
| ai_validation = validation_chain.invoke({ | |
| "transcription": self.corrected_transcription, | |
| "generated_content": generated_text, | |
| "structure_valid": structure_valid, | |
| "entities_coverage": entities_coverage, | |
| "missing_sections": validation_result["missing_sections"], | |
| "missing_entities": validation_result["missing_entities"] | |
| }) | |
| print(ai_validation.content) | |
| print("\n" + "=" * 60) | |
| print("β Document validated") | |
| # Remove the local model after validation | |
| try: | |
| if self.template_path and os.path.exists(self.template_path): | |
| os.remove(self.template_path) | |
| print(f"ποΈ Deleted local model file: {self.template_path}") | |
| except Exception as e: | |
| print(f"β οΈ Could not delete local model file: {e}") | |
| return output_path | |
| def main(): | |
| """Main function to run the LangChain medical document pipeline.""" | |
| print("π₯ LangChain Medical Document Agents - Refactored") | |
| print("=" * 60) | |
| # Initialize orchestrator | |
| orchestrator = MedicalDocumentOrchestrator( | |
| template_path="default.528.251014072.doc", | |
| transcription_path="transciption.txt" | |
| ) | |
| # Run the complete pipeline | |
| output_file = orchestrator.run_full_pipeline() | |
| print(f"\nβ Final document: {output_file}") | |
| print("π LangChain pipeline completed successfully!") | |
| if __name__ == "__main__": | |
| main() | |