medical-agent / Create_db_optimized.py
Nourhenem's picture
Upload folder using huggingface_hub
1eb76aa verified
import os
import re
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from dataclasses import dataclass
from pathlib import Path
# Core libraries
import torch
from transformers import (
AutoTokenizer, AutoModel, AutoModelForTokenClassification,
TrainingArguments, Trainer, pipeline
)
from torch.utils.data import Dataset
import torch.nn.functional as F
# Vector database
import chromadb
from chromadb.config import Settings
# Utilities
import logging
from tqdm import tqdm
import pandas as pd
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class MedicalEntity:
"""Structure pour les entités médicales extraites par NER"""
exam_types: List[Tuple[str, float]] # (entity, confidence)
specialties: List[Tuple[str, float]]
anatomical_regions: List[Tuple[str, float]]
pathologies: List[Tuple[str, float]]
medical_procedures: List[Tuple[str, float]]
measurements: List[Tuple[str, float]]
medications: List[Tuple[str, float]]
symptoms: List[Tuple[str, float]]
class AdvancedMedicalNER:
"""NER médical avancé basé sur CamemBERT-Bio fine-tuné"""
def __init__(self, model_name: str = "auto", cache_dir: str = "./models_cache"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
# Auto-détection du meilleur modèle NER médical disponible
self.model_name = self._select_best_model(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Chargement du modèle NER
self._load_ner_model()
# Labels BIO pour entités médicales
self.entity_labels = [
"O", # Outside
"B-EXAM", "I-EXAM", # Types d'examens
"B-SPECIALTY", "I-SPECIALTY", # Spécialités médicales
"B-ANATOMY", "I-ANATOMY", # Régions anatomiques
"B-PATHOLOGY", "I-PATHOLOGY", # Pathologies
"B-PROCEDURE", "I-PROCEDURE", # Procédures médicales
"B-MEASURE", "I-MEASURE", # Mesures/valeurs
"B-MEDICATION", "I-MEDICATION", # Médicaments
"B-SYMPTOM", "I-SYMPTOM" # Symptômes
]
self.id2label = {i: label for i, label in enumerate(self.entity_labels)}
self.label2id = {label: i for i, label in enumerate(self.entity_labels)}
def _select_best_model(self, model_name: str) -> str:
"""Sélection automatique du meilleur modèle NER médical"""
if model_name != "auto":
return model_name
# Liste des modèles par ordre de préférence
preferred_models = [
"almanach/camembert-bio-base", # CamemBERT Bio français
"Dr-BERT/DrBERT-7GB", # DrBERT spécialisé
"emilyalsentzer/Bio_ClinicalBERT", # Bio Clinical BERT
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"dmis-lab/biobert-base-cased-v1.2", # BioBERT
"camembert-base" # Fallback CamemBERT standard
]
for model in preferred_models:
try:
# Test de disponibilité
AutoTokenizer.from_pretrained(model, cache_dir=self.cache_dir)
logger.info(f"Modèle sélectionné: {model}")
return model
except:
continue
# Fallback ultime
logger.warning("Utilisation du modèle de base camembert-base")
return "camembert-base"
def _load_ner_model(self):
"""Charge ou crée le modèle NER fine-tuné"""
fine_tuned_path = self.cache_dir / "medical_ner_model"
if fine_tuned_path.exists():
logger.info("Chargement du modèle NER fine-tuné existant")
self.tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path)
self.ner_model = AutoModelForTokenClassification.from_pretrained(fine_tuned_path)
else:
logger.info("Création d'un nouveau modèle NER médical")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.cache_dir)
# Modèle pour classification de tokens (NER)
self.ner_model = AutoModelForTokenClassification.from_pretrained(
self.model_name,
num_labels=len(self.entity_labels),
id2label=self.id2label,
label2id=self.label2id,
cache_dir=self.cache_dir
)
self.ner_model.to(self.device)
# Pipeline NER
self.ner_pipeline = pipeline(
"token-classification",
model=self.ner_model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1,
aggregation_strategy="simple"
)
def extract_entities(self, text: str) -> MedicalEntity:
"""Extraction d'entités avec le modèle NER fine-tuné"""
# Prédiction NER
try:
ner_results = self.ner_pipeline(text)
except Exception as e:
logger.error(f"Erreur NER: {e}")
return MedicalEntity([], [], [], [], [], [], [], [])
# Groupement des entités par type
entities = {
"EXAM": [],
"SPECIALTY": [],
"ANATOMY": [],
"PATHOLOGY": [],
"PROCEDURE": [],
"MEASURE": [],
"MEDICATION": [],
"SYMPTOM": []
}
for result in ner_results:
entity_type = result['entity_group'].replace('B-', '').replace('I-', '')
entity_text = result['word']
confidence = result['score']
if entity_type in entities and confidence > 0.7: # Seuil de confiance
entities[entity_type].append((entity_text, confidence))
return MedicalEntity(
exam_types=entities["EXAM"],
specialties=entities["SPECIALTY"],
anatomical_regions=entities["ANATOMY"],
pathologies=entities["PATHOLOGY"],
medical_procedures=entities["PROCEDURE"],
measurements=entities["MEASURE"],
medications=entities["MEDICATION"],
symptoms=entities["SYMPTOM"]
)
def fine_tune_on_templates(self, templates_data: List[Dict],
output_dir: str = None,
epochs: int = 3):
"""Fine-tuning du modèle NER sur des templates médicaux"""
if output_dir is None:
output_dir = self.cache_dir / "medical_ner_model"
logger.info("Début du fine-tuning NER sur templates médicaux")
# Préparation des données d'entraînement
# (Ici, on utiliserait des templates annotés ou de l'auto-annotation)
train_dataset = self._prepare_training_data(templates_data)
# Configuration d'entraînement
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=100,
weight_decay=0.01,
logging_dir=f"{output_dir}/logs",
save_strategy="epoch",
evaluation_strategy="epoch" if train_dataset.get('eval') else "no",
load_best_model_at_end=True,
metric_for_best_model="eval_loss" if train_dataset.get('eval') else None,
)
# Trainer
trainer = Trainer(
model=self.ner_model,
args=training_args,
train_dataset=train_dataset['train'],
eval_dataset=train_dataset.get('eval'),
tokenizer=self.tokenizer,
)
# Entraînement
trainer.train()
# Sauvegarde
trainer.save_model()
self.tokenizer.save_pretrained(output_dir)
logger.info(f"Fine-tuning terminé, modèle sauvé dans {output_dir}")
def _prepare_training_data(self, templates_data: List[Dict]) -> Dict:
"""Prépare les données d'entraînement pour le NER (auto-annotation intelligente)"""
# Cette fonction pourrait utiliser des techniques d'auto-annotation
# ou des datasets médicaux pré-existants pour créer des labels BIO
# Pour l'exemple, retourner un dataset vide
# En production, on utiliserait des techniques d'annotation automatique
# ou des datasets médicaux annotés comme QUAERO, CAS, etc.
class EmptyDataset(Dataset):
def __len__(self):
return 0
def __getitem__(self, idx):
return {}
return {'train': EmptyDataset()}
class AdvancedMedicalEmbedding:
"""Générateur d'embeddings médicaux avancés avec cross-encoder reranking"""
def __init__(self,
base_model: str = "almanach/camembert-bio-base",
cross_encoder_model: str = "auto"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.base_model_name = base_model
# Modèle principal pour embeddings
self._load_base_model()
# Cross-encoder pour reranking
self._load_cross_encoder(cross_encoder_model)
def _load_base_model(self):
"""Charge le modèle de base pour les embeddings"""
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
self.base_model = AutoModel.from_pretrained(self.base_model_name)
self.base_model.to(self.device)
logger.info(f"Modèle de base chargé: {self.base_model_name}")
except Exception as e:
logger.error(f"Erreur chargement modèle de base: {e}")
raise
def _load_cross_encoder(self, model_name: str):
"""Charge le cross-encoder pour reranking"""
if model_name == "auto":
# Sélection automatique du meilleur cross-encoder médical
cross_encoders = [
"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"emilyalsentzer/Bio_ClinicalBERT",
self.base_model_name # Fallback
]
for model in cross_encoders:
try:
self.cross_tokenizer = AutoTokenizer.from_pretrained(model)
self.cross_model = AutoModel.from_pretrained(model)
self.cross_model.to(self.device)
logger.info(f"Cross-encoder chargé: {model}")
break
except:
continue
else:
self.cross_tokenizer = AutoTokenizer.from_pretrained(model_name)
self.cross_model = AutoModel.from_pretrained(model_name)
self.cross_model.to(self.device)
def generate_embedding(self, text: str, entities: MedicalEntity = None) -> np.ndarray:
"""Génère un embedding enrichi pour un texte médical"""
# Tokenisation
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Génération embedding
with torch.no_grad():
outputs = self.base_model(**inputs)
# Mean pooling
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Enrichissement avec entités NER
if entities:
embedding = self._enrich_with_ner_entities(embedding, entities)
return embedding.cpu().numpy().flatten().astype(np.float32)
def _enrich_with_ner_entities(self, base_embedding: torch.Tensor, entities: MedicalEntity) -> torch.Tensor:
"""Enrichit l'embedding avec les entités extraites par NER"""
# Concaténer les entités importantes avec leurs scores de confiance
entity_texts = []
confidence_weights = []
for entity_list in [entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies]:
for entity_text, confidence in entity_list:
entity_texts.append(entity_text)
confidence_weights.append(confidence)
if not entity_texts:
return base_embedding
# Génération d'embeddings pour les entités
entity_text_combined = " [SEP] ".join(entity_texts)
entity_inputs = self.tokenizer(
entity_text_combined,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
entity_outputs = self.base_model(**entity_inputs)
entity_embedding = torch.mean(entity_outputs.last_hidden_state, dim=1)
# Fusion pondérée par les scores de confiance
avg_confidence = np.mean(confidence_weights) if confidence_weights else 0.5
fusion_weight = min(0.4, avg_confidence) # Max 40% pour les entités
enriched_embedding = (1 - fusion_weight) * base_embedding + fusion_weight * entity_embedding
return enriched_embedding
def cross_encoder_rerank(self,
query: str,
candidates: List[Dict],
top_k: int = 3) -> List[Dict]:
"""Reranking avec cross-encoder pour affiner la sélection"""
if len(candidates) <= top_k:
return candidates
reranked_candidates = []
for candidate in candidates:
# Création de la paire query-candidate
pair_text = f"{query} [SEP] {candidate['document']}"
# Tokenisation
inputs = self.cross_tokenizer(
pair_text,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
# Score de similarité cross-encoder
with torch.no_grad():
outputs = self.cross_model(**inputs)
# Utilisation du [CLS] token pour le score de similarité
cls_embedding = outputs.last_hidden_state[:, 0, :]
similarity_score = torch.sigmoid(torch.mean(cls_embedding)).item()
candidate_copy = candidate.copy()
candidate_copy['cross_encoder_score'] = similarity_score
candidate_copy['final_score'] = (
0.6 * candidate['similarity_score'] +
0.4 * similarity_score
)
reranked_candidates.append(candidate_copy)
# Tri par score final
reranked_candidates.sort(key=lambda x: x['final_score'], reverse=True)
return reranked_candidates[:top_k]
class MedicalTemplateVectorDB:
"""Base de données vectorielle optimisée pour templates médicaux"""
def __init__(self, db_path: str = "./medical_vector_db", collection_name: str = "medical_templates"):
self.db_path = db_path
self.collection_name = collection_name
# ChromaDB avec configuration optimisée
self.client = chromadb.PersistentClient(
path=db_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# Collection avec métrique de distance optimisée
try:
self.collection = self.client.get_collection(collection_name)
logger.info(f"Collection '{collection_name}' chargée")
except:
self.collection = self.client.create_collection(
name=collection_name,
metadata={
"hnsw:space": "cosine",
"hnsw:M": 32, # Connectivité du graphe
"hnsw:ef_construction": 200, # Qualité vs vitesse construction
"hnsw:ef_search": 50 # Qualité vs vitesse recherche
}
)
logger.info(f"Collection '{collection_name}' créée avec optimisations HNSW")
def add_template(self,
template_id: str,
template_text: str,
embedding: np.ndarray,
entities: MedicalEntity,
metadata: Dict[str, Any] = None):
"""Ajoute un template avec métadonnées enrichies par NER"""
# Métadonnées automatiques basées sur NER
auto_metadata = {
"exam_types": [entity[0] for entity in entities.exam_types],
"specialties": [entity[0] for entity in entities.specialties],
"anatomical_regions": [entity[0] for entity in entities.anatomical_regions],
"pathologies": [entity[0] for entity in entities.pathologies],
"procedures": [entity[0] for entity in entities.medical_procedures],
"text_length": len(template_text),
"entity_confidence_avg": np.mean([
entity[1] for entity_list in [
entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies
] for entity in entity_list
]) if any([entities.exam_types, entities.specialties,
entities.anatomical_regions, entities.pathologies]) else 0.0
}
if metadata:
auto_metadata.update(metadata)
self.collection.add(
embeddings=[embedding.tolist()],
documents=[template_text],
metadatas=[auto_metadata],
ids=[template_id]
)
logger.info(f"Template {template_id} ajouté avec métadonnées NER automatiques")
def advanced_search(self,
query_embedding: np.ndarray,
n_results: int = 10,
entity_filters: Dict[str, List[str]] = None,
confidence_threshold: float = 0.0) -> List[Dict]:
"""Recherche avancée avec filtres basés sur entités NER"""
where_clause = {}
# Filtres basés sur entités NER extraites
if entity_filters:
for entity_type, entity_values in entity_filters.items():
if entity_values:
where_clause[entity_type] = {"$in": entity_values}
# Filtre par confiance moyenne des entités
if confidence_threshold > 0:
where_clause["entity_confidence_avg"] = {"$gte": confidence_threshold}
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=where_clause if where_clause else None,
include=["documents", "metadatas", "distances"]
)
# Formatage des résultats
formatted_results = []
for i in range(len(results['ids'][0])):
formatted_results.append({
'id': results['ids'][0][i],
'document': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'similarity_score': 1 - results['distances'][0][i],
'distance': results['distances'][0][i]
})
return formatted_results
class AdvancedMedicalTemplateProcessor:
"""Processeur avancé avec NER fine-tuné et reranking cross-encoder"""
def __init__(self,
base_model: str = "almanach/camembert-bio-base",
db_path: str = "./advanced_medical_vector_db"):
self.ner_extractor = AdvancedMedicalNER()
self.embedding_generator = AdvancedMedicalEmbedding(base_model)
self.vector_db = MedicalTemplateVectorDB(db_path)
logger.info("Processeur médical avancé initialisé avec NER fine-tuné et cross-encoder reranking")
def process_templates_batch(self,
templates: List[Dict[str, str]],
batch_size: int = 8,
fine_tune_ner: bool = False) -> None:
"""Traitement avancé avec option de fine-tuning NER"""
if fine_tune_ner:
logger.info("Fine-tuning du modèle NER sur les templates...")
self.ner_extractor.fine_tune_on_templates(templates)
logger.info(f"Traitement avancé de {len(templates)} templates")
for i in tqdm(range(0, len(templates), batch_size), desc="Traitement avancé"):
batch = templates[i:i+batch_size]
for template in batch:
try:
template_id = template['id']
template_text = template['text']
metadata = template.get('metadata', {})
# NER avancé
entities = self.ner_extractor.extract_entities(template_text)
# Embedding enrichi
embedding = self.embedding_generator.generate_embedding(template_text, entities)
# Stockage avec métadonnées NER
self.vector_db.add_template(
template_id=template_id,
template_text=template_text,
embedding=embedding,
entities=entities,
metadata=metadata
)
except Exception as e:
logger.error(f"Erreur traitement template {template.get('id', 'unknown')}: {e}")
continue
def find_best_template_with_reranking(self,
transcription: str,
initial_candidates: int = 10,
final_results: int = 3) -> List[Dict]:
"""Recherche optimale avec reranking cross-encoder"""
# 1. Extraction NER de la transcription
query_entities = self.ner_extractor.extract_entities(transcription)
# 2. Génération embedding enrichi
query_embedding = self.embedding_generator.generate_embedding(transcription, query_entities)
# 3. Filtres automatiques basés sur entités extraites
entity_filters = {}
if query_entities.exam_types:
entity_filters['exam_types'] = [entity[0] for entity in query_entities.exam_types]
if query_entities.specialties:
entity_filters['specialties'] = [entity[0] for entity in query_entities.specialties]
if query_entities.anatomical_regions:
entity_filters['anatomical_regions'] = [entity[0] for entity in query_entities.anatomical_regions]
# 4. Recherche vectorielle initiale
initial_candidates_results = self.vector_db.advanced_search(
query_embedding=query_embedding,
n_results=initial_candidates,
entity_filters=entity_filters,
confidence_threshold=0.6
)
# 5. Reranking avec cross-encoder
if len(initial_candidates_results) > final_results:
final_results_reranked = self.embedding_generator.cross_encoder_rerank(
query=transcription,
candidates=initial_candidates_results,
top_k=final_results
)
else:
final_results_reranked = initial_candidates_results
# 6. Enrichissement des résultats avec détails NER
for result in final_results_reranked:
result['query_entities'] = {
'exam_types': query_entities.exam_types,
'specialties': query_entities.specialties,
'anatomical_regions': query_entities.anatomical_regions,
'pathologies': query_entities.pathologies
}
return final_results_reranked
# Exemple d'utilisation avancée
def main():
"""Exemple d'utilisation du système avancé"""
# Initialisation du processeur avancé
processor = AdvancedMedicalTemplateProcessor()
# Traitement des templates avec fine-tuning optionnel
sample_templates = [
{
'id': 'angio_001',
'text': """Échographie et doppler artério-veineux des membres inférieurs.
Exploration de l'incontinence veineuse superficielle...""",
'metadata': {'source': 'angiologie', 'version': '2024'}
}
]
# Traitement avec fine-tuning NER
processor.process_templates_batch(sample_templates, fine_tune_ner=False)
# Recherche avec reranking
transcription = """madame bacon nicole bilan œdème droit gonalgies ostéophytes
incontinence veineuse modérée portions surale droite crurale gauche saphéniennes"""
best_matches = processor.find_best_template_with_reranking(
transcription=transcription,
initial_candidates=15,
final_results=3
)
# Affichage des résultats
for i, match in enumerate(best_matches):
print(f"\n=== Match {i+1} ===")
print(f"Template ID: {match['id']}")
print(f"Score final: {match.get('final_score', match['similarity_score']):.4f}")
print(f"Score cross-encoder: {match.get('cross_encoder_score', 'N/A')}")
print(f"Entités détectées dans la query:")
for entity_type, entities in match.get('query_entities', {}).items():
if entities:
print(f" - {entity_type}: {[f'{e[0]} ({e[1]:.2f})' for e in entities]}")
if __name__ == "__main__":
main()