Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
·
1b1c01c
1
Parent(s):
3276764
WIP
Browse files- MATLAB/main.m +1 -1
- MATLAB/visualize_app.mlapp +0 -0
- MATLAB/visualize_connectedNodes_continuous.m +1 -1
- app.py +55 -48
- llm_res.py +21 -21
- utils.py +3 -2
MATLAB/main.m
CHANGED
|
@@ -15,7 +15,7 @@ function display_elementsForKey(connectionsMap, key)
|
|
| 15 |
end
|
| 16 |
end
|
| 17 |
|
| 18 |
-
data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
|
| 19 |
data = renamevars(data,"#CUI1","CUI1");
|
| 20 |
data = data(1:1000,:);
|
| 21 |
|
|
|
|
| 15 |
end
|
| 16 |
end
|
| 17 |
|
| 18 |
+
data = readtable('../MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
|
| 19 |
data = renamevars(data,"#CUI1","CUI1");
|
| 20 |
data = data(1:1000,:);
|
| 21 |
|
MATLAB/visualize_app.mlapp
CHANGED
|
Binary files a/MATLAB/visualize_app.mlapp and b/MATLAB/visualize_app.mlapp differ
|
|
|
MATLAB/visualize_connectedNodes_continuous.m
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
function visualize_connectedNodes_continuous()
|
| 2 |
% Read the data and create the connections map
|
| 3 |
-
data = readtable('MGREL.RRF', 'Delimiter', '|', 'FileType', 'text', 'NumHeaderLines', 0, 'VariableNamingRule', 'preserve');
|
| 4 |
data = renamevars(data, '#CUI1', 'CUI1');
|
| 5 |
data = data(1:10000,:);
|
| 6 |
|
|
|
|
| 1 |
function visualize_connectedNodes_continuous()
|
| 2 |
% Read the data and create the connections map
|
| 3 |
+
data = readtable('../MGREL.RRF', 'Delimiter', '|', 'FileType', 'text', 'NumHeaderLines', 0, 'VariableNamingRule', 'preserve');
|
| 4 |
data = renamevars(data, '#CUI1', 'CUI1');
|
| 5 |
data = data(1:10000,:);
|
| 6 |
|
app.py
CHANGED
|
@@ -105,15 +105,17 @@ with st.container():
|
|
| 105 |
status.divider()
|
| 106 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
| 107 |
status.write("Getting a summary of the clinical trials...")
|
| 108 |
-
response
|
| 109 |
disease_overview = response
|
| 110 |
-
|
| 111 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
| 118 |
status.update(label="Done!", state="complete")
|
| 119 |
status.balloons()
|
|
@@ -187,52 +189,57 @@ with st.container():
|
|
| 187 |
with tabs[i]:
|
| 188 |
render_trial_details(trials[i])
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
get_all_diseases_name(engine))
|
| 194 |
-
|
| 195 |
-
st.write("You selected:", chosen_disease_name)
|
| 196 |
-
chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
|
| 197 |
-
|
| 198 |
-
nodes = []
|
| 199 |
-
edges = []
|
| 200 |
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
nodes.append( Node(id=chosen_disease_uri,
|
| 203 |
-
label=chosen_disease_name,
|
| 204 |
-
size=25,
|
| 205 |
-
shape="circular")
|
| 206 |
-
)
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
for uri, name, weight in similar_diseases:
|
| 211 |
-
nodes.append( Node(id=uri,
|
| 212 |
-
label=name,
|
| 213 |
size=25,
|
| 214 |
shape="circular")
|
| 215 |
)
|
| 216 |
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
status.divider()
|
| 106 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
| 107 |
status.write("Getting a summary of the clinical trials...")
|
| 108 |
+
response = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
| 109 |
disease_overview = response
|
| 110 |
+
try:
|
| 111 |
# 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
|
| 112 |
+
status.write("Getting summary statistics of the clinical trials...")
|
| 113 |
+
response = tagging_insights_from_json(json_of_clinical_trials)
|
| 114 |
+
print(f'Response from LLM tagging: {response}')
|
| 115 |
+
status.write(f'Response from LLM tagging: {response}')
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f'Error while extracting numerical data from the clinical trials: {e}')
|
| 118 |
+
status.warning(f'Error while extracting numerical data from the clinical trials. This information will not be shown.')
|
| 119 |
# 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
|
| 120 |
status.update(label="Done!", state="complete")
|
| 121 |
status.balloons()
|
|
|
|
| 189 |
with tabs[i]:
|
| 190 |
render_trial_details(trials[i])
|
| 191 |
|
| 192 |
+
show_graph_of_all_diseases = False
|
| 193 |
+
if show_graph_of_all_diseases:
|
| 194 |
+
# If disease_names is not defined, define it
|
| 195 |
+
if "disease_names" not in st.session_state:
|
| 196 |
+
st.session_state.disease_names = get_all_diseases_name(engine)
|
| 197 |
+
chosen_disease_name = st.selectbox(
|
| 198 |
+
"Choose a disease",
|
| 199 |
+
st.session_state.disease_names,
|
| 200 |
+
)
|
| 201 |
|
| 202 |
+
st.write("You selected:", chosen_disease_name)
|
| 203 |
+
chosen_disease_uri = get_uri_from_name(engine, chosen_disease_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
+
nodes = []
|
| 206 |
+
edges = []
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
nodes.append( Node(id=chosen_disease_uri,
|
| 210 |
+
label=chosen_disease_name,
|
|
|
|
|
|
|
|
|
|
| 211 |
size=25,
|
| 212 |
shape="circular")
|
| 213 |
)
|
| 214 |
|
| 215 |
+
similar_diseases = get_most_similar_diseases_from_uri(engine, chosen_disease_uri, threshold=0.6)
|
| 216 |
+
print(similar_diseases)
|
| 217 |
+
for uri, name, weight in similar_diseases:
|
| 218 |
+
nodes.append( Node(id=uri,
|
| 219 |
+
label=name,
|
| 220 |
+
size=25,
|
| 221 |
+
shape="circular")
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
print(True if float(weight) > 0.7 else False)
|
| 225 |
+
edges.append( Edge(source=chosen_disease_uri,
|
| 226 |
+
target=uri,
|
| 227 |
+
color="red" if float(weight) > 0.7 else "blue",
|
| 228 |
+
weight=float(weight)**10,
|
| 229 |
+
type="CURVE_SMOOTH"
|
| 230 |
+
# type="STRAIGHT"
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
config = Config(width=750,
|
| 235 |
+
height=950,
|
| 236 |
+
directed=False,
|
| 237 |
+
physics=True,
|
| 238 |
+
hierarchical=False,
|
| 239 |
+
collapsible=False,
|
| 240 |
+
# **kwargs
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return_value = agraph(nodes=nodes,
|
| 244 |
+
edges=edges,
|
| 245 |
+
config=config)
|
llm_res.py
CHANGED
|
@@ -301,7 +301,7 @@ def tagging_insights_from_json(data_json):
|
|
| 301 |
|
| 302 |
Extract the desired information from the following JSON data.
|
| 303 |
|
| 304 |
-
Only extract the properties mentioned in the 'Classification' function.
|
| 305 |
|
| 306 |
JSON data:
|
| 307 |
{input}
|
|
@@ -317,20 +317,20 @@ def tagging_insights_from_json(data_json):
|
|
| 317 |
# status: list = Field(
|
| 318 |
# description="Extract the status of all the clinical trials"
|
| 319 |
# )
|
| 320 |
-
keywords: list = Field(
|
| 321 |
-
|
| 322 |
-
)
|
| 323 |
# interventions: list = Field(
|
| 324 |
# description="describe the interventions for each clinical trial using title, name and description"
|
| 325 |
# )
|
| 326 |
-
primary_outcomes: list = Field(
|
| 327 |
-
|
| 328 |
-
)
|
| 329 |
-
secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
| 330 |
-
eligibility: list = Field(
|
| 331 |
-
|
| 332 |
-
)
|
| 333 |
-
healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
| 334 |
minimum_age: list = Field(
|
| 335 |
description="get the minimum age from each experiment"
|
| 336 |
)
|
|
@@ -343,12 +343,12 @@ def tagging_insights_from_json(data_json):
|
|
| 343 |
return {
|
| 344 |
# "project_title": self.project_title,
|
| 345 |
# "status": self.status,
|
| 346 |
-
"keywords": self.keywords,
|
| 347 |
# "interventions": self.interventions,
|
| 348 |
-
"primary_outcomes": self.primary_outcomes,
|
| 349 |
-
"secondary_outcomes": self.secondary_outcomes,
|
| 350 |
# "eligibility": self.eligibility,
|
| 351 |
-
"healthy_volunteers": self.healthy_volunteers,
|
| 352 |
"minimum_age": self.minimum_age,
|
| 353 |
"maximum_age": self.maximum_age,
|
| 354 |
"gender": self.gender
|
|
@@ -370,13 +370,13 @@ def tagging_insights_from_json(data_json):
|
|
| 370 |
|
| 371 |
avg_min_age, avg_max_age, most_common_gender, common_keywords= analyze_data(result_dict)
|
| 372 |
|
| 373 |
-
stats_dict= {'Average Minimum age': avg_min_age,
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
|
| 378 |
print(f"Result_tagging: {result_dict}")
|
| 379 |
-
return result_dict
|
| 380 |
|
| 381 |
|
| 382 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
|
|
|
| 301 |
|
| 302 |
Extract the desired information from the following JSON data.
|
| 303 |
|
| 304 |
+
Only extract the properties mentioned in the 'Classification' function. Output a list of the extracted properties, starting with [ and ending with ].
|
| 305 |
|
| 306 |
JSON data:
|
| 307 |
{input}
|
|
|
|
| 317 |
# status: list = Field(
|
| 318 |
# description="Extract the status of all the clinical trials"
|
| 319 |
# )
|
| 320 |
+
#keywords: list = Field(
|
| 321 |
+
# description="Extract the most relevant keywords for each clinical trials"
|
| 322 |
+
#)
|
| 323 |
# interventions: list = Field(
|
| 324 |
# description="describe the interventions for each clinical trial using title, name and description"
|
| 325 |
# )
|
| 326 |
+
#primary_outcomes: list = Field(
|
| 327 |
+
# description="get the timeframe of each clinical trial"
|
| 328 |
+
#)
|
| 329 |
+
#secondary_outcomes: list= Field(description= "get the secondary outcomes of each clinical trial")
|
| 330 |
+
#eligibility: list = Field(
|
| 331 |
+
# description="get the timeframe of each clinical trial"
|
| 332 |
+
#)
|
| 333 |
+
# healthy_volunteers: list= Field(description= "determine whether the clinical trial requires healthy volunteers")
|
| 334 |
minimum_age: list = Field(
|
| 335 |
description="get the minimum age from each experiment"
|
| 336 |
)
|
|
|
|
| 343 |
return {
|
| 344 |
# "project_title": self.project_title,
|
| 345 |
# "status": self.status,
|
| 346 |
+
#"keywords": self.keywords,
|
| 347 |
# "interventions": self.interventions,
|
| 348 |
+
#"primary_outcomes": self.primary_outcomes,
|
| 349 |
+
#"secondary_outcomes": self.secondary_outcomes,
|
| 350 |
# "eligibility": self.eligibility,
|
| 351 |
+
# "healthy_volunteers": self.healthy_volunteers,
|
| 352 |
"minimum_age": self.minimum_age,
|
| 353 |
"maximum_age": self.maximum_age,
|
| 354 |
"gender": self.gender
|
|
|
|
| 370 |
|
| 371 |
avg_min_age, avg_max_age, most_common_gender, common_keywords= analyze_data(result_dict)
|
| 372 |
|
| 373 |
+
#stats_dict= {'Average Minimum age': avg_min_age,
|
| 374 |
+
# 'Average Maximum age': avg_max_age,
|
| 375 |
+
# 'Most common gender undergoing the trials': most_common_gender,
|
| 376 |
+
# 'common keywords found in the trials': common_keywords}
|
| 377 |
|
| 378 |
print(f"Result_tagging: {result_dict}")
|
| 379 |
+
return result_dict#, stats_dict
|
| 380 |
|
| 381 |
|
| 382 |
# clinical_record_info = get_clinical_records_by_ids(['NCT00841061', 'NCT03035123', 'NCT02272751', 'NCT03035123', 'NCT03055377'])
|
utils.py
CHANGED
|
@@ -18,15 +18,16 @@ engine = create_engine(CONNECTION_STRING)
|
|
| 18 |
|
| 19 |
|
| 20 |
def get_all_diseases_name(engine) -> List[List[str]]:
|
|
|
|
| 21 |
with engine.connect() as conn:
|
| 22 |
with conn.begin():
|
| 23 |
sql = f"""
|
| 24 |
-
SELECT
|
| 25 |
"""
|
| 26 |
result = conn.execute(text(sql))
|
| 27 |
data = result.fetchall()
|
| 28 |
|
| 29 |
-
all_diseases = [row[
|
| 30 |
return all_diseases
|
| 31 |
|
| 32 |
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def get_all_diseases_name(engine) -> List[List[str]]:
|
| 21 |
+
print("Fetching all disease names...")
|
| 22 |
with engine.connect() as conn:
|
| 23 |
with conn.begin():
|
| 24 |
sql = f"""
|
| 25 |
+
SELECT label FROM Test.EntityEmbeddings
|
| 26 |
"""
|
| 27 |
result = conn.execute(text(sql))
|
| 28 |
data = result.fetchall()
|
| 29 |
|
| 30 |
+
all_diseases = [row[0] for row in data if row[0] != "nan"]
|
| 31 |
return all_diseases
|
| 32 |
|
| 33 |
|