Spaces:
Runtime error
Runtime error
| from smolagents import Tool | |
| from langchain.docstore.document import Document | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import datasets | |
| from typing import List | |
| class SentenceTransformerRetriever: | |
| """Retriever that uses SentenceTransformer embeddings for semantic search.""" | |
| def __init__(self, docs: List[Document], model_name: str = "all-MiniLM-L6-v2"): | |
| """Initialize with documents and a SentenceTransformer model. | |
| Args: | |
| docs: List of Document objects | |
| model_name: Name of the SentenceTransformer model to use | |
| """ | |
| self.docs = docs | |
| self.model = SentenceTransformer(model_name) | |
| # Create embeddings for all documents | |
| self.doc_texts = [doc.page_content for doc in self.docs] | |
| # Ensure we get numpy arrays for document embeddings | |
| self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_numpy=True) | |
| def get_relevant_documents(self, query: str, k: int = 3) -> List[Document]: | |
| """Return documents relevant to the query. | |
| Args: | |
| query: Query string | |
| k: Number of documents to return | |
| Returns: | |
| List of relevant Document objects | |
| """ | |
| # Encode the query and ensure we get a numpy array | |
| query_embedding = self.model.encode(query, convert_to_numpy=True) | |
| # Calculate similarities | |
| # Calculate cosine similarity manually to avoid tensor conversion issues | |
| similarities = [] | |
| for doc_embedding in self.doc_embeddings: | |
| # Calculate cosine similarity between query and document | |
| dot_product = np.dot(query_embedding, doc_embedding) | |
| query_norm = np.linalg.norm(query_embedding) | |
| doc_norm = np.linalg.norm(doc_embedding) | |
| similarity = dot_product / (query_norm * doc_norm) | |
| similarities.append(similarity) | |
| # Convert to numpy array | |
| similarities = np.array(similarities) | |
| # Get the top k most similar documents | |
| # Sort indices by similarity in descending order and take the top k | |
| top_k_indices = np.argsort(-similarities)[:k] | |
| # Return the top k documents | |
| return [self.docs[i] for i in top_k_indices] | |
| class GuestInfoRetrieverTool(Tool): | |
| name = "guest_info_retriever" | |
| description = "Retrieves detailed information about gala guests based on their name or relation using semantic search." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The name or relation of the guest you want information about." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, docs, model_name: str = "all-MiniLM-L6-v2"): | |
| self.is_initialized = False | |
| self.retriever = SentenceTransformerRetriever(docs, model_name) | |
| def forward(self, query: str): | |
| results = self.retriever.get_relevant_documents(query) | |
| if results: | |
| return "\n\n".join([doc.page_content for doc in results[:3]]) | |
| else: | |
| return "No matching guest information found." | |
| def load_guest_dataset(model_name: str = "all-MiniLM-L6-v2"): | |
| """Load the guest dataset and create a retriever tool. | |
| Args: | |
| model_name: Name of the SentenceTransformer model to use | |
| Returns: | |
| GuestInfoRetrieverTool: A tool for retrieving guest information | |
| """ | |
| # Load the dataset | |
| guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") | |
| # Convert dataset entries into Document objects | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| f"Name: {guest['name']}", | |
| f"Relation: {guest['relation']}", | |
| f"Description: {guest['description']}", | |
| f"Email: {guest['email']}" | |
| ]), | |
| metadata={"name": guest["name"]} | |
| ) | |
| for guest in guest_dataset | |
| ] | |
| # Return the tool with the specified model | |
| return GuestInfoRetrieverTool(docs, model_name=model_name) | |