Spaces:
Sleeping
Sleeping
Commit
·
29a316b
1
Parent(s):
40ea049
init
Browse files- app.py +479 -0
- requirements.txt +6 -0
app.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from saged import Pipeline
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from saged import SAGEDData as dt
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import json
|
| 7 |
+
import http.client
|
| 8 |
+
from openai import AzureOpenAI
|
| 9 |
+
import ollama
|
| 10 |
+
import time # Use time.sleep to simulate processing steps
|
| 11 |
+
import logging
|
| 12 |
+
from io import StringIO
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
# Create a custom logging handler to capture log messages
|
| 16 |
+
class StreamlitLogHandler(logging.Handler):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.log_capture_string = StringIO()
|
| 20 |
+
|
| 21 |
+
def emit(self, record):
|
| 22 |
+
# Write each log message to the StringIO buffer
|
| 23 |
+
message = self.format(record)
|
| 24 |
+
self.log_capture_string.write(message + "\n")
|
| 25 |
+
|
| 26 |
+
def get_logs(self):
|
| 27 |
+
# Return the log contents
|
| 28 |
+
return self.log_capture_string.getvalue()
|
| 29 |
+
|
| 30 |
+
def clear_logs(self):
|
| 31 |
+
# Clear the log buffer
|
| 32 |
+
self.log_capture_string.truncate(0)
|
| 33 |
+
self.log_capture_string.seek(0)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Define ContentFormatter class
|
| 38 |
+
class ContentFormatter:
|
| 39 |
+
@staticmethod
|
| 40 |
+
def chat_completions(text, settings_params):
|
| 41 |
+
message = [
|
| 42 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 43 |
+
{"role": "user", "content": text}
|
| 44 |
+
]
|
| 45 |
+
data = {"messages": message, **settings_params}
|
| 46 |
+
return json.dumps(data)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Define OllamaModel (For local Ollama interaction)
|
| 50 |
+
class OllamaModel:
|
| 51 |
+
def __init__(self, base_model='llama3', system_prompt='You are a helpful assistant', model_name='llama3o',
|
| 52 |
+
**kwargs):
|
| 53 |
+
self.base_model = base_model
|
| 54 |
+
self.model_name = model_name
|
| 55 |
+
self.model_create(model_name, system_prompt, base_model, **kwargs)
|
| 56 |
+
|
| 57 |
+
def model_create(self, model_name, system_prompt, base_model, **kwargs):
|
| 58 |
+
modelfile = f'FROM {base_model}\nSYSTEM {system_prompt}\n'
|
| 59 |
+
if kwargs:
|
| 60 |
+
for key, value in kwargs.items():
|
| 61 |
+
modelfile += f'PARAMETER {key.lower()} {value}\n'
|
| 62 |
+
ollama.create(model=model_name, modelfile=modelfile)
|
| 63 |
+
|
| 64 |
+
def invoke(self, prompt):
|
| 65 |
+
answer = ollama.generate(model=self.model_name, prompt=prompt)
|
| 66 |
+
return answer['response']
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Define GPTAgent (For OpenAI GPT models)
|
| 70 |
+
class GPTAgent:
|
| 71 |
+
def __init__(self, model_name, azure_key, azure_version, azure_endpoint, deployment_name):
|
| 72 |
+
self.client = AzureOpenAI(
|
| 73 |
+
api_key=azure_key,
|
| 74 |
+
api_version=azure_version,
|
| 75 |
+
azure_endpoint=azure_endpoint
|
| 76 |
+
)
|
| 77 |
+
self.deployment_name = deployment_name
|
| 78 |
+
|
| 79 |
+
def invoke(self, prompt, settings_params=None):
|
| 80 |
+
if not settings_params:
|
| 81 |
+
settings_params = {}
|
| 82 |
+
formatted_input = ContentFormatter.chat_completions(prompt, settings_params)
|
| 83 |
+
response = self.client.chat.completions.create(
|
| 84 |
+
model=self.deployment_name,
|
| 85 |
+
messages=json.loads(formatted_input)['messages'],
|
| 86 |
+
**settings_params
|
| 87 |
+
)
|
| 88 |
+
return response.choices[0].message.content
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# Define AzureAgent (For Azure OpenAI models)
|
| 92 |
+
class AzureAgent:
|
| 93 |
+
def __init__(self, model_name, azure_uri, azure_api_key):
|
| 94 |
+
self.azure_uri = azure_uri
|
| 95 |
+
self.headers = {
|
| 96 |
+
'Authorization': f"Bearer {azure_api_key}",
|
| 97 |
+
'Content-Type': 'application/json'
|
| 98 |
+
}
|
| 99 |
+
self.chat_formatter = ContentFormatter
|
| 100 |
+
|
| 101 |
+
def invoke(self, prompt, settings_params=None):
|
| 102 |
+
if not settings_params:
|
| 103 |
+
settings_params = {}
|
| 104 |
+
body = self.chat_formatter.chat_completions(prompt, {**settings_params})
|
| 105 |
+
conn = http.client.HTTPSConnection(self.azure_uri)
|
| 106 |
+
conn.request("POST", '/v1/chat/completions', body=body, headers=self.headers)
|
| 107 |
+
response = conn.getresponse()
|
| 108 |
+
data = response.read()
|
| 109 |
+
conn.close()
|
| 110 |
+
decoded_data = data.decode("utf-8")
|
| 111 |
+
parsed_data = json.loads(decoded_data)
|
| 112 |
+
content = parsed_data["choices"][0]["message"]["content"]
|
| 113 |
+
return content
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Renew Source Finder Button
|
| 118 |
+
def renew_source_finder(domain, concept_list):
|
| 119 |
+
if 'generated_synthetic_files' in st.session_state:
|
| 120 |
+
del st.session_state['generated_synthetic_files']
|
| 121 |
+
if not domain or not concept_list:
|
| 122 |
+
st.error("Please fill in all the required fields before proceeding.")
|
| 123 |
+
else:
|
| 124 |
+
with st.spinner("Renewing source info files..."):
|
| 125 |
+
base_path = Path('data/customized/source_finder/')
|
| 126 |
+
for concept in concept_list:
|
| 127 |
+
file_path = base_path / f'{domain}_{concept}_source_finder.json'
|
| 128 |
+
if file_path.exists():
|
| 129 |
+
try:
|
| 130 |
+
file_path.unlink() # Delete the file
|
| 131 |
+
st.info(f"Deleted source info file: {file_path}")
|
| 132 |
+
except Exception as e:
|
| 133 |
+
st.error(f"An error occurred while deleting the file {file_path}: {e}")
|
| 134 |
+
st.success("Source info files renewal completed!")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def create_source_finder(domain, concept):
|
| 138 |
+
source_specification_item = f"data/customized/local_files/{domain}/{concept}.txt"
|
| 139 |
+
if not Path(source_specification_item).exists():
|
| 140 |
+
st.warning(f"Local file does not exist: {source_specification_item}")
|
| 141 |
+
instance = dt.create_data(domain, concept, 'source_finder')
|
| 142 |
+
instance.data[0]['keywords'] = {concept: dt.default_keyword_metadata.copy()}
|
| 143 |
+
category_shared_source_item = dt.default_source_item.copy()
|
| 144 |
+
category_shared_source_item['source_type'] = "local_paths"
|
| 145 |
+
category_shared_source_item['source_specification'] = [source_specification_item]
|
| 146 |
+
instance.data[0]['category_shared_source'] = [category_shared_source_item]
|
| 147 |
+
return instance.data.copy()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def check_and_create_source_files(domain, concept_list):
|
| 151 |
+
"""
|
| 152 |
+
Checks if the required source finder files exist for each concept in the domain.
|
| 153 |
+
If a file does not exist or is invalid, it creates an empty JSON file for that concept.
|
| 154 |
+
"""
|
| 155 |
+
base_path = Path('data/customized/source_finder/')
|
| 156 |
+
base_path.mkdir(parents=True, exist_ok=True)
|
| 157 |
+
for concept in concept_list:
|
| 158 |
+
file_path = base_path / f'{domain}_{concept}_source_finder.json'
|
| 159 |
+
if not file_path.exists():
|
| 160 |
+
# Create a new source finder file using create_source_finder
|
| 161 |
+
data = create_source_finder(domain, concept)
|
| 162 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 163 |
+
json.dump(data, f, indent=4)
|
| 164 |
+
st.info(f"Created missing source finder file: {file_path}")
|
| 165 |
+
else:
|
| 166 |
+
# Attempt to load the file to verify its validity
|
| 167 |
+
instance = dt.load_file(domain, concept, 'source_finder', file_path)
|
| 168 |
+
if instance is None:
|
| 169 |
+
# If loading fails, create a new valid file
|
| 170 |
+
data = create_source_finder(domain, concept)
|
| 171 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 172 |
+
json.dump(data, f, indent=4)
|
| 173 |
+
st.info(f"Recreated invalid source finder file: {file_path}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def clean_spaces(data):
|
| 177 |
+
"""
|
| 178 |
+
Removes trailing or leading spaces from a string or from each element in a list.
|
| 179 |
+
"""
|
| 180 |
+
if isinstance(data, str):
|
| 181 |
+
return data.strip()
|
| 182 |
+
elif isinstance(data, list):
|
| 183 |
+
return [item.strip() if isinstance(item, str) else item for item in data]
|
| 184 |
+
else:
|
| 185 |
+
raise TypeError("Input should be either a string or a list of strings")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def create_replacement_dict(concept_list, replacer):
|
| 189 |
+
replacement = {}
|
| 190 |
+
for concept in concept_list:
|
| 191 |
+
replacement[concept] = {}
|
| 192 |
+
for company in replacer:
|
| 193 |
+
replacement[concept][company] = {concept: company}
|
| 194 |
+
return replacement
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Title of the app
|
| 198 |
+
st.title("SAGED-bias Benchmark-Building Demo")
|
| 199 |
+
|
| 200 |
+
# Initialize session state variables
|
| 201 |
+
if 'domain' not in st.session_state:
|
| 202 |
+
st.session_state['domain'] = None
|
| 203 |
+
if 'concept_list' not in st.session_state:
|
| 204 |
+
st.session_state['concept_list'] = None
|
| 205 |
+
if 'gpt_model' not in st.session_state:
|
| 206 |
+
st.session_state['gpt_model'] = None
|
| 207 |
+
if 'azure_model' not in st.session_state:
|
| 208 |
+
st.session_state['azure_model'] = None
|
| 209 |
+
if 'ollama_model' not in st.session_state:
|
| 210 |
+
st.session_state['ollama_model'] = None
|
| 211 |
+
|
| 212 |
+
# Sidebar: Model Selection
|
| 213 |
+
with st.sidebar:
|
| 214 |
+
st.header("Model Configuration")
|
| 215 |
+
|
| 216 |
+
# Selection of which model to use
|
| 217 |
+
model_selection = st.radio("Select Model Type", ['GPT-Azure', 'Azure', 'Ollama'])
|
| 218 |
+
|
| 219 |
+
# Collapsible Additional Configuration Section
|
| 220 |
+
with st.expander("Model Configuration"):
|
| 221 |
+
if model_selection == 'Ollama':
|
| 222 |
+
# Ollama Configuration
|
| 223 |
+
ollama_deployment_name = st.text_input("Enter Ollama Model Deployment Name", placeholder="e.g., llama3")
|
| 224 |
+
ollama_system_prompt = st.text_input("Enter System Prompt for Ollama",
|
| 225 |
+
placeholder="e.g., You are a helpful assistant.")
|
| 226 |
+
|
| 227 |
+
if ollama_deployment_name and ollama_system_prompt:
|
| 228 |
+
confirm_ollama = st.button("Confirm Ollama Configuration")
|
| 229 |
+
if confirm_ollama:
|
| 230 |
+
st.session_state['ollama_model'] = OllamaModel(
|
| 231 |
+
model_name=ollama_deployment_name,
|
| 232 |
+
system_prompt=ollama_system_prompt
|
| 233 |
+
)
|
| 234 |
+
st.success("Ollama model configured successfully.")
|
| 235 |
+
else:
|
| 236 |
+
st.warning("Please provide both Ollama deployment name and system prompt.")
|
| 237 |
+
|
| 238 |
+
elif model_selection == 'GPT-Azure' or model_selection == 'Azure':
|
| 239 |
+
# GPT / Azure Configuration
|
| 240 |
+
gpt_azure_endpoint = st.text_input("Enter Azure Endpoint URL",
|
| 241 |
+
placeholder="e.g., https://your-resource-name.openai.azure.com/")
|
| 242 |
+
gpt_azure_api_key = st.text_input("Enter Azure API Key", type="password")
|
| 243 |
+
gpt_azure_model_name = st.text_input("Enter Azure Model Name", placeholder="e.g., GPT-3.5-turbo")
|
| 244 |
+
gpt_azure_deployment_name = st.text_input("Enter Azure Deployment Name",
|
| 245 |
+
placeholder="e.g., gpt-3-5-deployment")
|
| 246 |
+
|
| 247 |
+
if gpt_azure_endpoint and gpt_azure_api_key and gpt_azure_model_name and gpt_azure_deployment_name:
|
| 248 |
+
confirm_gpt_azure = st.button("Confirm GPT/Azure Configuration")
|
| 249 |
+
if confirm_gpt_azure:
|
| 250 |
+
if model_selection == 'GPT-Azure':
|
| 251 |
+
st.session_state['gpt_model'] = GPTAgent(
|
| 252 |
+
model_name=gpt_azure_model_name,
|
| 253 |
+
azure_key=gpt_azure_api_key,
|
| 254 |
+
azure_version='2023-05-15', # Update if necessary
|
| 255 |
+
azure_endpoint=gpt_azure_endpoint,
|
| 256 |
+
deployment_name=gpt_azure_deployment_name
|
| 257 |
+
)
|
| 258 |
+
st.success("GPT model configured successfully.")
|
| 259 |
+
elif model_selection == 'Azure':
|
| 260 |
+
st.session_state['azure_model'] = AzureAgent(
|
| 261 |
+
model_name=gpt_azure_model_name,
|
| 262 |
+
azure_uri=gpt_azure_endpoint,
|
| 263 |
+
azure_api_key=gpt_azure_api_key
|
| 264 |
+
)
|
| 265 |
+
st.success("Azure model configured successfully.")
|
| 266 |
+
else:
|
| 267 |
+
st.warning("Please provide all fields for GPT/Azure configuration.")
|
| 268 |
+
|
| 269 |
+
# Main interaction based on configured model
|
| 270 |
+
if st.session_state.get('ollama_model'):
|
| 271 |
+
model = st.session_state['ollama_model']
|
| 272 |
+
elif st.session_state.get('gpt_model'):
|
| 273 |
+
model = st.session_state['gpt_model']
|
| 274 |
+
elif st.session_state.get('azure_model'):
|
| 275 |
+
model = st.session_state['azure_model']
|
| 276 |
+
else:
|
| 277 |
+
model = None
|
| 278 |
+
|
| 279 |
+
# User input: Domain and Concepts
|
| 280 |
+
with st.form(key='domain_concept_form'):
|
| 281 |
+
domain = clean_spaces(
|
| 282 |
+
st.text_input("Enter the domain: (e.g., Stocks, Education)", placeholder="Enter domain here..."))
|
| 283 |
+
|
| 284 |
+
# User input: Concepts
|
| 285 |
+
concept_text = st.text_area("Enter the concepts (separated by commas):",
|
| 286 |
+
placeholder="e.g., excel-stock, ok-stock, bad-stock")
|
| 287 |
+
concept_list = clean_spaces(concept_text.split(','))
|
| 288 |
+
|
| 289 |
+
submit_button = st.form_submit_button(label='Confirm Domain and Concepts')
|
| 290 |
+
|
| 291 |
+
if submit_button:
|
| 292 |
+
if not domain:
|
| 293 |
+
st.warning("Please enter a domain.")
|
| 294 |
+
elif not concept_list or concept_text.strip() == "":
|
| 295 |
+
st.warning("Please enter at least one concept.")
|
| 296 |
+
else:
|
| 297 |
+
st.session_state['domain'] = domain
|
| 298 |
+
st.session_state['concept_list'] = concept_list
|
| 299 |
+
st.success("Domain and concepts confirmed.")
|
| 300 |
+
|
| 301 |
+
# Display further options only after domain and concepts are confirmed
|
| 302 |
+
if st.session_state['domain'] and st.session_state['concept_list']:
|
| 303 |
+
with st.expander("Additional Options"):
|
| 304 |
+
# User input: Method
|
| 305 |
+
scraper_method = st.radio("Select the scraper method:", (('wiki', 'local_files', 'synthetic_files')))
|
| 306 |
+
|
| 307 |
+
# Initiate the source_finder_requirement and keyword_finder_requirement if 'wiki' is selected
|
| 308 |
+
if scraper_method == 'wiki':
|
| 309 |
+
st.session_state['keyword_finder_requirement'] = True
|
| 310 |
+
st.session_state['source_finder_requirement'] = True
|
| 311 |
+
st.session_state['check_source_finder'] = False
|
| 312 |
+
|
| 313 |
+
# File upload for each concept if 'local_files' is selected
|
| 314 |
+
if scraper_method == 'local_files':
|
| 315 |
+
uploaded_files = {}
|
| 316 |
+
st.session_state['keyword_finder_requirement'] = False
|
| 317 |
+
st.session_state['source_finder_requirement'] = False
|
| 318 |
+
st.session_state['check_source_finder'] = True
|
| 319 |
+
for concept in st.session_state['concept_list']:
|
| 320 |
+
uploaded_file = st.file_uploader(f"Upload file for concept '{concept}':", type=['txt'],
|
| 321 |
+
key=f"file_{concept}")
|
| 322 |
+
if uploaded_file:
|
| 323 |
+
uploaded_files[concept] = uploaded_file
|
| 324 |
+
# Save uploaded file
|
| 325 |
+
save_path = Path(f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt")
|
| 326 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 327 |
+
with open(save_path, 'wb') as f:
|
| 328 |
+
f.write(uploaded_file.getbuffer())
|
| 329 |
+
st.success(f"File for concept '{concept}' saved successfully.")
|
| 330 |
+
|
| 331 |
+
# Generate synthetic files if 'synthetic_files' is selected
|
| 332 |
+
if scraper_method == 'synthetic_files':
|
| 333 |
+
scraper_method = 'local_files'
|
| 334 |
+
st.session_state['keyword_finder_requirement'] = False
|
| 335 |
+
st.session_state['source_finder_requirement'] = False
|
| 336 |
+
st.session_state['check_source_finder'] = True
|
| 337 |
+
if 'generated_synthetic_files' not in st.session_state:
|
| 338 |
+
st.session_state['generated_synthetic_files'] = set()
|
| 339 |
+
|
| 340 |
+
prompt_inputs = {}
|
| 341 |
+
for concept in st.session_state['concept_list']:
|
| 342 |
+
if concept not in st.session_state['generated_synthetic_files']:
|
| 343 |
+
prompt_inputs[concept] = st.text_input(
|
| 344 |
+
f"Enter the prompt for concept '{concept}':",
|
| 345 |
+
value=f"Write a long article introducing the {concept} in the {st.session_state['domain']}. Use the {concept} as much as possible.",
|
| 346 |
+
key=f"prompt_{concept}"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if st.button("Generate Synthetic Files for All Concepts"):
|
| 350 |
+
if model:
|
| 351 |
+
for concept, prompt in prompt_inputs.items():
|
| 352 |
+
if prompt:
|
| 353 |
+
with st.spinner(f"Generating content for concept '{concept}'..."):
|
| 354 |
+
synthetic_content = model.invoke(prompt)
|
| 355 |
+
save_path = Path(
|
| 356 |
+
f"data/customized/local_files/{st.session_state['domain']}/{concept}.txt")
|
| 357 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 358 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
| 359 |
+
f.write(synthetic_content)
|
| 360 |
+
st.session_state['generated_synthetic_files'].add(concept)
|
| 361 |
+
st.success(f"Synthetic file for concept '{concept}' created successfully.")
|
| 362 |
+
else:
|
| 363 |
+
st.warning("Please configure a model to generate synthetic files.")
|
| 364 |
+
|
| 365 |
+
# User input: Prompt Method
|
| 366 |
+
prompt_method = st.radio("Select the prompt method:", ('split_sentences', 'questions'), index = 0)
|
| 367 |
+
|
| 368 |
+
# User input: Max Benchmark Length
|
| 369 |
+
max_benchmark_length = st.slider("Select the maximum prompts per concepts:", 1, 199, 10)
|
| 370 |
+
|
| 371 |
+
# User input: Branching
|
| 372 |
+
branching = st.radio("Enable branching:", ('Yes', 'No'), index=1)
|
| 373 |
+
branching_enabled = True if branching == 'Yes' else False
|
| 374 |
+
|
| 375 |
+
# User input: Replacer (only if branching is enabled)
|
| 376 |
+
replacer = []
|
| 377 |
+
replacement = {}
|
| 378 |
+
if branching_enabled:
|
| 379 |
+
replacer_text = st.text_area("Enter the replacer list (list of strings, separated by commas):",
|
| 380 |
+
placeholder="e.g., Company A, Company B")
|
| 381 |
+
replacer = clean_spaces(replacer_text.split(','))
|
| 382 |
+
replacement = create_replacement_dict(st.session_state['concept_list'], replacer)
|
| 383 |
+
|
| 384 |
+
# Configuration
|
| 385 |
+
concept_specified_config = {
|
| 386 |
+
x: {'keyword_finder': {'manual_keywords': [x]}} for x in st.session_state['concept_list']
|
| 387 |
+
}
|
| 388 |
+
concept_configuration = {
|
| 389 |
+
'keyword_finder': {
|
| 390 |
+
'require': st.session_state['keyword_finder_requirement'],
|
| 391 |
+
'keyword_number': 1,
|
| 392 |
+
},
|
| 393 |
+
'source_finder': {
|
| 394 |
+
'require': st.session_state['source_finder_requirement'],
|
| 395 |
+
'scrap_number': 10,
|
| 396 |
+
'method': scraper_method,
|
| 397 |
+
},
|
| 398 |
+
'scraper': {
|
| 399 |
+
'require': True,
|
| 400 |
+
'method': scraper_method,
|
| 401 |
+
},
|
| 402 |
+
'prompt_maker': {
|
| 403 |
+
'method': prompt_method,
|
| 404 |
+
'generation_function': model.invoke if model else None,
|
| 405 |
+
'max_benchmark_length': max_benchmark_length,
|
| 406 |
+
},
|
| 407 |
+
}
|
| 408 |
+
domain_configuration = {
|
| 409 |
+
'categories': st.session_state['concept_list'],
|
| 410 |
+
'branching': branching_enabled,
|
| 411 |
+
'branching_config': {
|
| 412 |
+
'generation_function': model.invoke if model else None,
|
| 413 |
+
'keyword_reference': st.session_state['concept_list'],
|
| 414 |
+
'replacement_descriptor_require': False,
|
| 415 |
+
'replacement_description': replacement,
|
| 416 |
+
'branching_pairs': 'not all',
|
| 417 |
+
'direction': 'not both',
|
| 418 |
+
},
|
| 419 |
+
'shared_config': concept_configuration,
|
| 420 |
+
'category_specified_config': concept_specified_config
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
# Renew Source Finder Button
|
| 424 |
+
if st.button('Renew Source info'):
|
| 425 |
+
renew_source_finder(st.session_state['domain'], st.session_state['concept_list'])
|
| 426 |
+
|
| 427 |
+
# Save the original stdout to print to the terminal if needed later
|
| 428 |
+
original_stdout = sys.stdout
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# Define StreamToText to capture and display logs in real-time within Streamlit only
|
| 432 |
+
class StreamToText:
|
| 433 |
+
def __init__(self):
|
| 434 |
+
self.output = StringIO()
|
| 435 |
+
|
| 436 |
+
def write(self, message):
|
| 437 |
+
if message.strip(): # Avoid adding empty messages
|
| 438 |
+
# Only append to Streamlit display, not the terminal
|
| 439 |
+
st.session_state.log_messages.append(message.strip())
|
| 440 |
+
log_placeholder.text("\n".join(st.session_state.log_messages)) # Flush updated logs
|
| 441 |
+
|
| 442 |
+
def flush(self):
|
| 443 |
+
pass # Required for compatibility with sys.stdout
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
# Initialize session state for log messages
|
| 447 |
+
if 'log_messages' not in st.session_state:
|
| 448 |
+
st.session_state.log_messages = []
|
| 449 |
+
|
| 450 |
+
# Replace sys.stdout with our custom StreamToText instance
|
| 451 |
+
stream_to_text = StreamToText()
|
| 452 |
+
sys.stdout = stream_to_text
|
| 453 |
+
|
| 454 |
+
# Placeholder for displaying logs within a collapsible expander
|
| 455 |
+
with st.expander("Show Logs", expanded=False):
|
| 456 |
+
log_placeholder = st.empty() # Placeholder for dynamic log display
|
| 457 |
+
|
| 458 |
+
# Define the Create Benchmark button
|
| 459 |
+
if st.button("Create a Benchmark"):
|
| 460 |
+
st.session_state.log_messages = [] # Clear previous logs
|
| 461 |
+
with st.spinner("Creating benchmark..."):
|
| 462 |
+
if st.session_state['check_source_finder']:
|
| 463 |
+
# Check for relevant materials
|
| 464 |
+
check_and_create_source_files(st.session_state['domain'], st.session_state['concept_list'])
|
| 465 |
+
|
| 466 |
+
try:
|
| 467 |
+
# Display progress bar and log messages
|
| 468 |
+
progress_bar = st.progress(0)
|
| 469 |
+
for i in tqdm(range(1, 101)):
|
| 470 |
+
progress_bar.progress(i)
|
| 471 |
+
time.sleep(0.05) # Short delay to simulate processing time
|
| 472 |
+
|
| 473 |
+
# Run the benchmark creation function
|
| 474 |
+
benchmark = Pipeline.domain_benchmark_building(st.session_state['domain'], domain_configuration)
|
| 475 |
+
st.success("Benchmark creation completed!")
|
| 476 |
+
st.dataframe(benchmark.data)
|
| 477 |
+
|
| 478 |
+
except Exception as e:
|
| 479 |
+
st.error(f"An error occurred during benchmark creation: {e}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ollama==0.3.3
|
| 2 |
+
openai==1.54.3
|
| 3 |
+
pandas==1.5.3
|
| 4 |
+
streamlit==1.40.1
|
| 5 |
+
tqdm==4.66.4
|
| 6 |
+
SAGEDbias==0.0.2
|