Spaces:
Sleeping
Sleeping
File size: 3,284 Bytes
51349bc 5f540b8 51349bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import numpy as np
from typing import List
from pathlib import Path
from src.constant import BASE_DIR
import chromadb
from langchain.vectorstores import Chroma
from langchain.schema import Document
from uuid import uuid4
DATA_DIR = os.path.join(BASE_DIR, "data", "db")
class VectorStore:
"""
Wrapper around Chroma vector database for persistent storage
and retrieval of document embeddings.
"""
def __init__(self,
collection_name: str = "medrag",
persist_directory: str = DATA_DIR):
self.collection_name = collection_name
self.persist_directory = persist_directory
self.client = None
self.collection = None
self._initialize_store()
def _initialize_store(self):
"""Initialize Chroma client and collection."""
try:
dir_path = Path(self.persist_directory)
dir_path.mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(self.persist_directory)
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "RAG collection for biomedical research"}
)
print(f"Store initialized successfully: {self.collection_name}")
except Exception as e:
print(f"Error initializing the store: {e}")
raise
def get_len(self) -> int:
"""Return number of documents in the collection."""
return self.collection.count()
def add_documents(self, documents: List[Document], embeddings: np.ndarray, batch_size: int = 5000):
"""
Add documents and their embeddings to the vector store in batches.
"""
if isinstance(embeddings, np.ndarray):
embeddings = embeddings.tolist() # Ensure compatibility
for start in range(0, len(documents), batch_size):
batch_docs = documents[start:start + batch_size]
batch_embeds = embeddings[start:start + batch_size]
ids, metadatas, texts, embeds = [], [], [], []
for idx, (doc, emb) in enumerate(zip(batch_docs, batch_embeds)):
ids.append(f"doc_{uuid4().hex}")
texts.append(doc.page_content)
metadata = dict(doc.metadata) if getattr(doc, "metadata", None) else {}
metadata.update({"doc_index": idx, "content_length": len(doc.page_content)})
metadatas.append(metadata)
embeds.append(emb)
self.collection.add(
ids=ids,
documents=texts,
embeddings=embeds,
metadatas=metadatas
)
print(f"Documents and embeddings added to collection: {self.collection_name}")
def get_retriever(self, embedding_function, search_kwargs: dict = None):
"""
Return a retriever interface for semantic search.
"""
if search_kwargs is None:
search_kwargs = {"k": 5}
vectorstore = Chroma(
client=self.client,
collection_name=self.collection_name,
embedding_function=embedding_function
)
return vectorstore.as_retriever(search_kwargs=search_kwargs) |