Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
·
4bb7c94
1
Parent(s):
1fb895a
WIP
Browse files- app.py +44 -10
- llm_res.py +49 -36
- utils.py +1 -1
app.py
CHANGED
|
@@ -28,6 +28,7 @@ show_graph = False
|
|
| 28 |
show_analyze_status = False
|
| 29 |
show_overview = False
|
| 30 |
show_details = False
|
|
|
|
| 31 |
|
| 32 |
# IRIS connection
|
| 33 |
username = "demo"
|
|
@@ -66,7 +67,7 @@ with st.container():
|
|
| 66 |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
|
| 67 |
description_input, encoder
|
| 68 |
)
|
| 69 |
-
status.info(f'
|
| 70 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
| 71 |
status.divider()
|
| 72 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
|
@@ -94,26 +95,37 @@ with st.container():
|
|
| 94 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
| 95 |
augmented_set_of_diseases, encoder
|
| 96 |
)
|
| 97 |
-
status.info(f'
|
| 98 |
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
| 99 |
status.divider()
|
| 100 |
status.write("Getting the details of the clinical trials...")
|
| 101 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
| 102 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
| 103 |
)
|
|
|
|
| 104 |
status.json(json_of_clinical_trials, expanded=False)
|
| 105 |
status.divider()
|
| 106 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
try:
|
| 111 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
| 112 |
status.write("Getting summary statistics of the clinical trials...")
|
| 113 |
response = tagging_insights_from_json(json_of_clinical_trials)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
print(f'Response from LLM tagging: {response}')
|
| 115 |
-
status.
|
| 116 |
except Exception as e:
|
|
|
|
| 117 |
print(f'Error while extracting numerical data from the clinical trials: {e}')
|
| 118 |
status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
|
| 119 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
|
@@ -170,10 +182,32 @@ $$"""
|
|
| 170 |
# overview
|
| 171 |
with st.container():
|
| 172 |
if show_overview:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
# details
|
|
|
|
| 28 |
show_analyze_status = False
|
| 29 |
show_overview = False
|
| 30 |
show_details = False
|
| 31 |
+
show_metrics = False
|
| 32 |
|
| 33 |
# IRIS connection
|
| 34 |
username = "demo"
|
|
|
|
| 67 |
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(
|
| 68 |
description_input, encoder
|
| 69 |
)
|
| 70 |
+
status.info(f'Selected {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
|
| 71 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
| 72 |
status.divider()
|
| 73 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
|
|
|
| 95 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
| 96 |
augmented_set_of_diseases, encoder
|
| 97 |
)
|
| 98 |
+
status.info(f'Selected {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
|
| 99 |
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
| 100 |
status.divider()
|
| 101 |
status.write("Getting the details of the clinical trials...")
|
| 102 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
| 103 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
| 104 |
)
|
| 105 |
+
status.success(f'Details of the clinical trials obtained.')
|
| 106 |
status.json(json_of_clinical_trials, expanded=False)
|
| 107 |
status.divider()
|
| 108 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
| 109 |
+
try:
|
| 110 |
+
status.write("Getting a summary of the clinical trials...")
|
| 111 |
+
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
| 112 |
+
status.success("Summary of the clinical trials obtained.")
|
| 113 |
+
disease_overview = response
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f'Error while getting a summary of the clinical trials: {e}')
|
| 116 |
+
status.warning(f'Error while getting a summary of the clinical trials. This information will not be shown.')
|
| 117 |
try:
|
| 118 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
| 119 |
status.write("Getting summary statistics of the clinical trials...")
|
| 120 |
response = tagging_insights_from_json(json_of_clinical_trials)
|
| 121 |
+
average_minimum_age = response["avg_min_age"]
|
| 122 |
+
average_maximum_age = response["avg_max_age"]
|
| 123 |
+
most_common_gender = response['most_common_gender']
|
| 124 |
+
|
| 125 |
print(f'Response from LLM tagging: {response}')
|
| 126 |
+
status.success(f'Summary statistics of the clinical trials obtained.')
|
| 127 |
except Exception as e:
|
| 128 |
+
raise e
|
| 129 |
print(f'Error while extracting numerical data from the clinical trials: {e}')
|
| 130 |
status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
|
| 131 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
|
|
|
| 182 |
# overview
|
| 183 |
with st.container():
|
| 184 |
if show_overview:
|
| 185 |
+
try:
|
| 186 |
+
st.write("## Overview of Related Clinical Trials")
|
| 187 |
+
st.write(disease_overview)
|
| 188 |
+
time.sleep(1)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f'Error while showing the overview of the clinical trials: {e}')
|
| 191 |
+
finally:
|
| 192 |
+
show_metrics = True
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
with st.container():
|
| 196 |
+
if show_metrics:
|
| 197 |
+
try:
|
| 198 |
+
st.write("## Metrics of the Clinical Trials")
|
| 199 |
+
col1, col2, col3 = st.columns(3)
|
| 200 |
+
with col1:
|
| 201 |
+
st.metric("Average Minimum Age", average_minimum_age)
|
| 202 |
+
with col2:
|
| 203 |
+
st.metric("Average Maximum Age", average_maximum_age)
|
| 204 |
+
with col3:
|
| 205 |
+
st.metric("Most Common Gender", most_common_gender)
|
| 206 |
+
time.sleep(2)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f'Error while showing the metrics: {e}')
|
| 209 |
+
finally:
|
| 210 |
+
show_details = True
|
| 211 |
|
| 212 |
|
| 213 |
# details
|
llm_res.py
CHANGED
|
@@ -24,6 +24,7 @@ from langchain.chains.llm import LLMChain
|
|
| 24 |
from langchain_core.prompts import PromptTemplate
|
| 25 |
from collections import Counter
|
| 26 |
import statistics
|
|
|
|
| 27 |
|
| 28 |
load_dotenv()
|
| 29 |
|
|
@@ -134,11 +135,12 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
|
|
| 134 |
# "eligibility": eligibility,
|
| 135 |
# }
|
| 136 |
# filtered_data.append(filtered_item)
|
| 137 |
-
|
| 138 |
# return filtered_data
|
| 139 |
# # for ele in filtered_data:
|
| 140 |
# # print(ele)
|
| 141 |
|
|
|
|
| 142 |
def process_dictionaty_with_llm_to_generate_response(json_data):
|
| 143 |
# processed_data = process_json_data_for_llm(json_data)
|
| 144 |
# res = tagging_chain.invoke({"input": processed_data})
|
|
@@ -217,9 +219,10 @@ def process_dictionaty_with_llm_to_generate_response(json_data):
|
|
| 217 |
"eligibility": eligibility,
|
| 218 |
}
|
| 219 |
filtered_data.append(filtered_item)
|
| 220 |
-
|
| 221 |
return filtered_data
|
| 222 |
|
|
|
|
| 223 |
def get_short_summary_out_of_json_files(data_json):
|
| 224 |
prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
|
| 225 |
|
|
@@ -272,29 +275,36 @@ General summary:"""
|
|
| 272 |
|
| 273 |
return result
|
| 274 |
|
|
|
|
| 275 |
def analyze_data(data):
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
| 279 |
# primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
|
| 280 |
-
|
| 281 |
# Calculate average minimum and maximum ages
|
| 282 |
avg_min_age = statistics.mean(min_ages) if min_ages else None
|
| 283 |
avg_max_age = statistics.mean(max_ages) if max_ages else None
|
| 284 |
-
|
| 285 |
# Find most common gender
|
| 286 |
-
gender_counter = Counter(data[
|
| 287 |
most_common_gender = gender_counter.most_common(1)[0][0]
|
| 288 |
-
|
| 289 |
# Flatten keywords list and find common keywords
|
| 290 |
-
keywords = [keyword for sublist in data[
|
| 291 |
-
common_keywords = [word for word, count in Counter(keywords).most_common()]
|
| 292 |
-
|
| 293 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
def tagging_insights_from_json(data_json):
|
| 296 |
-
processed_json= process_dictionaty_with_llm_to_generate_response(data_json)
|
| 297 |
-
|
| 298 |
tagging_prompt = ChatPromptTemplate.from_template(
|
| 299 |
"""
|
| 300 |
You are an expert on clinicial trials and analysis of their reports.
|
|
@@ -307,6 +317,7 @@ def tagging_insights_from_json(data_json):
|
|
| 307 |
{input}
|
| 308 |
"""
|
| 309 |
)
|
|
|
|
| 310 |
class Classification(BaseModel):
|
| 311 |
# description: str = Field(
|
| 312 |
# description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
|
|
@@ -317,25 +328,25 @@ def tagging_insights_from_json(data_json):
|
|
| 317 |
# status: list = Field(
|
| 318 |
# description="Extract the status of all the clinical trials"
|
| 319 |
# )
|
| 320 |
-
#keywords: list = Field(
|
| 321 |
# description="Extract the most relevant keywords for each clinical trials"
|
| 322 |
-
#)
|
| 323 |
# interventions: list = Field(
|
| 324 |
# description="describe the interventions for each clinical trial using title, name and description"
|
| 325 |
# )
|
| 326 |
-
#primary_outcomes: list = Field(
|
| 327 |
# description="get the timeframe of each clinical trial"
|
| 328 |
-
#)
|
| 329 |
-
#secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
| 330 |
-
#eligibility: list = Field(
|
| 331 |
# description="get the timeframe of each clinical trial"
|
| 332 |
-
#)
|
| 333 |
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
| 334 |
minimum_age: list = Field(
|
| 335 |
-
|
| 336 |
)
|
| 337 |
maximum_age: list = Field(
|
| 338 |
-
|
| 339 |
)
|
| 340 |
gender: list = Field(description="get the gender from each experiment")
|
| 341 |
|
|
@@ -343,15 +354,15 @@ def tagging_insights_from_json(data_json):
|
|
| 343 |
return {
|
| 344 |
# "project_title": self.project_title,
|
| 345 |
# "status": self.status,
|
| 346 |
-
#"keywords": self.keywords,
|
| 347 |
# "interventions": self.interventions,
|
| 348 |
-
#"primary_outcomes": self.primary_outcomes,
|
| 349 |
-
#"secondary_outcomes": self.secondary_outcomes,
|
| 350 |
# "eligibility": self.eligibility,
|
| 351 |
# "healthy_volunteers": self.healthy_volunteers,
|
| 352 |
"minimum_age": self.minimum_age,
|
| 353 |
"maximum_age": self.maximum_age,
|
| 354 |
-
"gender": self.gender
|
| 355 |
}
|
| 356 |
|
| 357 |
# LLM
|
|
@@ -365,18 +376,20 @@ def tagging_insights_from_json(data_json):
|
|
| 365 |
|
| 366 |
tagging_chain = tagging_prompt | llm
|
| 367 |
|
| 368 |
-
res= tagging_chain.invoke({"input": processed_json})
|
| 369 |
-
|
| 370 |
|
| 371 |
-
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
#stats_dict= {'Average Minimum age': avg_min_age,
|
| 374 |
# 'Average Maximum age': avg_max_age,
|
| 375 |
# 'Most common gender undergoing the trials': most_common_gender,
|
| 376 |
# 'common keywords found in the trials': common_keywords}
|
| 377 |
-
|
| 378 |
-
print(f"Result_tagging: {
|
| 379 |
-
return
|
| 380 |
|
| 381 |
|
| 382 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
|
@@ -386,4 +399,4 @@ def tagging_insights_from_json(data_json):
|
|
| 386 |
# json.dump(clinical_record_info, f, indent=4)
|
| 387 |
|
| 388 |
|
| 389 |
-
# tagging_chain = tagging_insights_from_json(json_data)
|
|
|
|
| 24 |
from langchain_core.prompts import PromptTemplate
|
| 25 |
from collections import Counter
|
| 26 |
import statistics
|
| 27 |
+
import regex as re
|
| 28 |
|
| 29 |
load_dotenv()
|
| 30 |
|
|
|
|
| 135 |
# "eligibility": eligibility,
|
| 136 |
# }
|
| 137 |
# filtered_data.append(filtered_item)
|
| 138 |
+
|
| 139 |
# return filtered_data
|
| 140 |
# # for ele in filtered_data:
|
| 141 |
# # print(ele)
|
| 142 |
|
| 143 |
+
|
| 144 |
def process_dictionaty_with_llm_to_generate_response(json_data):
|
| 145 |
# processed_data = process_json_data_for_llm(json_data)
|
| 146 |
# res = tagging_chain.invoke({"input": processed_data})
|
|
|
|
| 219 |
"eligibility": eligibility,
|
| 220 |
}
|
| 221 |
filtered_data.append(filtered_item)
|
| 222 |
+
|
| 223 |
return filtered_data
|
| 224 |
|
| 225 |
+
|
| 226 |
def get_short_summary_out_of_json_files(data_json):
|
| 227 |
prompt_template = """You are an expert on clinicial trials and their analysis of their reports.
|
| 228 |
|
|
|
|
| 275 |
|
| 276 |
return result
|
| 277 |
|
| 278 |
+
|
| 279 |
def analyze_data(data):
|
| 280 |
+
print(f"Data: {data}")
|
| 281 |
+
# Extract minimum and maximum ages: Turn ['18 Years', '20 Years'] into [18, 20]
|
| 282 |
+
min_ages = [int(re.search(r"\d+", age).group()) for age in data["minimum_age"] if age]
|
| 283 |
+
max_ages = [int(re.search(r"\d+", age).group()) for age in data["maximum_age"] if age]
|
| 284 |
# primary_timeframe= [int(age.split()[0]) for age in data['[primary_outcome]'] if age]
|
| 285 |
+
|
| 286 |
# Calculate average minimum and maximum ages
|
| 287 |
avg_min_age = statistics.mean(min_ages) if min_ages else None
|
| 288 |
avg_max_age = statistics.mean(max_ages) if max_ages else None
|
| 289 |
+
|
| 290 |
# Find most common gender
|
| 291 |
+
gender_counter = Counter(data["gender"])
|
| 292 |
most_common_gender = gender_counter.most_common(1)[0][0]
|
| 293 |
+
|
| 294 |
# Flatten keywords list and find common keywords
|
| 295 |
+
#keywords = [keyword for sublist in data["keywords"] for keyword in sublist]
|
| 296 |
+
#common_keywords = [word for word, count in Counter(keywords).most_common()]
|
| 297 |
+
|
| 298 |
+
return {
|
| 299 |
+
"avg_min_age": avg_min_age,
|
| 300 |
+
"avg_max_age": avg_max_age,
|
| 301 |
+
"most_common_gender": most_common_gender
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
|
| 305 |
def tagging_insights_from_json(data_json):
|
| 306 |
+
processed_json = process_dictionaty_with_llm_to_generate_response(data_json)
|
| 307 |
+
|
| 308 |
tagging_prompt = ChatPromptTemplate.from_template(
|
| 309 |
"""
|
| 310 |
You are an expert on clinicial trials and analysis of their reports.
|
|
|
|
| 317 |
{input}
|
| 318 |
"""
|
| 319 |
)
|
| 320 |
+
|
| 321 |
class Classification(BaseModel):
|
| 322 |
# description: str = Field(
|
| 323 |
# description="text description grouping all the clinical trials using briefDescription and detailedDescription keys"
|
|
|
|
| 328 |
# status: list = Field(
|
| 329 |
# description="Extract the status of all the clinical trials"
|
| 330 |
# )
|
| 331 |
+
# keywords: list = Field(
|
| 332 |
# description="Extract the most relevant keywords for each clinical trials"
|
| 333 |
+
# )
|
| 334 |
# interventions: list = Field(
|
| 335 |
# description="describe the interventions for each clinical trial using title, name and description"
|
| 336 |
# )
|
| 337 |
+
# primary_outcomes: list = Field(
|
| 338 |
# description="get the timeframe of each clinical trial"
|
| 339 |
+
# )
|
| 340 |
+
# secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
| 341 |
+
# eligibility: list = Field(
|
| 342 |
# description="get the timeframe of each clinical trial"
|
| 343 |
+
# )
|
| 344 |
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
| 345 |
minimum_age: list = Field(
|
| 346 |
+
description="get the minimum age from each experiment"
|
| 347 |
)
|
| 348 |
maximum_age: list = Field(
|
| 349 |
+
description="get the maximum age from each experiment"
|
| 350 |
)
|
| 351 |
gender: list = Field(description="get the gender from each experiment")
|
| 352 |
|
|
|
|
| 354 |
return {
|
| 355 |
# "project_title": self.project_title,
|
| 356 |
# "status": self.status,
|
| 357 |
+
# "keywords": self.keywords,
|
| 358 |
# "interventions": self.interventions,
|
| 359 |
+
# "primary_outcomes": self.primary_outcomes,
|
| 360 |
+
# "secondary_outcomes": self.secondary_outcomes,
|
| 361 |
# "eligibility": self.eligibility,
|
| 362 |
# "healthy_volunteers": self.healthy_volunteers,
|
| 363 |
"minimum_age": self.minimum_age,
|
| 364 |
"maximum_age": self.maximum_age,
|
| 365 |
+
"gender": self.gender,
|
| 366 |
}
|
| 367 |
|
| 368 |
# LLM
|
|
|
|
| 376 |
|
| 377 |
tagging_chain = tagging_prompt | llm
|
| 378 |
|
| 379 |
+
res = tagging_chain.invoke({"input": processed_json})
|
| 380 |
+
unprocessed_results_dict = res.get_dict()
|
| 381 |
|
| 382 |
+
results_dict = analyze_data(
|
| 383 |
+
unprocessed_results_dict
|
| 384 |
+
)
|
| 385 |
|
| 386 |
+
# stats_dict= {'Average Minimum age': avg_min_age,
|
| 387 |
# 'Average Maximum age': avg_max_age,
|
| 388 |
# 'Most common gender undergoing the trials': most_common_gender,
|
| 389 |
# 'common keywords found in the trials': common_keywords}
|
| 390 |
+
|
| 391 |
+
print(f"Result_tagging: {results_dict}")
|
| 392 |
+
return results_dict
|
| 393 |
|
| 394 |
|
| 395 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
|
|
|
| 399 |
# json.dump(clinical_record_info, f, indent=4)
|
| 400 |
|
| 401 |
|
| 402 |
+
# tagging_chain = tagging_insights_from_json(json_data)
|
utils.py
CHANGED
|
@@ -189,7 +189,7 @@ def get_clinical_trials_related_to_diseases(
|
|
| 189 |
with engine.connect() as conn:
|
| 190 |
with conn.begin():
|
| 191 |
sql = f"""
|
| 192 |
-
SELECT TOP
|
| 193 |
FROM Test.ClinicalTrials d
|
| 194 |
ORDER BY distance DESC
|
| 195 |
"""
|
|
|
|
| 189 |
with engine.connect() as conn:
|
| 190 |
with conn.begin():
|
| 191 |
sql = f"""
|
| 192 |
+
SELECT TOP 15 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
| 193 |
FROM Test.ClinicalTrials d
|
| 194 |
ORDER BY distance DESC
|
| 195 |
"""
|