Spaces:
Build error
Build error
| #from haystack.components.generators import HuggingFaceTGIGenerator | |
| from llama_index.llms import HuggingFaceInferenceAPI | |
| from llama_index.llms import ChatMessage, MessageRole | |
| from llama_index.prompts import ChatPromptTemplate | |
| from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext #, LLMPredictor, StorageContext, load_index_from_storage | |
| import gradio as gr | |
| #import sys | |
| #import logging | |
| #import torch | |
| #from huggingface_hub import InferenceClient | |
| #import tqdm as notebook_tqdm | |
| import requests | |
| import os | |
| import json | |
| #generator = HuggingFaceTGIGenerator("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
| #generator.warm_up() | |
| def download_file(url, filename): | |
| """ | |
| Download a file from the specified URL and save it locally under the given filename. | |
| """ | |
| response = requests.get(url, stream=True) | |
| # Check if the request was successful | |
| if filename in os.listdir('content/'): return | |
| if filename == '': return | |
| if response.status_code == 200: | |
| with open('content/' + filename, 'wb') as file: | |
| for chunk in response.iter_content(chunk_size=1024): | |
| if chunk: # filter out keep-alive new chunks | |
| file.write(chunk) | |
| print(f"Download complete: {filename}") | |
| else: | |
| print(f"Error: Unable to download file. HTTP status code: {response.status_code}") | |
| #def save_answer(prompt, rag_answer, norag_answer): | |
| # json_dict = dict() | |
| # json_dict['prompt'] = prompt | |
| # json_dict['rag_answer'] = rag_answer | |
| # json_dict['norag_answer'] = norag_answer | |
| # | |
| # file_path = 'saved_answers.json' | |
| # | |
| # # Check if the file exists | |
| # if not os.path.isfile(file_path): | |
| # with open(file_path, 'w') as f: | |
| # # Create an empty list in the file to store dictionaries | |
| # json.dump([], f) | |
| # f.write('\n') # Add a newline to separate the list and future entries | |
| # | |
| # # Open the file in append mode | |
| # with open(file_path, 'a+') as f: | |
| # # Read the existing data | |
| # f.seek(0) | |
| # data = json.load(f) | |
| # | |
| # # Append the new dictionary to the list | |
| # data.append(json_dict) | |
| # | |
| # # Move the cursor to the beginning of the file | |
| # f.seek(0) | |
| # | |
| # # Write the updated list of dictionaries | |
| # json.dump(data, f) | |
| # f.write('\n') # Add a newline to separate the list and future entries | |
| # | |
| # | |
| #def check_answer(prompt): | |
| # file_path = 'saved_answers.json' | |
| # | |
| # if not os.path.isfile(file_path): | |
| # with open(file_path, 'w') as f: | |
| # # Create an empty list in the file to store dictionaries | |
| # json.dump([], f) | |
| # f.write('\n') # Add a newline to separate the list and future entries | |
| # with open('saved_answers.json', 'r') as f: | |
| # data = json.load(f) | |
| # for entry in data: | |
| # if entry['prompt'] == prompt: | |
| # return entry['rag_answer'], entry['norag_answer'] | |
| # return None, None # Return None if the prompt is not found | |
| def save_answer(prompt, rag_answer, norag_answer): | |
| file_path = 'saved_answers.jsonl' | |
| # Create a dictionary for the current answer | |
| json_dict = { | |
| 'prompt': prompt, | |
| 'rag_answer': rag_answer, | |
| 'norag_answer': norag_answer | |
| } | |
| # Check if the file exists, and create it if not | |
| #if not os.path.isfile(file_path): | |
| # with open(file_path, 'w') as f: | |
| # # Create an empty list in the file to store dictionaries | |
| # json.dump([], f) | |
| # f.write('\n') # Add a newline to separate the list and future entries | |
| # Load existing data from the file | |
| existing_data = load_jsonl(file_path) | |
| # Append the new answer to the existing data | |
| existing_data.append(json_dict) | |
| # Save the updated data back to the file | |
| write_to_jsonl(file_path, existing_data) | |
| def check_answer(prompt): | |
| file_path = 'saved_answers.jsonl' | |
| ## Check if the file exists, and create it if not | |
| #if not os.path.isfile(file_path): | |
| # with open(file_path, 'w') as f: | |
| # # Create an empty list in the file to store dictionaries | |
| # json.dump([], f) | |
| # f.write('\n') # Add a newline to separate the list and future entries | |
| # Load existing data from the file | |
| try: | |
| existing_data = load_jsonl(file_path) | |
| except: | |
| return None, None | |
| if len(existing_data) == 0: | |
| return None, None | |
| # Find the answer for the given prompt, if it exists | |
| for entry in existing_data: | |
| if entry['prompt'] == prompt: | |
| return entry['rag_answer'], entry['norag_answer'] | |
| # Return None if the prompt is not found | |
| return None, None | |
| # Helper functions | |
| def load_jsonl(file_path): | |
| data = [] | |
| with open(file_path, 'r') as file: | |
| for line in file: | |
| # Each line is a JSON object | |
| item = json.loads(line) | |
| data.append(item) | |
| return data | |
| def write_to_jsonl(file_path, data): | |
| with open(file_path, 'a+') as file: | |
| for item in data: | |
| # Convert Python object to JSON string and write it to the file | |
| json_line = json.dumps(item) | |
| file.write(json_line + '\n') | |
| def generate(prompt, history, rag_only, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): | |
| rag_answer, norag_answer = check_answer(prompt) | |
| if rag_answer != None: | |
| if rag_only: | |
| return f'* Mixtral + RAG Output:\n{rag_answer}' | |
| else: | |
| return f'* Mixtral Output:\n{norag_answer}\n\n* Mixtral + RAG Output:\n{rag_answer}' | |
| mixtral = HuggingFaceInferenceAPI( | |
| model_name="mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| #Mistral-7B-Instruct-v0.2 | |
| ) | |
| service_context = ServiceContext.from_defaults( | |
| llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5" | |
| ) | |
| download = download_file(file_link,file_link.split("/")[-1]) | |
| documents = SimpleDirectoryReader("content/").load_data() | |
| index = VectorStoreIndex.from_documents(documents,service_context=service_context) | |
| # Text QA Prompt | |
| chat_text_qa_msgs = [ | |
| ChatMessage( | |
| role=MessageRole.SYSTEM, | |
| content=( | |
| "Always answer the question, even if the context isn't helpful." | |
| ), | |
| ), | |
| ChatMessage( | |
| role=MessageRole.USER, | |
| content=( | |
| "Context information is below.\n" | |
| "---------------------\n" | |
| "{context_str}\n" | |
| "---------------------\n" | |
| "Given the context information and not prior knowledge, " | |
| "answer the question: {query_str}\n" | |
| ), | |
| ), | |
| ] | |
| text_qa_template = ChatPromptTemplate(chat_text_qa_msgs) | |
| # Refine Prompt | |
| chat_refine_msgs = [ | |
| ChatMessage( | |
| role=MessageRole.SYSTEM, | |
| content=( | |
| "Always answer the question, even if the context isn't helpful." | |
| ), | |
| ), | |
| ChatMessage( | |
| role=MessageRole.USER, | |
| content=( | |
| "We have the opportunity to refine the original answer " | |
| "(only if needed) with some more context below.\n" | |
| "------------\n" | |
| "{context_msg}\n" | |
| "------------\n" | |
| "Given the new context, refine the original answer to better " | |
| "answer the question: {query_str}. " | |
| "If the context isn't useful, output the original answer again.\n" | |
| "Original Answer: {existing_answer}" | |
| ), | |
| ), | |
| ] | |
| refine_template = ChatPromptTemplate(chat_refine_msgs) | |
| temperature = float(temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| stream= index.as_query_engine( | |
| text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6, temperature = temperature, | |
| max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty | |
| ).query(prompt) | |
| print(str(stream)) | |
| output_rag= str(stream) #"" | |
| #output_norag = mixtral.complete(prompt, details=True, similarity_top_k=6, temperature = temperature, | |
| # max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty = repetition_penalty) | |
| #for response in str(stream): | |
| # output += response | |
| # yield output | |
| #print(output_norag) | |
| #result = generator.run(prompt, generation_kwargs={"max_new_tokens": 350}) | |
| #output_norag = result["replies"][0] | |
| ### NORAG | |
| if rag_only == False: | |
| chat_text_qa_msgs_nr = [ | |
| ChatMessage( | |
| role=MessageRole.SYSTEM, | |
| content=( | |
| "Always answer the question" | |
| ), | |
| ), | |
| ChatMessage( | |
| role=MessageRole.USER, | |
| content=( | |
| "answer the question: {query_str}\n" | |
| ), | |
| ), | |
| ] | |
| text_qa_template_nr = ChatPromptTemplate(chat_text_qa_msgs_nr) | |
| # Refine Prompt | |
| chat_refine_msgs_nr = [ | |
| ChatMessage( | |
| role=MessageRole.SYSTEM, | |
| content=( | |
| "Always answer the question" | |
| ), | |
| ), | |
| ChatMessage( | |
| role=MessageRole.USER, | |
| content=( | |
| "answer the question: {query_str}. " | |
| "If the context isn't useful, output the original answer again.\n" | |
| "Original Answer: {existing_answer}" | |
| ), | |
| ), | |
| ] | |
| refine_template_nr = ChatPromptTemplate(chat_refine_msgs_nr) | |
| stream_nr= index.as_query_engine( | |
| text_qa_template=text_qa_template_nr, refine_template=refine_template_nr, similarity_top_k=6 | |
| ).query(prompt) | |
| ### | |
| output_norag = str(stream_nr) | |
| save_answer(prompt, output_rag, output_norag) | |
| return f'* Mixtral Output:\n{output_norag}\n\n* Mixtral + RAG Output:\n{output_rag}' | |
| return f'* Mixtral + RAG Output:\n{output_rag}' | |
| #for response in formatted_output: | |
| # output += response | |
| # yield output | |
| #return formatted_output | |
| def upload_file(files): | |
| file_paths = [file.name for file in files] | |
| return file_paths | |
| additional_inputs=[ | |
| gr.Checkbox( | |
| label="RAG Only", | |
| interactive=True, | |
| value= False | |
| ), | |
| gr.Textbox( | |
| label="File Link", | |
| max_lines=1, | |
| interactive=True, | |
| value= "https://arxiv.org/pdf/2401.10020.pdf" | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| value=1024, | |
| minimum=0, | |
| maximum=2048, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.90, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| ] | |
| examples=[["What is a trustworthy digital repository, where can you find this information?", None, None, None, None, None, None, ], | |
| ["What are things a repository must have?", None, None, None, None, None, None,], | |
| ["What principles should record creators follow?", None, None, None, None, None, None,], | |
| ["Write a very short summary of Data Sanitation Techniques by Edgar Dale, and write a citation in APA style.", None, None, None, None, None, None,], | |
| ["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None, None,], | |
| ["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None, None,], | |
| ] | |
| gr.ChatInterface( | |
| fn=generate, | |
| chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), | |
| additional_inputs=additional_inputs, | |
| title="RAG Demo", | |
| examples=examples, | |
| #concurrency_limit=20, | |
| ).queue().launch(show_api=False,debug=True,share=True) | |
| #iface = gr.Interface(fn=generate, inputs=["text"], outputs=["text", "text"], | |
| # additional_inputs=additional_inputs, title="RAG Demo", examples=examples) | |
| #iface.launch() | |