Spaces:
Sleeping
Sleeping
| """ | |
| Simple Vector Store for Medical RAG - Runtime Version | |
| This version is designed to load a pre-computed vector store from the Hugging Face Hub. | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import time | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| import numpy as np | |
| from dataclasses import dataclass | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from langchain_core.documents import Document | |
| from huggingface_hub import hf_hub_download | |
| class SearchResult: | |
| """Simple search result structure""" | |
| content: str | |
| score: float | |
| metadata: Dict[str, Any] | |
| class SimpleVectorStore: | |
| """ | |
| A simplified vector store that loads its index and documents from the Hugging Face Hub. | |
| It does not contain any logic for creating embeddings or building an index at runtime. | |
| """ | |
| def __init__(self, | |
| repo_id: str = None, | |
| local_dir: str = None, | |
| embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| """ | |
| Initializes the vector store by loading from HF Hub or local directory. | |
| Args: | |
| repo_id (str): The Hugging Face Hub repository ID (e.g., "user/repo-name"). Optional if local_dir provided. | |
| local_dir (str): Local directory containing vector store files. Optional if repo_id provided. | |
| embedding_model_name (str): The embedding model to use for query embedding. | |
| Defaults to sentence-transformers/all-MiniLM-L6-v2 (384d). | |
| """ | |
| if not repo_id and not local_dir: | |
| raise ValueError("Either repo_id or local_dir must be provided") | |
| self.repo_id = repo_id | |
| self.local_dir = local_dir | |
| self.embedding_model_name = embedding_model_name | |
| self.setup_logging() | |
| # Log the embedding model choice for medical domain | |
| if "Clinical" in embedding_model_name or "Bio" in embedding_model_name: | |
| self.logger.info(f"π₯ Using medical domain embedding model: {embedding_model_name}") | |
| else: | |
| self.logger.warning(f"β οΈ Using general domain embedding model: {embedding_model_name}") | |
| self.embedding_model = None | |
| self.index = None | |
| self.documents = [] | |
| self.metadata = [] | |
| self._initialize_embedding_model() | |
| # Load from local directory or HF Hub | |
| if self.local_dir: | |
| self.load_from_local_directory() | |
| else: | |
| self.load_from_huggingface_hub() | |
| def setup_logging(self): | |
| """Setup logging for the vector store""" | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| self.logger = logging.getLogger(__name__) | |
| def _initialize_embedding_model(self): | |
| """Initialize the sentence transformer model for creating query embeddings.""" | |
| try: | |
| self.logger.info(f"Loading embedding model: {self.embedding_model_name}") | |
| self.embedding_model = SentenceTransformer(self.embedding_model_name) | |
| self.logger.info("Embedding model loaded successfully.") | |
| except Exception as e: | |
| self.logger.error(f"Error loading embedding model: {e}") | |
| raise | |
| def load_from_local_directory(self): | |
| """ | |
| Loads the vector store artifacts from a local directory. | |
| """ | |
| self.logger.info(f"Loading vector store from local directory: {self.local_dir}") | |
| try: | |
| local_path = Path(self.local_dir) | |
| # Check if directory exists | |
| if not local_path.exists(): | |
| raise FileNotFoundError(f"Local directory not found: {self.local_dir}") | |
| # Load the FAISS index | |
| index_path = local_path / "faiss_index.bin" | |
| self.index = faiss.read_index(str(index_path)) | |
| self.logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors from local directory.") | |
| # Load documents and metadata | |
| docs_path = local_path / "documents.json" | |
| metadata_path = local_path / "metadata.json" | |
| config_path = local_path / "config.json" | |
| with open(docs_path, 'r', encoding='utf-8') as f: | |
| page_contents = json.load(f) | |
| with open(metadata_path, 'r', encoding='utf-8') as f: | |
| metadatas = json.load(f) | |
| # Combine them to reconstruct the documents | |
| if len(page_contents) != len(metadatas): | |
| raise ValueError("Mismatch between number of documents and metadata entries.") | |
| for i in range(len(page_contents)): | |
| content = page_contents[i] if isinstance(page_contents[i], str) else page_contents[i].get('page_content', '') | |
| metadata = metadatas[i] if isinstance(metadatas[i], dict) else {} | |
| # Ensure a valid citation exists | |
| if not metadata.get('citation'): | |
| source_path = metadata.get('source', 'Unknown') | |
| if source_path != 'Unknown': | |
| metadata['citation'] = Path(source_path).stem.replace('-', ' ').title() | |
| else: | |
| metadata['citation'] = 'Unknown Source' | |
| self.documents.append(Document(page_content=content, metadata=metadata)) | |
| self.metadata.append(metadata) | |
| self.logger.info(f"Loaded {len(self.documents)} documents from local directory.") | |
| # Load and log the configuration | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = json.load(f) | |
| self.logger.info(f"Vector store configuration loaded: {config}") | |
| except Exception as e: | |
| self.logger.error(f"Failed to load vector store from local directory: {e}") | |
| raise | |
| def load_from_huggingface_hub(self): | |
| """ | |
| Downloads the vector store artifacts from the specified Hugging Face Hub repository and loads them. | |
| """ | |
| self.logger.info(f"Downloading vector store from Hugging Face Hub repo: {self.repo_id}") | |
| try: | |
| # Download the four essential files | |
| index_path = hf_hub_download(repo_id=self.repo_id, filename="faiss_index.bin") | |
| docs_path = hf_hub_download(repo_id=self.repo_id, filename="documents.json") | |
| metadata_path = hf_hub_download(repo_id=self.repo_id, filename="metadata.json") # Download metadata | |
| config_path = hf_hub_download(repo_id=self.repo_id, filename="config.json") | |
| self.logger.info("Vector store files downloaded successfully.") | |
| # Load the FAISS index | |
| self.index = faiss.read_index(index_path) | |
| self.logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors.") | |
| # Load the documents and metadata separately | |
| with open(docs_path, 'r', encoding='utf-8') as f: | |
| page_contents = json.load(f) | |
| with open(metadata_path, 'r', encoding='utf-8') as f: | |
| metadatas = json.load(f) | |
| # Combine them to reconstruct the documents | |
| if len(page_contents) != len(metadatas): | |
| raise ValueError("Mismatch between number of documents and metadata entries.") | |
| for i in range(len(page_contents)): | |
| content = page_contents[i] if isinstance(page_contents[i], str) else page_contents[i].get('page_content', '') | |
| metadata = metadatas[i] if isinstance(metadatas[i], dict) else {} | |
| # FIX: Ensure a valid citation exists. | |
| # If 'citation' is missing or empty, create one from the source file path. | |
| if not metadata.get('citation'): | |
| source_path = metadata.get('source', 'Unknown') | |
| if source_path != 'Unknown': | |
| # Extract the guideline name from the parent directory of the source file | |
| metadata['citation'] = Path(source_path).parent.name.replace('-', ' ').title() | |
| else: | |
| metadata['citation'] = 'Unknown Source' | |
| self.documents.append(Document(page_content=content, metadata=metadata)) | |
| self.metadata.append(metadata) | |
| self.logger.info(f"Loaded {len(self.documents)} documents with improved citations.") | |
| # Load and log the configuration | |
| with open(config_path, 'r', encoding='utf-8') as f: | |
| config = json.load(f) | |
| self.logger.info(f"Vector store configuration loaded: {config}") | |
| except Exception as e: | |
| self.logger.error(f"Failed to load vector store from Hugging Face Hub: {e}") | |
| raise | |
| def search(self, query: str, k: int = 5) -> List[SearchResult]: | |
| """ | |
| Searches the vector store for the top-k most similar documents to the query. | |
| Args: | |
| query (str): The search query. | |
| k (int): The number of results to return. | |
| Returns: | |
| A list of SearchResult objects. | |
| """ | |
| if not self.index or not self.documents: | |
| self.logger.error("Search attempted but vector store is not initialized.") | |
| return [] | |
| # Create an embedding for the query | |
| query_embedding = self.embedding_model.encode([query], normalize_embeddings=True) | |
| # Search the FAISS index | |
| scores, indices = self.index.search(query_embedding.astype('float32'), k) | |
| # Process and return the results | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx == -1: continue # Skip invalid indices | |
| doc = self.documents[idx] | |
| results.append(SearchResult( | |
| content=doc.page_content, | |
| score=float(score), | |
| metadata=doc.metadata | |
| )) | |
| return results | |
| def main(): | |
| """Main function to test the simple vector store""" | |
| print("π Testing Simple Vector Store v2.0") | |
| print("=" * 60) | |
| try: | |
| # Initialize vector store | |
| vector_store = SimpleVectorStore( | |
| repo_id="user/repo-name" | |
| ) | |
| # Test search functionality | |
| print(f"\nπ TESTING SEARCH FUNCTIONALITY:") | |
| test_queries = [ | |
| "magnesium sulfate dosage preeclampsia", | |
| "postpartum hemorrhage management", | |
| "fetal heart rate monitoring", | |
| "emergency cesarean delivery" | |
| ] | |
| for query in test_queries: | |
| print(f"\nπ Query: '{query}'") | |
| results = vector_store.search(query, k=3) | |
| for i, result in enumerate(results, 1): | |
| print(f" Result {i}: Score={result.score:.3f}, Doc={result.metadata.get('document_name', 'Unknown')}") | |
| print(f" Type={result.metadata.get('content_type', 'general')}") | |
| print(f" Preview: {result.content[:100]}...") | |
| print(f"\nπ Simple Vector Store Testing Complete!") | |
| print(f"β Successfully loaded vector store with {len(vector_store.documents):,} embeddings") | |
| print(f"β Search functionality working with high relevance scores") | |
| return vector_store | |
| except Exception as e: | |
| print(f"β Error in simple vector store: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| if __name__ == "__main__": | |
| main() |