Spaces:
Sleeping
Sleeping
| from haystack.nodes import WebRetriever | |
| from haystack.schema import Document | |
| from typing import List | |
| from haystack.document_stores import FAISSDocumentStore | |
| from haystack.nodes import AnswerParser, PromptNode, PromptTemplate | |
| from haystack import Pipeline | |
| from haystack.nodes import DensePassageRetriever | |
| import os | |
| from dotenv import load_dotenv | |
| def initialize_documents(serp_key, nl_query): | |
| """ | |
| Initialize documents retrieved from the SERP API. | |
| Args: | |
| serp_key (str): API key for the SERP API. | |
| nl_query (str): Natural language query to retrieve documents for. | |
| """ | |
| # Initialize WebRetriever | |
| retriever = WebRetriever(api_key=serp_key, | |
| mode="preprocessed_documents", | |
| top_k=100) | |
| # Retrieve documents based a natural language query | |
| documents : List[Document] = retriever.retrieve(query=nl_query) | |
| return documents | |
| def initialize_faiss_document_store(documents): | |
| """ | |
| Initialize a FAISS document store and retriever. | |
| Args: | |
| documents (List[Document]): List of documents to be stored in the document store. | |
| Returns: | |
| document_store (FAISSDocumentStore): FAISS document store. | |
| retriever (DensePassageRetriever): Dense passage retriever. | |
| """ | |
| # Initialize document store | |
| document_store = FAISSDocumentStore(faiss_index_factory_str="Flat", return_embedding=True) | |
| retriever = DensePassageRetriever( | |
| document_store=document_store, | |
| query_embedding_model="facebook/dpr-question_encoder-single-nq-base", | |
| passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", | |
| use_gpu=True, | |
| embed_title=True, | |
| ) | |
| # Delete existing documents in document store | |
| document_store.delete_documents() | |
| document_store.write_documents(documents) | |
| # Add documents embeddings to index | |
| document_store.update_embeddings(retriever=retriever) | |
| return document_store, retriever | |
| def initialize_rag_pipeline(retriever, openai_key): | |
| """ | |
| Initialize a pipeline for RAG-based question answering. | |
| Args: | |
| retriever (DensePassageRetriever): Dense passage retriever. | |
| openai_key (str): API key for OpenAI. | |
| Returns: | |
| query_pipeline (Pipeline): Pipeline for RAG-based question answering. | |
| """ | |
| prompt_template = PromptTemplate(prompt = """"Answer the following query based on the provided context. If the context does | |
| not include an answer, reply with 'The data does not contain information related to the question'.\n | |
| Query: {query}\n | |
| Documents: {join(documents)} | |
| Answer: | |
| """, | |
| output_parser=AnswerParser()) | |
| prompt_node = PromptNode(model_name_or_path = "gpt-4", | |
| api_key = openai_key, | |
| default_prompt_template = prompt_template, | |
| max_length = 500, | |
| model_kwargs={"stream":True}) | |
| query_pipeline = Pipeline() | |
| query_pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) | |
| query_pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Retriever"]) | |
| return query_pipeline | |