Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from streamlit_agraph import agraph, Node, Edge, Config | |
| import os | |
| from sqlalchemy import create_engine, text | |
| import pandas as pd | |
| from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris | |
| import json | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| username = 'demo' | |
| password = 'demo' | |
| hostname = os.getenv('IRIS_HOSTNAME', 'localhost') | |
| port = '1972' | |
| namespace = 'USER' | |
| CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
| engine = create_engine(CONNECTION_STRING) | |
| def handle_click_on_analyze_button(user_text): | |
| # 1. Embed the textual description that the user entered using the model | |
| # 2. Get 5 diseases with the highest cosine silimarity from the DB | |
| encoder = SentenceTransformer("allenai-specter") | |
| diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(user_text, encoder) | |
| #for disease_label in diseases_related_to_the_user_text: | |
| # st.text(disease_label) | |
| # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases) | |
| diseases_uris = [disease['uri'] for disease in diseases_related_to_the_user_text] | |
| get_similarities_among_diseases_uris(diseases_uris) | |
| print(diseases_related_to_the_user_text) | |
| # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8) | |
| # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases | |
| # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases | |
| # 7. Use an LLM to get a summary of the clinical trials, in plain text format | |
| # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that. | |
| # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered | |
| pass | |
| st.write("# Klìnic") | |
| description_input = st.text_input(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.') | |
| if st.button("Analyze"): | |
| handle_click_on_analyze_button(description_input) | |
| # TODO: also when user clicks enter | |
| st.write(":red[Here should be the graph]") # TODO remove | |
| chart_data = pd.DataFrame( | |
| np.random.randn(20, 3), columns=["a", "b", "c"] | |
| ) # TODO remove | |
| st.scatter_chart(chart_data) # TODO remove | |
| st.write("## Disease Overview") | |
| disease_overview = ":red[lorem ipsum]" # TODO | |
| st.write(disease_overview) | |
| st.write("## Clinical Trials Details") | |
| trials = [] | |
| # TODO replace mock data | |
| with open("mock_trial.json") as f: | |
| d = json.load(f) | |
| for i in range(0, 5): | |
| trials.append(d) | |
| for trial in trials: | |
| with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"): | |
| official_title = trial["protocolSection"]["identificationModule"][ | |
| "officialTitle" | |
| ] | |
| st.write(f"##### {official_title}") | |
| brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"] | |
| st.write(brief_summary) | |
| status_module = { | |
| "Status": trial["protocolSection"]["statusModule"]["overallStatus"], | |
| "Status Date": trial["protocolSection"]["statusModule"][ | |
| "statusVerifiedDate" | |
| ], | |
| } | |
| st.write("###### Status") | |
| st.table(status_module) | |
| design_module = { | |
| "Study Type": trial["protocolSection"]["designModule"]["studyType"], | |
| # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array | |
| "Allocation": trial["protocolSection"]["designModule"]["designInfo"][ | |
| "allocation" | |
| ], | |
| "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][ | |
| "count" | |
| ], | |
| } | |
| st.write("###### Design") | |
| st.table(design_module) | |
| # TODO more modules? | |