Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import time | |
| import os | |
| from spinoza_project.source.backend.llm_utils import ( | |
| get_llm_api, | |
| get_vectorstore_api, | |
| ) | |
| from spinoza_project.source.frontend.utils import ( | |
| init_env, | |
| parse_output_llm_with_sources, | |
| ) | |
| from spinoza_project.source.frontend.gradio_utils import ( | |
| get_sources, | |
| set_prompts, | |
| get_config, | |
| get_prompts, | |
| get_assets, | |
| get_theme, | |
| get_init_prompt, | |
| get_synthesis_prompt, | |
| get_qdrants, | |
| get_qdrants_public, | |
| start_agents, | |
| end_agents, | |
| next_call, | |
| zip_longest_fill, | |
| reformulate, | |
| answer, | |
| ) | |
| from assets.utils_javascript import ( | |
| accordion_trigger, | |
| accordion_trigger_end, | |
| accordion_trigger_spinoza, | |
| accordion_trigger_spinoza_end, | |
| update_footer, | |
| ) | |
| init_env() | |
| config = get_config() | |
| ## Loading Prompts | |
| print("Loading Prompts") | |
| prompts = get_prompts(config) | |
| chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config) | |
| synthesis_prompt_template = get_synthesis_prompt(config) | |
| ## Building LLM | |
| print("Building LLM") | |
| groq_model_name = ( | |
| config["groq_model_name"] if not os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME") else "" | |
| ) | |
| llm = get_llm_api(groq_model_name) | |
| ## Loading BDDs | |
| print("Loading Databases") | |
| qdrants = get_qdrants(config) | |
| if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"): | |
| bdd_presse = get_vectorstore_api("presse") | |
| bdd_afp = get_vectorstore_api("afp") | |
| else: | |
| qdrants_public = get_qdrants_public(config) | |
| qdrants = {**qdrants, **qdrants_public} | |
| bdd_presse = None | |
| bdd_afp = None | |
| ## Loading Assets | |
| css, source_information = get_assets() | |
| theme = get_theme() | |
| init_prompt = get_init_prompt() | |
| def reformulate_questions( | |
| question, | |
| llm=llm, | |
| chat_reformulation_prompts=chat_reformulation_prompts, | |
| config=config, | |
| ): | |
| for elt in zip_longest_fill( | |
| *[ | |
| reformulate(llm, chat_reformulation_prompts, question, tab, config=config) | |
| for tab in config["tabs"] | |
| ] | |
| ): | |
| time.sleep(0.02) | |
| yield elt | |
| def retrieve_sources( | |
| *questions, | |
| qdrants=qdrants, | |
| bdd_presse=bdd_presse, | |
| bdd_afp=bdd_afp, | |
| config=config, | |
| ): | |
| formated_sources, text_sources = get_sources( | |
| questions, qdrants, bdd_presse, bdd_afp, config | |
| ) | |
| return (formated_sources, *text_sources) | |
| def answer_questions( | |
| *questions_sources, llm=llm, chat_qa_prompts=chat_qa_prompts, config=config | |
| ): | |
| questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] | |
| sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] | |
| for elt in zip_longest_fill( | |
| *[ | |
| answer(llm, chat_qa_prompts, question, source, tab, config) | |
| for question, source, tab in zip(questions, sources, config["tabs"]) | |
| ] | |
| ): | |
| time.sleep(0.02) | |
| yield [ | |
| [(question, parse_output_llm_with_sources(ans))] | |
| for question, ans in zip(questions, elt) | |
| ] | |
| def get_synthesis( | |
| question, | |
| *answers, | |
| llm=llm, | |
| synthesis_prompt_template=synthesis_prompt_template, | |
| config=config, | |
| ): | |
| answer = [] | |
| for i, tab in enumerate(config["tabs"]): | |
| if len(str(answers[i])) >= 100: | |
| answer.append( | |
| f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") | |
| ) | |
| if len(answer) == 0: | |
| return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" | |
| else: | |
| for elt in llm.stream( | |
| synthesis_prompt_template, | |
| { | |
| "question": question.replace("<p>", "").replace("</p>\n", ""), | |
| "answers": "\n\n".join(answer), | |
| }, | |
| ): | |
| time.sleep(0.01) | |
| yield [(question, parse_output_llm_with_sources(elt))] | |
| with gr.Blocks( | |
| title=f"🔍 Spinoza", | |
| css=css, | |
| js=update_footer(), | |
| theme=theme, | |
| ) as demo: | |
| chatbots = {} | |
| question = gr.State("") | |
| docs_textbox = gr.State([""]) | |
| agent_questions = {elt: gr.State("") for elt in config["tabs"]} | |
| component_sources = {elt: gr.State("") for elt in config["tabs"]} | |
| text_sources = {elt: gr.State("") for elt in config["tabs"]} | |
| tab_states = {elt: gr.State(elt) for elt in config["tabs"]} | |
| with gr.Tab("Q&A", elem_id="main-component"): | |
| with gr.Row(elem_id="chatbot-row"): | |
| with gr.Column(scale=2, elem_id="center-panel"): | |
| with gr.Group(elem_id="chatbot-group"): | |
| for tab in list(config["tabs"].keys()) + ["Spinoza"]: | |
| if tab == "Spinoza": | |
| agent_name = f"Spinoza" | |
| elem_id = f"accordion-{tab}" | |
| elem_classes = "accordion accordion-agent spinoza-agent" | |
| else: | |
| agent_name = f"Agent {config['source_mapping'][tab]}" | |
| elem_id = f"accordion-{config['source_mapping'][tab]}" | |
| elem_classes = "accordion accordion-agent" | |
| with gr.Accordion( | |
| agent_name, | |
| open=True if agent_name == "Spinoza" else False, | |
| elem_id=elem_id, | |
| elem_classes=elem_classes, | |
| ): | |
| # chatbot_key = agent_name.lower().replace(" ", "_") | |
| chatbots[tab] = gr.Chatbot( | |
| value=( | |
| [(None, init_prompt)] | |
| if agent_name == "Spinoza" | |
| else None | |
| ), | |
| show_copy_button=True, | |
| show_share_button=False, | |
| show_label=False, | |
| elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}", | |
| layout="panel", | |
| avatar_images=( | |
| "./assets/logos/help.png", | |
| ( | |
| "./assets/logos/spinoza.png" | |
| if agent_name == "Spinoza" | |
| else None | |
| ), | |
| ), | |
| ) | |
| with gr.Row(elem_id="input-message"): | |
| ask = gr.Textbox( | |
| placeholder="Ask me anything here!", | |
| show_label=False, | |
| scale=7, | |
| lines=1, | |
| interactive=True, | |
| elem_id="input-textbox", | |
| ) | |
| with gr.Column(scale=1, variant="panel", elem_id="right-panel"): | |
| with gr.TabItem("Sources", elem_id="tab-sources", id=0): | |
| sources_textbox = gr.HTML( | |
| show_label=False, elem_id="sources-textbox" | |
| ) | |
| with gr.Tab("Source information", elem_id="source-component"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(source_information) | |
| with gr.Tab("Contact", elem_id="contact-component"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("For any issue contact **[email protected]**.") | |
| ask.submit( | |
| start_agents, inputs=[], outputs=[chatbots["Spinoza"]], js=accordion_trigger() | |
| ).then( | |
| fn=reformulate_questions, | |
| inputs=[ask], | |
| outputs=[agent_questions[tab] for tab in config["tabs"]], | |
| ).then( | |
| fn=retrieve_sources, | |
| inputs=[agent_questions[tab] for tab in config["tabs"]], | |
| outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], | |
| ).then( | |
| fn=answer_questions, | |
| inputs=[agent_questions[tab] for tab in config["tabs"]] | |
| + [text_sources[tab] for tab in config["tabs"]], | |
| outputs=[chatbots[tab] for tab in config["tabs"]], | |
| ).then( | |
| fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() | |
| ).then( | |
| fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() | |
| ).then( | |
| fn=get_synthesis, | |
| inputs=[agent_questions[list(config["tabs"].keys())[1]]] | |
| + [chatbots[tab] for tab in config["tabs"]], | |
| outputs=[chatbots["Spinoza"]], | |
| ).then( | |
| fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() | |
| ).then( | |
| fn=end_agents, inputs=[], outputs=[] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) | |