Spaces:
Sleeping
Sleeping
| import argparse | |
| from dataclasses import asdict | |
| import json | |
| import os | |
| import streamlit as st | |
| from data_driven_characters.character import get_character_definition | |
| from data_driven_characters.corpus import ( | |
| get_corpus_summaries, | |
| load_docs, | |
| ) | |
| from data_driven_characters.chatbots import ( | |
| SummaryChatBot, | |
| RetrievalChatBot, | |
| SummaryRetrievalChatBot, | |
| ) | |
| from data_driven_characters.interfaces import CommandLine, Streamlit | |
| OUTPUT_ROOT = "output" | |
| def create_chatbot(corpus, character_name, chatbot_type, retrieval_docs, summary_type): | |
| # logging | |
| corpus_name = os.path.splitext(os.path.basename(corpus))[0] | |
| output_dir = f"{OUTPUT_ROOT}/{corpus_name}/summarytype_{summary_type}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| summaries_dir = f"{output_dir}/summaries" | |
| character_definitions_dir = f"{output_dir}/character_definitions" | |
| os.makedirs(character_definitions_dir, exist_ok=True) | |
| # load docs | |
| docs = load_docs(corpus_path=corpus, chunk_size=2048, chunk_overlap=64) | |
| # generate summaries | |
| corpus_summaries = get_corpus_summaries( | |
| docs=docs, summary_type=summary_type, cache_dir=summaries_dir | |
| ) | |
| # get character definition | |
| character_definition = get_character_definition( | |
| name=character_name, | |
| corpus_summaries=corpus_summaries, | |
| cache_dir=character_definitions_dir, | |
| ) | |
| print(json.dumps(asdict(character_definition), indent=4)) | |
| # construct retrieval documents | |
| if retrieval_docs == "raw": | |
| documents = [ | |
| doc.page_content | |
| for doc in load_docs(corpus_path=corpus, chunk_size=256, chunk_overlap=16) | |
| ] | |
| elif retrieval_docs == "summarized": | |
| documents = corpus_summaries | |
| else: | |
| raise ValueError(f"Unknown retrieval docs type: {retrieval_docs}") | |
| # initialize chatbot | |
| if chatbot_type == "summary": | |
| chatbot = SummaryChatBot(character_definition=character_definition) | |
| elif chatbot_type == "retrieval": | |
| chatbot = RetrievalChatBot( | |
| character_definition=character_definition, | |
| documents=documents, | |
| ) | |
| elif chatbot_type == "summary_retrieval": | |
| chatbot = SummaryRetrievalChatBot( | |
| character_definition=character_definition, | |
| documents=documents, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown chatbot type: {chatbot_type}") | |
| return chatbot | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--corpus", type=str, default="data/everything_everywhere_all_at_once.txt" | |
| ) | |
| parser.add_argument("--character_name", type=str, default="Evelyn") | |
| parser.add_argument( | |
| "--chatbot_type", | |
| type=str, | |
| default="summary_retrieval", | |
| choices=["summary", "retrieval", "summary_retrieval"], | |
| ) | |
| parser.add_argument( | |
| "--summary_type", | |
| type=str, | |
| default="map_reduce", | |
| choices=["map_reduce", "refine"], | |
| ) | |
| parser.add_argument( | |
| "--retrieval_docs", | |
| type=str, | |
| default="summarized", | |
| choices=["raw", "summarized"], | |
| ) | |
| parser.add_argument( | |
| "--interface", type=str, default="cli", choices=["cli", "streamlit"] | |
| ) | |
| args = parser.parse_args() | |
| if args.interface == "cli": | |
| chatbot = create_chatbot( | |
| args.corpus, | |
| args.character_name, | |
| args.chatbot_type, | |
| args.retrieval_docs, | |
| args.summary_type, | |
| ) | |
| app = CommandLine(chatbot=chatbot) | |
| elif args.interface == "streamlit": | |
| chatbot = st.cache_resource(create_chatbot)( | |
| args.corpus, | |
| args.character_name, | |
| args.chatbot_type, | |
| args.retrieval_docs, | |
| args.summary_type, | |
| ) | |
| st.title("Data Driven Characters") | |
| st.write("Create your own character chatbots, grounded in existing corpora.") | |
| st.divider() | |
| st.markdown(f"**chatbot type**: *{args.chatbot_type}*") | |
| if "retrieval" in args.chatbot_type: | |
| st.markdown(f"**retrieving from**: *{args.retrieval_docs} corpus*") | |
| app = Streamlit(chatbot=chatbot) | |
| else: | |
| raise ValueError(f"Unknown interface: {args.interface}") | |
| app.run() | |
| if __name__ == "__main__": | |
| main() | |