VedaMD-Backend-v2 / src /simple_vector_store.py
sniro23's picture
Production ready: Clean codebase + Cerebras + Automated pipeline
b4971bd
"""
Simple Vector Store for Medical RAG - Runtime Version
This version is designed to load a pre-computed vector store from the Hugging Face Hub.
"""
import os
import json
import logging
import time
from typing import List, Dict, Any, Optional
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import faiss
from sentence_transformers import SentenceTransformer
from langchain_core.documents import Document
from huggingface_hub import hf_hub_download
@dataclass
class SearchResult:
"""Simple search result structure"""
content: str
score: float
metadata: Dict[str, Any]
class SimpleVectorStore:
"""
A simplified vector store that loads its index and documents from the Hugging Face Hub.
It does not contain any logic for creating embeddings or building an index at runtime.
"""
def __init__(self,
repo_id: str = None,
local_dir: str = None,
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Initializes the vector store by loading from HF Hub or local directory.
Args:
repo_id (str): The Hugging Face Hub repository ID (e.g., "user/repo-name"). Optional if local_dir provided.
local_dir (str): Local directory containing vector store files. Optional if repo_id provided.
embedding_model_name (str): The embedding model to use for query embedding.
Defaults to sentence-transformers/all-MiniLM-L6-v2 (384d).
"""
if not repo_id and not local_dir:
raise ValueError("Either repo_id or local_dir must be provided")
self.repo_id = repo_id
self.local_dir = local_dir
self.embedding_model_name = embedding_model_name
self.setup_logging()
# Log the embedding model choice for medical domain
if "Clinical" in embedding_model_name or "Bio" in embedding_model_name:
self.logger.info(f"πŸ₯ Using medical domain embedding model: {embedding_model_name}")
else:
self.logger.warning(f"⚠️ Using general domain embedding model: {embedding_model_name}")
self.embedding_model = None
self.index = None
self.documents = []
self.metadata = []
self._initialize_embedding_model()
# Load from local directory or HF Hub
if self.local_dir:
self.load_from_local_directory()
else:
self.load_from_huggingface_hub()
def setup_logging(self):
"""Setup logging for the vector store"""
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def _initialize_embedding_model(self):
"""Initialize the sentence transformer model for creating query embeddings."""
try:
self.logger.info(f"Loading embedding model: {self.embedding_model_name}")
self.embedding_model = SentenceTransformer(self.embedding_model_name)
self.logger.info("Embedding model loaded successfully.")
except Exception as e:
self.logger.error(f"Error loading embedding model: {e}")
raise
def load_from_local_directory(self):
"""
Loads the vector store artifacts from a local directory.
"""
self.logger.info(f"Loading vector store from local directory: {self.local_dir}")
try:
local_path = Path(self.local_dir)
# Check if directory exists
if not local_path.exists():
raise FileNotFoundError(f"Local directory not found: {self.local_dir}")
# Load the FAISS index
index_path = local_path / "faiss_index.bin"
self.index = faiss.read_index(str(index_path))
self.logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors from local directory.")
# Load documents and metadata
docs_path = local_path / "documents.json"
metadata_path = local_path / "metadata.json"
config_path = local_path / "config.json"
with open(docs_path, 'r', encoding='utf-8') as f:
page_contents = json.load(f)
with open(metadata_path, 'r', encoding='utf-8') as f:
metadatas = json.load(f)
# Combine them to reconstruct the documents
if len(page_contents) != len(metadatas):
raise ValueError("Mismatch between number of documents and metadata entries.")
for i in range(len(page_contents)):
content = page_contents[i] if isinstance(page_contents[i], str) else page_contents[i].get('page_content', '')
metadata = metadatas[i] if isinstance(metadatas[i], dict) else {}
# Ensure a valid citation exists
if not metadata.get('citation'):
source_path = metadata.get('source', 'Unknown')
if source_path != 'Unknown':
metadata['citation'] = Path(source_path).stem.replace('-', ' ').title()
else:
metadata['citation'] = 'Unknown Source'
self.documents.append(Document(page_content=content, metadata=metadata))
self.metadata.append(metadata)
self.logger.info(f"Loaded {len(self.documents)} documents from local directory.")
# Load and log the configuration
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
self.logger.info(f"Vector store configuration loaded: {config}")
except Exception as e:
self.logger.error(f"Failed to load vector store from local directory: {e}")
raise
def load_from_huggingface_hub(self):
"""
Downloads the vector store artifacts from the specified Hugging Face Hub repository and loads them.
"""
self.logger.info(f"Downloading vector store from Hugging Face Hub repo: {self.repo_id}")
try:
# Download the four essential files
index_path = hf_hub_download(repo_id=self.repo_id, filename="faiss_index.bin")
docs_path = hf_hub_download(repo_id=self.repo_id, filename="documents.json")
metadata_path = hf_hub_download(repo_id=self.repo_id, filename="metadata.json") # Download metadata
config_path = hf_hub_download(repo_id=self.repo_id, filename="config.json")
self.logger.info("Vector store files downloaded successfully.")
# Load the FAISS index
self.index = faiss.read_index(index_path)
self.logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors.")
# Load the documents and metadata separately
with open(docs_path, 'r', encoding='utf-8') as f:
page_contents = json.load(f)
with open(metadata_path, 'r', encoding='utf-8') as f:
metadatas = json.load(f)
# Combine them to reconstruct the documents
if len(page_contents) != len(metadatas):
raise ValueError("Mismatch between number of documents and metadata entries.")
for i in range(len(page_contents)):
content = page_contents[i] if isinstance(page_contents[i], str) else page_contents[i].get('page_content', '')
metadata = metadatas[i] if isinstance(metadatas[i], dict) else {}
# FIX: Ensure a valid citation exists.
# If 'citation' is missing or empty, create one from the source file path.
if not metadata.get('citation'):
source_path = metadata.get('source', 'Unknown')
if source_path != 'Unknown':
# Extract the guideline name from the parent directory of the source file
metadata['citation'] = Path(source_path).parent.name.replace('-', ' ').title()
else:
metadata['citation'] = 'Unknown Source'
self.documents.append(Document(page_content=content, metadata=metadata))
self.metadata.append(metadata)
self.logger.info(f"Loaded {len(self.documents)} documents with improved citations.")
# Load and log the configuration
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
self.logger.info(f"Vector store configuration loaded: {config}")
except Exception as e:
self.logger.error(f"Failed to load vector store from Hugging Face Hub: {e}")
raise
def search(self, query: str, k: int = 5) -> List[SearchResult]:
"""
Searches the vector store for the top-k most similar documents to the query.
Args:
query (str): The search query.
k (int): The number of results to return.
Returns:
A list of SearchResult objects.
"""
if not self.index or not self.documents:
self.logger.error("Search attempted but vector store is not initialized.")
return []
# Create an embedding for the query
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)
# Search the FAISS index
scores, indices = self.index.search(query_embedding.astype('float32'), k)
# Process and return the results
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1: continue # Skip invalid indices
doc = self.documents[idx]
results.append(SearchResult(
content=doc.page_content,
score=float(score),
metadata=doc.metadata
))
return results
def main():
"""Main function to test the simple vector store"""
print("πŸ”„ Testing Simple Vector Store v2.0")
print("=" * 60)
try:
# Initialize vector store
vector_store = SimpleVectorStore(
repo_id="user/repo-name"
)
# Test search functionality
print(f"\nπŸ” TESTING SEARCH FUNCTIONALITY:")
test_queries = [
"magnesium sulfate dosage preeclampsia",
"postpartum hemorrhage management",
"fetal heart rate monitoring",
"emergency cesarean delivery"
]
for query in test_queries:
print(f"\nπŸ“ Query: '{query}'")
results = vector_store.search(query, k=3)
for i, result in enumerate(results, 1):
print(f" Result {i}: Score={result.score:.3f}, Doc={result.metadata.get('document_name', 'Unknown')}")
print(f" Type={result.metadata.get('content_type', 'general')}")
print(f" Preview: {result.content[:100]}...")
print(f"\nπŸŽ‰ Simple Vector Store Testing Complete!")
print(f"βœ… Successfully loaded vector store with {len(vector_store.documents):,} embeddings")
print(f"βœ… Search functionality working with high relevance scores")
return vector_store
except Exception as e:
print(f"❌ Error in simple vector store: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
main()