Spaces:
Sleeping
Sleeping
| # %% | |
| import os | |
| from typing import Any, Dict, List | |
| import pandas as pd | |
| import requests | |
| import streamlit as st | |
| from sentence_transformers import SentenceTransformer | |
| from sqlalchemy import create_engine, text | |
| username = "demo" | |
| password = "demo" | |
| hostname = os.getenv("IRIS_HOSTNAME", "localhost") | |
| port = "1972" | |
| namespace = "USER" | |
| CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
| engine = create_engine(CONNECTION_STRING) | |
| def get_all_diseases_name(engine) -> List[List[str]]: | |
| print("Fetching all disease names...") | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT label FROM Test.EntityEmbeddings | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| all_diseases = [row[0] for row in data if row[0] != "nan"] | |
| return all_diseases | |
| def get_uri_from_name(engine, name: str) -> str: | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT uri FROM Test.EntityEmbeddings | |
| WHERE label = '{name}' | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| return data[0][0].split("/")[-1] | |
| def get_most_similar_diseases_from_uri( | |
| engine, original_disease_uri: str, threshold: float = 0.8 | |
| ) -> List[str]: | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT * FROM Test.EntityEmbeddings | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| all_diseases = [row[1] for row in data if row[1] != "nan"] | |
| return all_diseases | |
| def get_uri_from_name(engine, name: str) -> str: | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT uri FROM Test.EntityEmbeddings | |
| WHERE label = '{name}' | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| return data[0][0].split("/")[-1] | |
| def get_most_similar_diseases_from_uri( | |
| engine, original_disease_uri: str, threshold: float = 0.8 | |
| ) -> List[str]: | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2, | |
| VECTOR_COSINE(e1.embedding, e2.embedding) AS distance | |
| FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
| WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}' | |
| AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold} | |
| AND e1.uri != e2.uri | |
| ORDER BY distance DESC | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| similar_diseases = [ | |
| (row[1].split("/")[-1], row[3], row[4]) for row in data if row[3] != "nan" | |
| ] | |
| return similar_diseases | |
| def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]: | |
| # Request: | |
| # curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \ | |
| # -H "accept: text/csv" | |
| request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}" | |
| response = requests.get(request_url, headers={"accept": "application/json"}) | |
| return response.json() | |
| def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]: | |
| clinical_records = [] | |
| for clinical_record_id in clinical_record_ids: | |
| clinical_record_info = get_clinical_record_info(clinical_record_id) | |
| clinical_records.append(clinical_record_info) | |
| return clinical_records | |
| def get_similarities_among_diseases_uris( | |
| uri_list: List[str], | |
| ) -> List[tuple[str, str, float]]: | |
| uri_list = ", ".join([f"'{uri}'" for uri in uri_list]) | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance | |
| FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
| WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| return [ | |
| { | |
| "uri1": row[0].split("/")[-1], | |
| "uri2": row[1].split("/")[-1], | |
| "distance": float(row[2]), | |
| } | |
| for row in data | |
| ] | |
| def augment_the_set_of_diseaces(diseases: List[str]) -> str: | |
| augmented_diseases = diseases.copy() | |
| for i in range(10 - len(augmented_diseases)): | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(augmented_diseases)}) AS score | |
| FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2 | |
| WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in augmented_diseases])}) | |
| AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in augmented_diseases])}) | |
| AND e2.label != 'nan' | |
| GROUP BY e2.label | |
| ORDER BY score DESC | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| augmented_diseases.append(data[0][0]) | |
| return augmented_diseases | |
| def get_embedding(string: str, encoder) -> List[float]: | |
| # Embed the string using sentence-transformers | |
| vector = encoder.encode(string, show_progress_bar=False) | |
| return vector | |
| def get_diseases_related_to_a_textual_description( | |
| description: str, encoder | |
| ) -> List[str]: | |
| # Embed the description using sentence-transformers | |
| description_embedding = get_embedding(description, encoder) | |
| string_representation = str(description_embedding.tolist())[1:-1] | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT TOP 10 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance | |
| FROM Test.DiseaseDescriptions d | |
| ORDER BY distance DESC | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| return [ | |
| {"uri": row[0], "distance": float(row[1])} | |
| for row in data | |
| if float(row[1]) > 0.8 | |
| ] | |
| def get_clinical_trials_related_to_diseases(diseases: List[str], encoder) -> List[str]: | |
| # Embed the diseases using sentence-transformers | |
| diseases_string = ", ".join(diseases) | |
| disease_embedding = get_embedding(diseases_string, encoder) | |
| string_representation = str(disease_embedding.tolist())[1:-1] | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| sql = f""" | |
| SELECT TOP 20 d.nct_id, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance | |
| FROM Test.ClinicalTrials d | |
| ORDER BY distance DESC | |
| """ | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| return [{"nct_id": row[0], "distance": row[1]} for row in data] | |
| def get_similarities_df(diseases: List[Dict[str, Any]]) -> pd.DataFrame: | |
| # Find out the score of each disease by averaging the cosine similarity of the embeddings of the diseases that include it as uri1 or uri2 | |
| df_diseases_similarities = pd.DataFrame(diseases) | |
| # Use uri1 as the index, and uri2 as the columns. The values are the distances. | |
| df_diseases_similarities = df_diseases_similarities.pivot( | |
| index="uri1", columns="uri2", values="distance" | |
| ) | |
| # Fill the diagonal with 1.0 | |
| df_diseases_similarities = df_diseases_similarities.fillna(1.0) | |
| return df_diseases_similarities | |
| def filter_out_less_promising_diseases(info_dicts: List[Dict[str, Any]]) -> List[str]: | |
| df_diseases_similarities = get_similarities_df(info_dicts) | |
| # Filter out the diseases that are 0.2 standard deviations below the mean | |
| mean = df_diseases_similarities.mean().mean() | |
| std = df_diseases_similarities.mean().std() | |
| filtered_diseases = df_diseases_similarities.mean()[ | |
| df_diseases_similarities.mean() > mean - 0.2 * std | |
| ].index.tolist() | |
| return [f'http://identifiers.org/medgen/{d}' for d in filtered_diseases], df_diseases_similarities | |
| def get_labels_of_diseases_from_uris(uris: List[str]) -> List[str]: | |
| with engine.connect() as conn: | |
| with conn.begin(): | |
| joined_uris = ", ".join([f"'{uri}'" for uri in uris]) | |
| sql = f""" | |
| SELECT label FROM Test.EntityEmbeddings | |
| WHERE uri IN ({joined_uris}) | |
| """ | |
| print(text(sql)) | |
| result = conn.execute(text(sql)) | |
| data = result.fetchall() | |
| return [row[0] for row in data] | |
| def to_capitalized_case(string: str) -> str: | |
| string = string.replace("_", " ") | |
| if string.isupper(): | |
| return string[0] + string[1:].lower() | |
| def list_to_capitalized_case(strings: List[str]) -> str: | |
| strings = [to_capitalized_case(s) for s in strings] | |
| return ", ".join(strings) | |
| def render_trial_details(trial: dict) -> None: | |
| # TODO: handle key errors for all cases (→ do not render) | |
| official_title = trial["protocolSection"]["identificationModule"]["officialTitle"] | |
| st.write(f"##### {official_title}") | |
| try: | |
| st.write(trial["protocolSection"]["descriptionModule"]["briefSummary"]) | |
| except KeyError: | |
| try: | |
| st.write( | |
| trial["protocolSection"]["descriptionModule"]["detailedDescription"] | |
| ) | |
| except KeyError: | |
| st.error("No description available.") | |
| st.write("###### Status") | |
| try: | |
| status_module = { | |
| "Status": to_capitalized_case( | |
| trial["protocolSection"]["statusModule"]["overallStatus"] | |
| ), | |
| "Status Date": trial["protocolSection"]["statusModule"][ | |
| "statusVerifiedDate" | |
| ], | |
| "Has Results": trial["hasResults"], | |
| } | |
| st.table(status_module) | |
| except KeyError: | |
| st.info("No status information available.") | |
| st.write("###### Design") | |
| try: | |
| design_module = { | |
| "Study Type": to_capitalized_case( | |
| trial["protocolSection"]["designModule"]["studyType"] | |
| ), | |
| "Phases": list_to_capitalized_case( | |
| trial["protocolSection"]["designModule"]["phases"] | |
| ), | |
| "Allocation": to_capitalized_case( | |
| trial["protocolSection"]["designModule"]["designInfo"]["allocation"] | |
| ), | |
| "Primary Purpose": to_capitalized_case( | |
| trial["protocolSection"]["designModule"]["designInfo"]["primaryPurpose"] | |
| ), | |
| "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][ | |
| "count" | |
| ], | |
| "Masking": to_capitalized_case( | |
| trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][ | |
| "masking" | |
| ] | |
| ), | |
| "Who Masked": list_to_capitalized_case( | |
| trial["protocolSection"]["designModule"]["designInfo"]["maskingInfo"][ | |
| "whoMasked" | |
| ] | |
| ), | |
| } | |
| st.table(design_module) | |
| except KeyError: | |
| st.info("No design information available.") | |
| st.write("###### Interventions") | |
| try: | |
| interventions_module = {} | |
| for intervention in trial["protocolSection"]["armsInterventionsModule"][ | |
| "interventions" | |
| ]: | |
| name = intervention["name"] | |
| desc = intervention["description"] | |
| interventions_module[name] = desc | |
| st.table(interventions_module) | |
| except KeyError: | |
| st.info("No interventions information available.") | |
| # Button to go to ClinicalTrials.gov and see the trial. It takes the user to the official page of the trial. | |
| st.markdown( | |
| f"See more in [ClinicalTrials.gov](https://clinicaltrials.gov/study/{trial['protocolSection']['identificationModule']['nctId']})" | |
| ) | |
| if __name__ == "__main__": | |
| username = "demo" | |
| password = "demo" | |
| hostname = os.getenv("IRIS_HOSTNAME", "localhost") | |
| port = "1972" | |
| namespace = "USER" | |
| CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}" | |
| try: | |
| engine = create_engine(CONNECTION_STRING) | |
| diseases = get_most_similar_diseases_from_uri("C1843013") | |
| for disease in diseases: | |
| print(disease) | |
| except Exception as e: | |
| print(e) | |
| try: | |
| print(get_uri_from_name(engine, "Alzheimer disease 3")) | |
| except Exception as e: | |
| print(e) | |
| clinical_record_info = get_clinical_records_by_ids(["NCT00841061"]) | |
| print(clinical_record_info) | |
| textual_description = ( | |
| "A disease that causes memory loss and other cognitive impairments." | |
| ) | |
| encoder = SentenceTransformer("allenai-specter") | |
| diseases = get_diseases_related_to_a_textual_description( | |
| textual_description, encoder | |
| ) | |
| for disease in diseases: | |
| print(disease) | |
| try: | |
| similarities = get_similarities_among_diseases_uris( | |
| [ | |
| "http://identifiers.org/medgen/C4553765", | |
| "http://identifiers.org/medgen/C4553176", | |
| "http://identifiers.org/medgen/C4024935", | |
| ] | |
| ) | |
| for similarity in similarities: | |
| print( | |
| f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}' | |
| ) | |
| except Exception as e: | |
| print(e) | |
| # %% | |