Spaces:
Sleeping
Sleeping
ACMCMC
commited on
Commit
·
1f35211
1
Parent(s):
2408e3d
WIP app
Browse files
app.py
CHANGED
|
@@ -3,8 +3,10 @@ from streamlit_agraph import agraph, Node, Edge, Config
|
|
| 3 |
import os
|
| 4 |
from sqlalchemy import create_engine, text
|
| 5 |
import pandas as pd
|
| 6 |
-
from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description
|
| 7 |
import json
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
username = 'demo'
|
|
@@ -15,11 +17,17 @@ namespace = 'USER'
|
|
| 15 |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
|
| 16 |
engine = create_engine(CONNECTION_STRING)
|
| 17 |
|
| 18 |
-
def handle_click_on_analyze_button():
|
| 19 |
# 1. Embed the textual description that the user entered using the model
|
| 20 |
-
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(description_input)
|
| 21 |
# 2. Get 5 diseases with the highest cosine silimarity from the DB
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
|
|
|
|
|
|
|
|
|
| 23 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
| 24 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
| 25 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
|
@@ -31,7 +39,10 @@ def handle_click_on_analyze_button():
|
|
| 31 |
|
| 32 |
st.write("# Klìnic")
|
| 33 |
|
| 34 |
-
description_input = st.text_input(label="Enter the disease description 👇")
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
st.write(":red[Here should be the graph]") # TODO remove
|
| 37 |
chart_data = pd.DataFrame(
|
|
|
|
| 3 |
import os
|
| 4 |
from sqlalchemy import create_engine, text
|
| 5 |
import pandas as pd
|
| 6 |
+
from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name, get_diseases_related_to_a_textual_description, get_similarities_among_diseases_uris
|
| 7 |
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
|
| 11 |
|
| 12 |
username = 'demo'
|
|
|
|
| 17 |
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
|
| 18 |
engine = create_engine(CONNECTION_STRING)
|
| 19 |
|
| 20 |
+
def handle_click_on_analyze_button(user_text):
|
| 21 |
# 1. Embed the textual description that the user entered using the model
|
|
|
|
| 22 |
# 2. Get 5 diseases with the highest cosine silimarity from the DB
|
| 23 |
+
encoder = SentenceTransformer("allenai-specter")
|
| 24 |
+
diseases_related_to_the_user_text = get_diseases_related_to_a_textual_description(user_text, encoder)
|
| 25 |
+
#for disease_label in diseases_related_to_the_user_text:
|
| 26 |
+
# st.text(disease_label)
|
| 27 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
| 28 |
+
diseases_uris = [disease['uri'] for disease in diseases_related_to_the_user_text]
|
| 29 |
+
get_similarities_among_diseases_uris(diseases_uris)
|
| 30 |
+
print(diseases_related_to_the_user_text)
|
| 31 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
| 32 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
| 33 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
|
|
|
| 39 |
|
| 40 |
st.write("# Klìnic")
|
| 41 |
|
| 42 |
+
description_input = st.text_input(label="Enter the disease description 👇", placeholder='A disease that causes memory loss and other cognitive impairments.')
|
| 43 |
+
if st.button("Analyze"):
|
| 44 |
+
handle_click_on_analyze_button(description_input)
|
| 45 |
+
# TODO: also when user clicks enter
|
| 46 |
|
| 47 |
st.write(":red[Here should be the graph]") # TODO remove
|
| 48 |
chart_data = pd.DataFrame(
|
utils.py
CHANGED
|
@@ -5,6 +5,15 @@ from sqlalchemy import create_engine, text
|
|
| 5 |
import requests
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def get_all_diseases_name(engine) -> List[List[str]]:
|
| 10 |
with engine.connect() as conn:
|
|
@@ -98,46 +107,48 @@ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str
|
|
| 98 |
return clinical_records
|
| 99 |
|
| 100 |
|
| 101 |
-
def
|
| 102 |
-
uri_list
|
|
|
|
|
|
|
| 103 |
with engine.connect() as conn:
|
| 104 |
with conn.begin():
|
| 105 |
sql = f"""
|
| 106 |
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
|
| 107 |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
|
| 108 |
-
WHERE e1.uri IN {uri_list} AND e2.uri IN {uri_list} AND e1.uri != e2.uri
|
| 109 |
"""
|
| 110 |
result = conn.execute(text(sql))
|
| 111 |
data = result.fetchall()
|
| 112 |
return data
|
| 113 |
|
| 114 |
|
| 115 |
-
encoder
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def get_embedding(string: str) -> List[float]:
|
| 119 |
# Embed the string using sentence-transformers
|
| 120 |
vector = encoder.encode(string, show_progress_bar=False)
|
| 121 |
return vector
|
| 122 |
|
| 123 |
|
| 124 |
-
def get_diseases_related_to_a_textual_description(
|
|
|
|
|
|
|
| 125 |
# Embed the description using sentence-transformers
|
| 126 |
-
description_embedding = get_embedding(description)
|
| 127 |
-
print(f
|
| 128 |
string_representation = str(description_embedding.tolist())[1:-1]
|
| 129 |
-
print(f
|
| 130 |
|
| 131 |
with engine.connect() as conn:
|
| 132 |
with conn.begin():
|
| 133 |
sql = f"""
|
| 134 |
-
SELECT TOP 5 uri, VECTOR_COSINE(
|
| 135 |
-
FROM Test.DiseaseDescriptions
|
| 136 |
ORDER BY distance DESC
|
| 137 |
"""
|
| 138 |
result = conn.execute(text(sql))
|
| 139 |
data = result.fetchall()
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
|
@@ -164,9 +175,29 @@ if __name__ == "__main__":
|
|
| 164 |
clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
|
| 165 |
print(clinical_record_info)
|
| 166 |
|
| 167 |
-
textual_description =
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
for disease in diseases:
|
| 170 |
print(disease)
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
# %%
|
|
|
|
| 5 |
import requests
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
|
| 8 |
+
username = "demo"
|
| 9 |
+
password = "demo"
|
| 10 |
+
hostname = os.getenv("IRIS_HOSTNAME", "localhost")
|
| 11 |
+
port = "1972"
|
| 12 |
+
namespace = "USER"
|
| 13 |
+
CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
|
| 14 |
+
|
| 15 |
+
engine = create_engine(CONNECTION_STRING)
|
| 16 |
+
|
| 17 |
|
| 18 |
def get_all_diseases_name(engine) -> List[List[str]]:
|
| 19 |
with engine.connect() as conn:
|
|
|
|
| 107 |
return clinical_records
|
| 108 |
|
| 109 |
|
| 110 |
+
def get_similarities_among_diseases_uris(
|
| 111 |
+
uri_list: List[str],
|
| 112 |
+
) -> List[tuple[str, str, float]]:
|
| 113 |
+
uri_list = ", ".join([f"'{uri}'" for uri in uri_list])
|
| 114 |
with engine.connect() as conn:
|
| 115 |
with conn.begin():
|
| 116 |
sql = f"""
|
| 117 |
SELECT e1.uri AS uri1, e2.uri AS uri2, VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
|
| 118 |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
|
| 119 |
+
WHERE e1.uri IN ({uri_list}) AND e2.uri IN ({uri_list}) AND e1.uri != e2.uri
|
| 120 |
"""
|
| 121 |
result = conn.execute(text(sql))
|
| 122 |
data = result.fetchall()
|
| 123 |
return data
|
| 124 |
|
| 125 |
|
| 126 |
+
def get_embedding(string: str, encoder) -> List[float]:
|
|
|
|
|
|
|
|
|
|
| 127 |
# Embed the string using sentence-transformers
|
| 128 |
vector = encoder.encode(string, show_progress_bar=False)
|
| 129 |
return vector
|
| 130 |
|
| 131 |
|
| 132 |
+
def get_diseases_related_to_a_textual_description(
|
| 133 |
+
description: str, encoder
|
| 134 |
+
) -> List[str]:
|
| 135 |
# Embed the description using sentence-transformers
|
| 136 |
+
description_embedding = get_embedding(description, encoder)
|
| 137 |
+
print(f"Size of the embedding: {len(description_embedding)}")
|
| 138 |
string_representation = str(description_embedding.tolist())[1:-1]
|
| 139 |
+
print(f"String representation: {string_representation}")
|
| 140 |
|
| 141 |
with engine.connect() as conn:
|
| 142 |
with conn.begin():
|
| 143 |
sql = f"""
|
| 144 |
+
SELECT TOP 5 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
| 145 |
+
FROM Test.DiseaseDescriptions d
|
| 146 |
ORDER BY distance DESC
|
| 147 |
"""
|
| 148 |
result = conn.execute(text(sql))
|
| 149 |
data = result.fetchall()
|
| 150 |
+
|
| 151 |
+
return [{"uri": row[0], "distance": row[1]} for row in data]
|
| 152 |
|
| 153 |
|
| 154 |
if __name__ == "__main__":
|
|
|
|
| 175 |
clinical_record_info = get_clinical_records_by_ids(["NCT00841061"])
|
| 176 |
print(clinical_record_info)
|
| 177 |
|
| 178 |
+
textual_description = (
|
| 179 |
+
"A disease that causes memory loss and other cognitive impairments."
|
| 180 |
+
)
|
| 181 |
+
encoder = SentenceTransformer("allenai-specter")
|
| 182 |
+
diseases = get_diseases_related_to_a_textual_description(
|
| 183 |
+
textual_description, encoder
|
| 184 |
+
)
|
| 185 |
for disease in diseases:
|
| 186 |
print(disease)
|
| 187 |
|
| 188 |
+
try:
|
| 189 |
+
similarities = get_similarities_among_diseases_uris(
|
| 190 |
+
[
|
| 191 |
+
"http://identifiers.org/medgen/C4553765",
|
| 192 |
+
"http://identifiers.org/medgen/C4553176",
|
| 193 |
+
"http://identifiers.org/medgen/C4024935",
|
| 194 |
+
]
|
| 195 |
+
)
|
| 196 |
+
for similarity in similarities:
|
| 197 |
+
print(
|
| 198 |
+
f'{similarity[0].split("/")[-1]} and {similarity[1].split("/")[-1]} have a similarity of {similarity[2]}'
|
| 199 |
+
)
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(e)
|
| 202 |
+
|
| 203 |
# %%
|