Spaces:
Sleeping
Sleeping
| from saged import Pipeline | |
| from tqdm import tqdm | |
| from pathlib import Path | |
| from saged import SAGEDData as dt | |
| import streamlit as st | |
| import json | |
| import http.client | |
| from openai import AzureOpenAI | |
| import ollama | |
| import time # Use time.sleep to simulate processing steps | |
| import logging | |
| from io import StringIO | |
| import sys | |
| # Create a custom logging handler to capture log messages | |
| class StreamlitLogHandler(logging.Handler): | |
| def __init__(self): | |
| super().__init__() | |
| self.log_capture_string = StringIO() | |
| def emit(self, record): | |
| # Write each log message to the StringIO buffer | |
| message = self.format(record) | |
| self.log_capture_string.write(message + "\n") | |
| def get_logs(self): | |
| # Return the log contents | |
| return self.log_capture_string.getvalue() | |
| def clear_logs(self): | |
| # Clear the log buffer | |
| self.log_capture_string.truncate(0) | |
| self.log_capture_string.seek(0) | |
| # Define ContentFormatter class | |
| class ContentFormatter: | |
| def chat_completions(text, settings_params): | |
| message = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": text} | |
| ] | |
| data = {"messages": message, **settings_params} | |
| return json.dumps(data) | |
| # Define OllamaModel (For local Ollama interaction) | |
| class OllamaModel: | |
| def __init__(self, base_model='llama3', system_prompt='You are a helpful assistant', model_name='llama3o', | |
| **kwargs): | |
| self.base_model = base_model | |
| self.model_name = model_name | |
| self.model_create(model_name, system_prompt, base_model, **kwargs) | |
| def model_create(self, model_name, system_prompt, base_model, **kwargs): | |
| modelfile = f'FROM {base_model}\nSYSTEM {system_prompt}\n' | |
| if kwargs: | |
| for key, value in kwargs.items(): | |
| modelfile += f'PARAMETER {key.lower()} {value}\n' | |
| ollama.create(model=model_name, modelfile=modelfile) | |
| def invoke(self, prompt): | |
| answer = ollama.generate(model=self.model_name, prompt=prompt) | |
| return answer['response'] | |
| # Define GPTAgent (For OpenAI GPT models) | |
| class GPTAgent: | |
| def __init__(self, model_name, azure_key, azure_version, azure_endpoint, deployment_name): | |
| self.client = AzureOpenAI( | |
| api_key=azure_key, | |
| api_version=azure_version, | |
| azure_endpoint=azure_endpoint | |
| ) | |
| self.deployment_name = deployment_name | |
| def invoke(self, prompt, settings_params=None): | |
| if not settings_params: | |
| settings_params = {} | |
| formatted_input = ContentFormatter.chat_completions(prompt, settings_params) | |
| response = self.client.chat.completions.create( | |
| model=self.deployment_name, | |
| messages=json.loads(formatted_input)['messages'], | |
| **settings_params | |
| ) | |
| return response.choices[0].message.content | |
| # Define AzureAgent (For Azure OpenAI models) | |
| class AzureAgent: | |
| def __init__(self, model_name, azure_uri, azure_api_key): | |
| self.azure_uri = azure_uri | |
| self.headers = { | |
| 'Authorization': f"Bearer {azure_api_key}", | |
| 'Content-Type': 'application/json' | |
| } | |
| self.chat_formatter = ContentFormatter | |
| def invoke(self, prompt, settings_params=None): | |
| if not settings_params: | |
| settings_params = {} | |
| body = self.chat_formatter.chat_completions(prompt, {**settings_params}) | |
| conn = http.client.HTTPSConnection(self.azure_uri) | |
| conn.request("POST", '/v1/chat/completions', body=body, headers=self.headers) | |
| response = conn.getresponse() | |
| data = response.read() | |
| conn.close() | |
| decoded_data = data.decode("utf-8") | |
| parsed_data = json.loads(decoded_data) | |
| content = parsed_data["choices"][0]["message"]["content"] | |
| return content | |
| # Renew Source Finder Button | |
| def renew_source_finder(domain, concept_list): | |
| if 'generated_synthetic_files' in st.session_state: | |
| del st.session_state['generated_synthetic_files'] | |
| if not domain or not concept_list: | |
| st.error("Please fill in all the required fields before proceeding.") | |
| else: | |
| with st.spinner("Renewing source info files..."): | |
| base_path = Path('data/customized/source_finder/') | |
| for concept in concept_list: | |
| file_path = base_path / f'{domain}_{concept}_source_finder.json' | |
| if file_path.exists(): | |
| try: | |
| file_path.unlink() # Delete the file | |
| st.info(f"Deleted source info file: {file_path}") | |
| except Exception as e: | |
| st.error(f"An error occurred while deleting the file {file_path}: {e}") | |
| st.success("Source info files renewal completed!") | |
| def create_source_finder(domain, concept): | |
| source_specification_item = f"data/customized/local_files/{domain}/{concept}.txt" | |
| if not Path(source_specification_item).exists(): | |
| st.warning(f"Local file does not exist: {source_specification_item}") | |
| instance = dt.create_data(domain, concept, 'source_finder') | |
| instance.data[0]['keywords'] = {concept: dt.default_keyword_metadata.copy()} | |
| category_shared_source_item = dt.default_source_item.copy() | |
| category_shared_source_item['source_type'] = "local_paths" | |
| category_shared_source_item['source_specification'] = [source_specification_item] | |
| instance.data[0]['category_shared_source'] = [category_shared_source_item] | |
| return instance.data.copy() | |
| def check_and_create_source_files(domain, concept_list): | |
| """ | |
| Checks if the required source finder files exist for each concept in the domain. | |
| If a file does not exist or is invalid, it creates an empty JSON file for that concept. | |
| """ | |
| base_path = Path('data/customized/source_finder/') | |
| base_path.mkdir(parents=True, exist_ok=True) | |
| for concept in concept_list: | |
| file_path = base_path / f'{domain}_{concept}_source_finder.json' | |
| if not file_path.exists(): | |
| # Create a new source finder file using create_source_finder | |
| data = create_source_finder(domain, concept) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=4) | |
| st.info(f"Created missing source finder file: {file_path}") | |
| else: | |
| # Attempt to load the file to verify its validity | |
| instance = dt.load_file(domain, concept, 'source_finder', file_path) | |
| if instance is None: | |
| # If loading fails, create a new valid file | |
| data = create_source_finder(domain, concept) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=4) | |
| st.info(f"Recreated invalid source finder file: {file_path}") | |
| def clean_spaces(data): | |
| """ | |
| Removes trailing or leading spaces from a string or from each element in a list. | |
| """ | |
| if isinstance(data, str): | |
| return data.strip() | |
| elif isinstance(data, list): | |
| return [item.strip() if isinstance(item, str) else item for item in data] | |
| else: | |
| raise TypeError("Input should be either a string or a list of strings") | |
| def create_replacement_dict(concept_list, replacer): | |
| replacement = {} | |
| for concept in concept_list: | |
| replacement[concept] = {} | |
| for company in replacer: | |
| replacement[concept][company] = {concept: company} | |
| return replacement | |
| # Title of the app | |
| st.title("SAGED-bias Benchmark-Building Demo") | |
| # Initialize session state variables | |
| if 'domain' not in st.session_state: | |
| st.session_state['domain'] = None | |
| if 'concept_list' not in st.session_state: | |
| st.session_state['concept_list'] = None | |
| if 'gpt_model' not in st.session_state: | |
| st.session_state['gpt_model'] = None | |
| if 'azure_model' not in st.session_state: | |
| st.session_state['azure_model'] = None | |
| if 'ollama_model' not in st.session_state: | |
| st.session_state['ollama_model'] = None | |
| # Sidebar: Model Selection | |
| with st.sidebar: | |
| st.header("Model Configuration") | |
| # Selection of which model to use | |
| model_selection = st.radio("Select Model Type", ['GPT-Azure', 'Azure', 'Ollama']) | |
| # Collapsible Additional Configuration Section | |
| with st.expander("Model Configuration"): | |
| if model_selection == 'Ollama': | |
| # Ollama Configuration | |
| ollama_deployment_name = st.text_input("Enter Ollama Model Deployment Name", placeholder="e.g., llama3") | |
| ollama_system_prompt = st.text_input("Enter System Prompt for Ollama", | |
| placeholder="e.g., You are a helpful assistant.") | |
| if ollama_deployment_name and ollama_system_prompt: | |
| confirm_ollama = st.button("Confirm Ollama Configuration") | |
| if confirm_ollama: | |
| st.session_state['ollama_model'] = OllamaModel( | |
| model_name=ollama_deployment_name, | |
| system_prompt=ollama_system_prompt | |
| ) | |
| st.success("Ollama model configured successfully.") | |
| else: | |
| st.warning("Please provide both Ollama deployment name and system prompt.") | |
| elif model_selection == 'GPT-Azure' or model_selection == 'Azure': | |
| # GPT / Azure Configuration | |
| gpt_azure_endpoint = st.text_input("Enter Azure Endpoint URL", | |
| placeholder="e.g., https://your-resource-name.openai.azure.com/") | |
| gpt_azure_api_key = st.text_input("Enter Azure API Key", type="password") | |
| gpt_azure_model_name = st.text_input("Enter Azure Model Name", placeholder="e.g., GPT-3.5-turbo") | |
| gpt_azure_deployment_name = st.text_input("Enter Azure Deployment Name", | |
| placeholder="e.g., gpt-3-5-deployment") | |
| if gpt_azure_endpoint and gpt_azure_api_key and gpt_azure_model_name and gpt_azure_deployment_name: | |
| confirm_gpt_azure = st.button("Confirm GPT/Azure Configuration") | |
| if confirm_gpt_azure: | |
| if model_selection == 'GPT-Azure': | |
| st.session_state['gpt_model'] = GPTAgent( | |
| model_name=gpt_azure_model_name, | |
| azure_key=gpt_azure_api_key, | |
| azure_version='2023-05-15', # Update if necessary | |
| azure_endpoint=gpt_azure_endpoint, | |
| deployment_name=gpt_azure_deployment_name | |
| ) | |
| st.success("GPT model configured successfully.") | |
| elif model_selection == 'Azure': | |
| st.session_state['azure_model'] = AzureAgent( | |
| model_name=gpt_azure_model_name, | |
| azure_uri=gpt_azure_endpoint, | |
| azure_api_key=gpt_azure_api_key | |
| ) | |
| st.success("Azure model configured successfully.") | |
| else: | |
| st.warning("Please provide all fields for GPT/Azure configuration.") | |
| # Main interaction based on configured model | |
| if st.session_state.get('ollama_model'): | |
| model = st.session_state['ollama_model'] | |
| elif st.session_state.get('gpt_model'): | |
| model = st.session_state['gpt_model'] | |
| elif st.session_state.get('azure_model'): | |
| model = st.session_state['azure_model'] | |
| else: | |
| model = None | |
| # User input: Domain and Concepts | |
| with st.form(key='domain_concept_form'): | |
| domain = clean_spaces( | |
| st.text_input("Enter the domain: (e.g., Stocks, Education)", placeholder="Enter domain here...")) | |
| # User input: Concepts | |
| concept_text = st.text_area("Enter the concepts (separated by commas):", | |
| placeholder="e.g., excel-stock, ok-stock, bad-stock") | |
| concept_list = clean_spaces(concept_text.split(',')) | |
| submit_button = st.form_submit_button(label='Confirm Domain and Concepts') | |
| if submit_button: | |
| if not domain: | |
| st.warning("Please enter a domain.") | |
| elif not concept_list or concept_text.strip() == "": | |
| st.warning("Please enter at least one concept.") | |
| else: | |
| st.session_state['domain'] = domain | |
| st.session_state['concept_list'] = concept_list | |
| st.success("Domain and concepts confirmed.") | |
| # Display further options only after domain and concepts are confirmed | |
| if st.session_state['domain'] and st.session_state['concept_list']: | |
| with st.expander("Additional Options"): | |
| # User input: Method | |
| scraper_method = st.radio("Select the scraper method:", (('wiki', 'local_files', 'synthetic_files'))) | |
| # Initiate the source_finder_requirement and keyword_finder_requirement if 'wiki' is selected | |
| if scraper_method == 'wiki': | |
| st.session_state['keyword_finder_requirement'] = True | |
| st.session_state['source_finder_requirement'] = True | |
| st.session_state['check_source_finder'] = False | |
| # File upload for each concept if 'local_files' is selected | |
| if scraper_method == 'local_files': | |
| uploaded_files = {} | |
| st.session_state['keyword_finder_requirement'] = False | |
| st.session_state['source_finder_requirement'] = False | |
| st.session_state['check_source_finder'] = True | |
| for concept in st.session_state['concept_list']: | |
| uploaded_file = st.file_uploader(f"Upload file for concept '{concept}':", type=['txt'], | |
| key=f"file_{concept}") | |
| if uploaded_file: | |
| uploaded_files[concept] = uploaded_file | |
| # Save uploaded file | |
| save_path = Path(f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt") | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, 'wb') as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.success(f"File for concept '{concept}' saved successfully.") | |
| # Generate synthetic files if 'synthetic_files' is selected | |
| if scraper_method == 'synthetic_files': | |
| scraper_method = 'local_files' | |
| st.session_state['keyword_finder_requirement'] = False | |
| st.session_state['source_finder_requirement'] = False | |
| st.session_state['check_source_finder'] = True | |
| if 'generated_synthetic_files' not in st.session_state: | |
| st.session_state['generated_synthetic_files'] = set() | |
| prompt_inputs = {} | |
| for concept in st.session_state['concept_list']: | |
| if concept not in st.session_state['generated_synthetic_files']: | |
| prompt_inputs[concept] = st.text_input( | |
| f"Enter the prompt for concept '{concept}':", | |
| value=f"Write a long article introducing the {concept} in the {st.session_state['domain']}. Use the {concept} as much as possible.", | |
| key=f"prompt_{concept}" | |
| ) | |
| if st.button("Generate Synthetic Files for All Concepts"): | |
| if model: | |
| for concept, prompt in prompt_inputs.items(): | |
| if prompt: | |
| with st.spinner(f"Generating content for concept '{concept}'..."): | |
| synthetic_content = model.invoke(prompt) | |
| save_path = Path( | |
| f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt") | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, 'w', encoding='utf-8') as f: | |
| f.write(synthetic_content) | |
| st.session_state['generated_synthetic_files'].add(concept) | |
| st.success(f"Synthetic file for concept '{concept}' created successfully.") | |
| else: | |
| st.warning("Please configure a model to generate synthetic files.") | |
| # User input: Prompt Method | |
| prompt_method = st.radio("Select the prompt method:", ('split_sentences', 'questions'), index = 0) | |
| # User input: Max Benchmark Length | |
| max_benchmark_length = st.slider("Select the maximum prompts per concepts:", 1, 199, 10) | |
| # User input: Branching | |
| branching = st.radio("Enable branching:", ('Yes', 'No'), index=1) | |
| branching_enabled = True if branching == 'Yes' else False | |
| # User input: Replacer (only if branching is enabled) | |
| replacer = [] | |
| replacement = {} | |
| if branching_enabled: | |
| replacer_text = st.text_area("Enter the replacer list (list of strings, separated by commas):", | |
| placeholder="e.g., Company A, Company B") | |
| replacer = clean_spaces(replacer_text.split(',')) | |
| replacement = create_replacement_dict(st.session_state['concept_list'], replacer) | |
| # Configuration | |
| concept_specified_config = { | |
| x: {'keyword_finder': {'manual_keywords': [x]}} for x in st.session_state['concept_list'] | |
| } | |
| concept_configuration = { | |
| 'keyword_finder': { | |
| 'require': st.session_state['keyword_finder_requirement'], | |
| 'keyword_number': 1, | |
| }, | |
| 'source_finder': { | |
| 'require': st.session_state['source_finder_requirement'], | |
| 'scrap_number': 10, | |
| 'method': scraper_method, | |
| }, | |
| 'scraper': { | |
| 'require': True, | |
| 'method': scraper_method, | |
| }, | |
| 'prompt_maker': { | |
| 'method': prompt_method, | |
| 'generation_function': model.invoke if model else None, | |
| 'max_benchmark_length': max_benchmark_length, | |
| }, | |
| } | |
| domain_configuration = { | |
| 'categories': st.session_state['concept_list'], | |
| 'branching': branching_enabled, | |
| 'branching_config': { | |
| 'generation_function': model.invoke if model else None, | |
| 'keyword_reference': st.session_state['concept_list'], | |
| 'replacement_descriptor_require': False, | |
| 'replacement_description': replacement, | |
| 'branching_pairs': 'not all', | |
| 'direction': 'not both', | |
| }, | |
| 'shared_config': concept_configuration, | |
| 'category_specified_config': concept_specified_config | |
| } | |
| # Renew Source Finder Button | |
| if st.button('Renew Source info'): | |
| renew_source_finder(st.session_state['domain'], st.session_state['concept_list']) | |
| # Save the original stdout to print to the terminal if needed later | |
| original_stdout = sys.stdout | |
| # Define StreamToText to capture and display logs in real-time within Streamlit only | |
| class StreamToText: | |
| def __init__(self): | |
| self.output = StringIO() | |
| def write(self, message): | |
| if message.strip(): # Avoid adding empty messages | |
| # Only append to Streamlit display, not the terminal | |
| st.session_state.log_messages.append(message.strip()) | |
| log_placeholder.text("\n".join(st.session_state.log_messages)) # Flush updated logs | |
| def flush(self): | |
| pass # Required for compatibility with sys.stdout | |
| # Initialize session state for log messages | |
| if 'log_messages' not in st.session_state: | |
| st.session_state.log_messages = [] | |
| # Replace sys.stdout with our custom StreamToText instance | |
| stream_to_text = StreamToText() | |
| sys.stdout = stream_to_text | |
| # Placeholder for displaying logs within a collapsible expander | |
| with st.expander("Show Logs", expanded=False): | |
| log_placeholder = st.empty() # Placeholder for dynamic log display | |
| # Define the Create Benchmark button | |
| if st.button("Create a Benchmark"): | |
| st.session_state.log_messages = [] # Clear previous logs | |
| with st.spinner("Creating benchmark..."): | |
| if st.session_state['check_source_finder']: | |
| # Check for relevant materials | |
| check_and_create_source_files(st.session_state['domain'], st.session_state['concept_list']) | |
| try: | |
| # Display progress bar and log messages | |
| progress_bar = st.progress(0) | |
| for i in tqdm(range(1, 101)): | |
| progress_bar.progress(i) | |
| time.sleep(0.05) # Short delay to simulate processing time | |
| # Run the benchmark creation function | |
| benchmark = Pipeline.domain_benchmark_building(st.session_state['domain'], domain_configuration) | |
| st.success("Benchmark creation completed!") | |
| st.dataframe(benchmark.data) | |
| except Exception as e: | |
| st.error(f"An error occurred during benchmark creation: {e}") | |