acmc commited on
Commit
93e1b64
·
0 Parent(s):

First commit in HuggingFace

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.assets/architecture.png ADDED
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.csv filter=lfs diff=lfs merge=lfs -text
2
+ *.RRF filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ *.tsv
163
+ doctests/
164
+
165
+ file_db/
166
+ clinical_trials.csv
MGCONSO.RRF ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe1722c72d4e2eef72f4b14468ce1fb05bd22952b20cda7249bce4d17ef0607f
3
+ size 74127292
MGDEF.RRF ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb47c3f4e1731560571844b9520747987352d24baf37b8907aefe817383d9e18
3
+ size 17450055
MGREL.RRF ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ae7bd453b114dad737adbc65aa23028aaf02370d6c7d37877251629e8162ae0
3
+ size 94519907
README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Klìnic
3
+ emoji: 👍🏻
4
+ colorFrom: pink
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.34.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # hackupc-24
13
+
14
+ ## Architecture
15
+ ![alt text](.assets/architecture.png "Architecture")
16
+
17
+ ## Setup
18
+
19
+ 1. Start the IRIS Docker container:
20
+ ```Shell
21
+ docker-compose up
22
+ ```
23
+ 2. Start a Jupyter Notebook
24
+ 3. Navigate to http://localhost:52773/csp/sys/UtilHome.csp to access IRIS and login with username: `demo`, password: `demo`
25
+ - You can execute SQL queries at 'System Explorer' → 'SQL'
26
+ 4. run [vector_search.ipynb](./vector_search.ipynb)
27
+ 5. run the frontend: `streamlit run app.py`
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_agraph import agraph, Node, Edge, Config
3
+ import os
4
+ from sqlalchemy import create_engine, text
5
+ import pandas as pd
6
+ from utils import get_all_diseases_name, get_most_similar_diseases_from_uri, get_uri_from_name
7
+ import json
8
+
9
+
10
+ username = 'demo'
11
+ password = 'demo'
12
+ hostname = os.getenv('IRIS_HOSTNAME', 'localhost')
13
+ port = '1972'
14
+ namespace = 'USER'
15
+ CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
16
+ engine = create_engine(CONNECTION_STRING)
17
+
18
+ def handle_click_on_analyze_button():
19
+ # 1. Embed the textual description that the user entered using the model ()
20
+ # 2. Get 5 diseases with the highest cosine silimarity from the DB
21
+ # 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
22
+ # 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
23
+ # 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
24
+ # 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
25
+ # 7. Use an LLM to get a summary of the clinical trials, in plain text format
26
+ # 8. Use an LLM to extract numerical data from the clinical trials (e.g. number of patients, number of deaths, etc.). Get summary statistics out of that.
27
+ # 9. Show the results to the user: graph of the diseases chosen, summary of the clinical trials, summary statistics of the clinical trials, and list of the details of the clinical trials considered
28
+ pass
29
+
30
+
31
+ st.write("# Klìnic")
32
+
33
+ description_input = st.text_input(label="Enter the disease description 👇")
34
+
35
+ st.write(":red[Here should be the graph]") # TODO remove
36
+ chart_data = pd.DataFrame(
37
+ np.random.randn(20, 3), columns=["a", "b", "c"]
38
+ ) # TODO remove
39
+ st.scatter_chart(chart_data) # TODO remove
40
+
41
+ st.write("## Disease Overview")
42
+ disease_overview = ":red[lorem ipsum]" # TODO
43
+ st.write(disease_overview)
44
+
45
+ st.write("## Clinical Trials Details")
46
+ trials = []
47
+ # TODO replace mock data
48
+ with open("mock_trial.json") as f:
49
+ d = json.load(f)
50
+ for i in range(0, 5):
51
+ trials.append(d)
52
+
53
+ for trial in trials:
54
+ with st.expander(f"{trial['protocolSection']['identificationModule']['nctId']}"):
55
+ official_title = trial["protocolSection"]["identificationModule"][
56
+ "officialTitle"
57
+ ]
58
+ st.write(f"##### {official_title}")
59
+
60
+ brief_summary = trial["protocolSection"]["descriptionModule"]["briefSummary"]
61
+ st.write(brief_summary)
62
+
63
+ status_module = {
64
+ "Status": trial["protocolSection"]["statusModule"]["overallStatus"],
65
+ "Status Date": trial["protocolSection"]["statusModule"][
66
+ "statusVerifiedDate"
67
+ ],
68
+ }
69
+ st.write("###### Status")
70
+ st.table(status_module)
71
+
72
+ design_module = {
73
+ "Study Type": trial["protocolSection"]["designModule"]["studyType"],
74
+ # "Phases": trial["protocolSection"]["designModule"]["phases"], # breaks formatting because it is an array
75
+ "Allocation": trial["protocolSection"]["designModule"]["designInfo"][
76
+ "allocation"
77
+ ],
78
+ "Participants": trial["protocolSection"]["designModule"]["enrollmentInfo"][
79
+ "count"
80
+ ],
81
+ }
82
+ st.write("###### Design")
83
+ st.table(design_module)
84
+
85
+ # TODO more modules?
calculate_smilar_nodes.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ def transe_distance(head, tail, relation, entity_embeddings, relation_embeddings):
3
+ head_embedding = entity_embeddings[head]
4
+ tail_embedding = entity_embeddings[tail]
5
+ relation_embeddings = relation_embeddings[relation]
6
+ distance = head_embedding + relation_embeddings - tail_embedding
7
+ return distance
8
+
9
+ def calculate_similar_nodes(node, entity_embeddings, relation_embeddings, top_n=10):
10
+ distances = []
11
+ for i in range(len(entity_embeddings)):
12
+ distance = transe_distance(node, i, 0, entity_embeddings, relation_embeddings)
13
+ distances.append((i, distance))
14
+ distances.sort(key=lambda x: x[1].norm().item())
15
+ return distances[:top_n]
16
+
17
+ # %%
18
+ import pandas as pd
19
+
20
+ # Load the embeddings from the CSV files
21
+ entity_embeddings = pd.read_csv("entity_embeddings.csv", index_col=0)
22
+ # The embedding column is a string, convert it to a tensor
23
+ import torch
24
+
25
+ entity_embeddings["embedding"] = entity_embeddings["embedding"].apply(
26
+ lambda x: torch.tensor(eval(x))
27
+ )
28
+ entity_embeddings.head()
29
+
30
+ # Now, load the relation embeddings
31
+ relation_embeddings = pd.read_csv("relation_embeddings.csv", index_col=0)
32
+ relation_embeddings["embedding"] = relation_embeddings["embedding"].apply(
33
+ lambda x: torch.tensor(eval(x))
34
+ )
35
+ display(relation_embeddings.head())
36
+ # %%
37
+ # Find the index of the entity with the uri "http://identifiers.org/medgen/C0002395"
38
+ head = entity_embeddings[
39
+ entity_embeddings["uri"] == "http://identifiers.org/medgen/C0002395"
40
+ ].index[0]
41
+ # Find the index of the entity with the uri "http://identifiers.org/medgen/C1843013"
42
+ tail = entity_embeddings[
43
+ entity_embeddings["uri"] == "http://identifiers.org/medgen/C1843013"
44
+ ].index[0]
45
+ relation = 0
46
+ distance = transe_distance(
47
+ head,
48
+ tail,
49
+ relation,
50
+ entity_embeddings["embedding"],
51
+ relation_embeddings["embedding"],
52
+ )
53
+ print(
54
+ f'Distance between {entity_embeddings["label"][head]} ({head}) and {entity_embeddings["label"][tail]} ({tail}) via relation {relation_embeddings["label"][relation]} is {distance.norm().item()}'
55
+ )
56
+ # %%
57
+ # Calculate similar nodes to the head
58
+ similar_nodes = calculate_similar_nodes(head, entity_embeddings["embedding"], relation_embeddings["embedding"])
59
+ print(f"Similar nodes to {entity_embeddings['label'][head]} ({head}):")
60
+ # Print the similar nodes
61
+ for i, (node, distance) in enumerate(similar_nodes):
62
+ print(f"{i}: {entity_embeddings['label'][node]} ({node}) with distance {distance.norm().item()}")
63
+ # %%
clinical_trials_diseases.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2efed4939576aa55aa19b07949b58b4c3fc5022629e96f0f665581431c85b4a
3
+ size 86845185
clinical_trials_embeddings.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6f1cce06d76a6e83695fa6e1583292bcd85b7846488cde3c43ec2067409a91c
3
+ size 2111058864
clinical_trials_embeddings.ipynb ADDED
@@ -0,0 +1,1714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9db3b813-22dc-4209-86d2-42e935f5f5dd",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from langchain_community.document_loaders.csv_loader import CSVLoader\n",
11
+ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
12
+ "import pandas as pd\n",
13
+ "import langchain\n",
14
+ "import os\n",
15
+ "import openai\n",
16
+ "import ast\n",
17
+ "from langchain import OpenAI\n",
18
+ "from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain\n",
19
+ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
20
+ "from langchain.document_loaders import UnstructuredURLLoader\n",
21
+ "from langchain.embeddings import OpenAIEmbeddings\n",
22
+ "from langchain.vectorstores import FAISS"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "id": "fb48dccd-37e5-484a-a2fc-c482839b9ed9",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# loader = CSVLoader(file_path=\"trials/brief_summaries.csv\")\n",
33
+ "# data = loader.load()"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 3,
39
+ "id": "cdcc3107-e6a8-47ca-bd89-e381fbcf9b9e",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "# df= pd.read_csv(\"trials/brief_summaries.txt\", delimiter=\"|\")\n",
44
+ "# df.shape"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "id": "c6a94c14-b197-4eff-9941-1ca52069cd5c",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "# df.head(20)"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 5,
60
+ "id": "95c000ff-bf0c-4489-8643-93a238db41dc",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "os.environ[\"OPENAI_API_KEY\"] = (\n",
65
+ " \"sk-proj-CG2E98bSWs53X2eWO0Z4T3BlbkFJLm7H1vfkbua0zP548CKQ\"\n",
66
+ ")"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 6,
72
+ "id": "e2e03936-fcce-4287-bfe6-e31f1b69f693",
73
+ "metadata": {},
74
+ "outputs": [
75
+ {
76
+ "name": "stderr",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "/home/aldan/miniconda3/envs/hackupc/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: The class `OpenAI` was deprecated in LangChain 0.0.10 and will be removed in 0.2.0. An updated version of the class exists in the langchain-openai package and should be used instead. To use it run `pip install -U langchain-openai` and import as `from langchain_openai import OpenAI`.\n",
80
+ " warn_deprecated(\n"
81
+ ]
82
+ }
83
+ ],
84
+ "source": [
85
+ "llm = OpenAI(temperature=0.6, max_tokens=500)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 7,
91
+ "id": "aede3d75-2441-44f6-b3cf-2f86f050da24",
92
+ "metadata": {},
93
+ "outputs": [
94
+ {
95
+ "ename": "FileNotFoundError",
96
+ "evalue": "[Errno 2] No such file or directory: 'clinical_trials.csv'",
97
+ "output_type": "error",
98
+ "traceback": [
99
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
100
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
101
+ "Cell \u001b[0;32mIn[7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m df_trials\u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mclinical_trials.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(df_trials\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 3\u001b[0m df_trials\u001b[38;5;241m.\u001b[39mhead()\n",
102
+ "File \u001b[0;32m~/miniconda3/envs/hackupc/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1024\u001b[0m, in \u001b[0;36mread_csv\u001b[0;34m(filepath_or_buffer, sep, delimiter, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skipinitialspace, skiprows, skipfooter, nrows, na_values, keep_default_na, na_filter, verbose, skip_blank_lines, parse_dates, infer_datetime_format, keep_date_col, date_parser, date_format, dayfirst, cache_dates, iterator, chunksize, compression, thousands, decimal, lineterminator, quotechar, quoting, doublequote, escapechar, comment, encoding, encoding_errors, dialect, on_bad_lines, delim_whitespace, low_memory, memory_map, float_precision, storage_options, dtype_backend)\u001b[0m\n\u001b[1;32m 1011\u001b[0m kwds_defaults \u001b[38;5;241m=\u001b[39m _refine_defaults_read(\n\u001b[1;32m 1012\u001b[0m dialect,\n\u001b[1;32m 1013\u001b[0m delimiter,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1020\u001b[0m dtype_backend\u001b[38;5;241m=\u001b[39mdtype_backend,\n\u001b[1;32m 1021\u001b[0m )\n\u001b[1;32m 1022\u001b[0m kwds\u001b[38;5;241m.\u001b[39mupdate(kwds_defaults)\n\u001b[0;32m-> 1024\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_read\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n",
103
+ "File \u001b[0;32m~/miniconda3/envs/hackupc/lib/python3.10/site-packages/pandas/io/parsers/readers.py:618\u001b[0m, in \u001b[0;36m_read\u001b[0;34m(filepath_or_buffer, kwds)\u001b[0m\n\u001b[1;32m 615\u001b[0m _validate_names(kwds\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 617\u001b[0m \u001b[38;5;66;03m# Create the parser.\u001b[39;00m\n\u001b[0;32m--> 618\u001b[0m parser \u001b[38;5;241m=\u001b[39m \u001b[43mTextFileReader\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilepath_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 620\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chunksize \u001b[38;5;129;01mor\u001b[39;00m iterator:\n\u001b[1;32m 621\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parser\n",
104
+ "File \u001b[0;32m~/miniconda3/envs/hackupc/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1618\u001b[0m, in \u001b[0;36mTextFileReader.__init__\u001b[0;34m(self, f, engine, **kwds)\u001b[0m\n\u001b[1;32m 1615\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptions[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwds[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhas_index_names\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 1617\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles: IOHandles \u001b[38;5;241m|\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1618\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_engine \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_make_engine\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mengine\u001b[49m\u001b[43m)\u001b[49m\n",
105
+ "File \u001b[0;32m~/miniconda3/envs/hackupc/lib/python3.10/site-packages/pandas/io/parsers/readers.py:1878\u001b[0m, in \u001b[0;36mTextFileReader._make_engine\u001b[0;34m(self, f, engine)\u001b[0m\n\u001b[1;32m 1876\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m mode:\n\u001b[1;32m 1877\u001b[0m mode \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m-> 1878\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;241m=\u001b[39m \u001b[43mget_handle\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1879\u001b[0m \u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1880\u001b[0m \u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43mcompression\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcompression\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mmemory_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmemory_map\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_text\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_text\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mencoding_errors\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstrict\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstorage_options\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1889\u001b[0m f \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandles\u001b[38;5;241m.\u001b[39mhandle\n",
106
+ "File \u001b[0;32m~/miniconda3/envs/hackupc/lib/python3.10/site-packages/pandas/io/common.py:873\u001b[0m, in \u001b[0;36mget_handle\u001b[0;34m(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)\u001b[0m\n\u001b[1;32m 868\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(handle, \u001b[38;5;28mstr\u001b[39m):\n\u001b[1;32m 869\u001b[0m \u001b[38;5;66;03m# Check whether the filename is to be opened in binary mode.\u001b[39;00m\n\u001b[1;32m 870\u001b[0m \u001b[38;5;66;03m# Binary mode does not support 'encoding' and 'newline'.\u001b[39;00m\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mencoding \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ioargs\u001b[38;5;241m.\u001b[39mmode:\n\u001b[1;32m 872\u001b[0m \u001b[38;5;66;03m# Encoding\u001b[39;00m\n\u001b[0;32m--> 873\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\n\u001b[1;32m 874\u001b[0m \u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 875\u001b[0m \u001b[43m \u001b[49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mioargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43merrors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merrors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mnewline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 881\u001b[0m \u001b[38;5;66;03m# Binary mode\u001b[39;00m\n\u001b[1;32m 882\u001b[0m handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mopen\u001b[39m(handle, ioargs\u001b[38;5;241m.\u001b[39mmode)\n",
107
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'clinical_trials.csv'"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "df_trials = pd.read_csv(\"clinical_trials.csv\")\n",
113
+ "print(df_trials.shape)\n",
114
+ "df_trials.head()"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 15,
120
+ "id": "65589fc4-4f55-4b72-9935-c3b163fde0e2",
121
+ "metadata": {},
122
+ "outputs": [
123
+ {
124
+ "data": {
125
+ "text/html": [
126
+ "<div>\n",
127
+ "<style scoped>\n",
128
+ " .dataframe tbody tr th:only-of-type {\n",
129
+ " vertical-align: middle;\n",
130
+ " }\n",
131
+ "\n",
132
+ " .dataframe tbody tr th {\n",
133
+ " vertical-align: top;\n",
134
+ " }\n",
135
+ "\n",
136
+ " .dataframe thead th {\n",
137
+ " text-align: right;\n",
138
+ " }\n",
139
+ "</style>\n",
140
+ "<table border=\"1\" class=\"dataframe\">\n",
141
+ " <thead>\n",
142
+ " <tr style=\"text-align: right;\">\n",
143
+ " <th></th>\n",
144
+ " <th>desease_condition</th>\n",
145
+ " </tr>\n",
146
+ " </thead>\n",
147
+ " <tbody>\n",
148
+ " <tr>\n",
149
+ " <th>0</th>\n",
150
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
151
+ " </tr>\n",
152
+ " <tr>\n",
153
+ " <th>1</th>\n",
154
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
155
+ " </tr>\n",
156
+ " <tr>\n",
157
+ " <th>2</th>\n",
158
+ " <td>['tuberculosis', 'latent tuberculosis', 'infec...</td>\n",
159
+ " </tr>\n",
160
+ " <tr>\n",
161
+ " <th>3</th>\n",
162
+ " <td>['heart failure', 'heart diseases', 'cardiovas...</td>\n",
163
+ " </tr>\n",
164
+ " <tr>\n",
165
+ " <th>4</th>\n",
166
+ " <td>['lymphoma', 'neoplasms by histologic type', '...</td>\n",
167
+ " </tr>\n",
168
+ " </tbody>\n",
169
+ "</table>\n",
170
+ "</div>"
171
+ ],
172
+ "text/plain": [
173
+ " desease_condition\n",
174
+ "0 ['marijuana abuse', 'substance-related disorde...\n",
175
+ "1 ['marijuana abuse', 'substance-related disorde...\n",
176
+ "2 ['tuberculosis', 'latent tuberculosis', 'infec...\n",
177
+ "3 ['heart failure', 'heart diseases', 'cardiovas...\n",
178
+ "4 ['lymphoma', 'neoplasms by histologic type', '..."
179
+ ]
180
+ },
181
+ "execution_count": 15,
182
+ "metadata": {},
183
+ "output_type": "execute_result"
184
+ }
185
+ ],
186
+ "source": [
187
+ "df_trials_filtered = df_trials[[\"desease_condition\"]]\n",
188
+ "df_trials_filtered.head()"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": 16,
194
+ "id": "88e28056-e340-416c-a1a9-4a6c29556dc7",
195
+ "metadata": {},
196
+ "outputs": [
197
+ {
198
+ "data": {
199
+ "text/plain": [
200
+ "\"['marijuana abuse', 'substance-related disorders', 'chemically-induced disorders', 'mental disorders']\""
201
+ ]
202
+ },
203
+ "execution_count": 16,
204
+ "metadata": {},
205
+ "output_type": "execute_result"
206
+ }
207
+ ],
208
+ "source": [
209
+ "df_trials_filtered[\"desease_condition\"].iloc[0]"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": 17,
215
+ "id": "5bd8f876-0480-40a5-a32f-ca7ec137a70f",
216
+ "metadata": {},
217
+ "outputs": [
218
+ {
219
+ "name": "stderr",
220
+ "output_type": "stream",
221
+ "text": [
222
+ "C:\\Users\\ariji\\AppData\\Local\\Temp\\ipykernel_22340\\16068817.py:4: SettingWithCopyWarning: \n",
223
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
224
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
225
+ "\n",
226
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
227
+ " df_trials_filtered['desease_condition']= df_trials_filtered['desease_condition'].apply(list_to_string)\n"
228
+ ]
229
+ },
230
+ {
231
+ "data": {
232
+ "text/html": [
233
+ "<div>\n",
234
+ "<style scoped>\n",
235
+ " .dataframe tbody tr th:only-of-type {\n",
236
+ " vertical-align: middle;\n",
237
+ " }\n",
238
+ "\n",
239
+ " .dataframe tbody tr th {\n",
240
+ " vertical-align: top;\n",
241
+ " }\n",
242
+ "\n",
243
+ " .dataframe thead th {\n",
244
+ " text-align: right;\n",
245
+ " }\n",
246
+ "</style>\n",
247
+ "<table border=\"1\" class=\"dataframe\">\n",
248
+ " <thead>\n",
249
+ " <tr style=\"text-align: right;\">\n",
250
+ " <th></th>\n",
251
+ " <th>desease_condition</th>\n",
252
+ " </tr>\n",
253
+ " </thead>\n",
254
+ " <tbody>\n",
255
+ " <tr>\n",
256
+ " <th>0</th>\n",
257
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
258
+ " </tr>\n",
259
+ " <tr>\n",
260
+ " <th>1</th>\n",
261
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
262
+ " </tr>\n",
263
+ " <tr>\n",
264
+ " <th>2</th>\n",
265
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
266
+ " </tr>\n",
267
+ " <tr>\n",
268
+ " <th>3</th>\n",
269
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
270
+ " </tr>\n",
271
+ " <tr>\n",
272
+ " <th>4</th>\n",
273
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
274
+ " </tr>\n",
275
+ " <tr>\n",
276
+ " <th>...</th>\n",
277
+ " <td>...</td>\n",
278
+ " </tr>\n",
279
+ " <tr>\n",
280
+ " <th>440512</th>\n",
281
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
282
+ " </tr>\n",
283
+ " <tr>\n",
284
+ " <th>440513</th>\n",
285
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
286
+ " </tr>\n",
287
+ " <tr>\n",
288
+ " <th>440514</th>\n",
289
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
290
+ " </tr>\n",
291
+ " <tr>\n",
292
+ " <th>440515</th>\n",
293
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
294
+ " </tr>\n",
295
+ " <tr>\n",
296
+ " <th>440516</th>\n",
297
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
298
+ " </tr>\n",
299
+ " </tbody>\n",
300
+ "</table>\n",
301
+ "<p>440517 rows × 1 columns</p>\n",
302
+ "</div>"
303
+ ],
304
+ "text/plain": [
305
+ " desease_condition\n",
306
+ "0 marijuana abuse, substance-related disorders, ...\n",
307
+ "1 marijuana abuse, substance-related disorders, ...\n",
308
+ "2 tuberculosis, latent tuberculosis, infections,...\n",
309
+ "3 heart failure, heart diseases, cardiovascular ...\n",
310
+ "4 lymphoma, neoplasms by histologic type, neopla...\n",
311
+ "... ...\n",
312
+ "440512 obesity, overweight, overnutrition, nutrition ...\n",
313
+ "440513 obesity, overweight, overnutrition, nutrition ...\n",
314
+ "440514 obesity, overweight, overnutrition, nutrition ...\n",
315
+ "440515 autistic disorder, autism spectrum disorder, c...\n",
316
+ "440516 autistic disorder, autism spectrum disorder, c...\n",
317
+ "\n",
318
+ "[440517 rows x 1 columns]"
319
+ ]
320
+ },
321
+ "execution_count": 17,
322
+ "metadata": {},
323
+ "output_type": "execute_result"
324
+ }
325
+ ],
326
+ "source": [
327
+ "def list_to_string(disease_list):\n",
328
+ " disease_list = ast.literal_eval(disease_list)\n",
329
+ " return \", \".join(disease_list)\n",
330
+ "\n",
331
+ "\n",
332
+ "df_trials_filtered[\"desease_condition\"] = df_trials_filtered[\"desease_condition\"].apply(\n",
333
+ " list_to_string\n",
334
+ ")\n",
335
+ "df_trials_filtered"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": 18,
341
+ "id": "bbbb22e6-4883-4869-8ccc-95696bc67b1b",
342
+ "metadata": {},
343
+ "outputs": [
344
+ {
345
+ "data": {
346
+ "text/plain": [
347
+ "0 marijuana abuse, substance-related disorders, ...\n",
348
+ "1 marijuana abuse, substance-related disorders, ...\n",
349
+ "2 tuberculosis, latent tuberculosis, infections,...\n",
350
+ "3 heart failure, heart diseases, cardiovascular ...\n",
351
+ "4 lymphoma, neoplasms by histologic type, neopla...\n",
352
+ "Name: desease_condition, dtype: object"
353
+ ]
354
+ },
355
+ "execution_count": 18,
356
+ "metadata": {},
357
+ "output_type": "execute_result"
358
+ }
359
+ ],
360
+ "source": [
361
+ "df_trials_filtered[\"desease_condition\"].head()"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 19,
367
+ "id": "7f4e7ceb-8bfd-4294-a850-8935f88b6555",
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "df_trials_filtered.to_csv(\"diseases.csv\", index=False)"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": 20,
377
+ "id": "af1c5c2b-24a0-44a1-9e5d-7ee89ca4cccf",
378
+ "metadata": {},
379
+ "outputs": [
380
+ {
381
+ "data": {
382
+ "text/plain": [
383
+ "440517"
384
+ ]
385
+ },
386
+ "execution_count": 20,
387
+ "metadata": {},
388
+ "output_type": "execute_result"
389
+ }
390
+ ],
391
+ "source": [
392
+ "loader = CSVLoader(file_path=\"./diseases.csv\", encoding=\"utf-8\")\n",
393
+ "data = loader.load()\n",
394
+ "len(data)"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "cab89218-41ca-4048-886d-bc2c1c9b30bc",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": [
404
+ "embeddings = OpenAIEmbeddings()\n",
405
+ "vectorstore = FAISS.from_documents(data, embeddings)"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": null,
411
+ "id": "225ade4a-d004-44cc-a5ff-22ce2bfcac32",
412
+ "metadata": {},
413
+ "outputs": [],
414
+ "source": [
415
+ "file_path = \"vector_index.pkl\"\n",
416
+ "with open(file_path, \"wb\") as f:\n",
417
+ " pickle.dump(vectorstore, f)"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 98,
423
+ "id": "11912a93-ad02-41cb-8bce-2750c947fa74",
424
+ "metadata": {},
425
+ "outputs": [
426
+ {
427
+ "name": "stdout",
428
+ "output_type": "stream",
429
+ "text": [
430
+ "(440517, 2)\n"
431
+ ]
432
+ },
433
+ {
434
+ "data": {
435
+ "text/html": [
436
+ "<div>\n",
437
+ "<style scoped>\n",
438
+ " .dataframe tbody tr th:only-of-type {\n",
439
+ " vertical-align: middle;\n",
440
+ " }\n",
441
+ "\n",
442
+ " .dataframe tbody tr th {\n",
443
+ " vertical-align: top;\n",
444
+ " }\n",
445
+ "\n",
446
+ " .dataframe thead th {\n",
447
+ " text-align: right;\n",
448
+ " }\n",
449
+ "</style>\n",
450
+ "<table border=\"1\" class=\"dataframe\">\n",
451
+ " <thead>\n",
452
+ " <tr style=\"text-align: right;\">\n",
453
+ " <th></th>\n",
454
+ " <th>desease_condition</th>\n",
455
+ " <th>text</th>\n",
456
+ " </tr>\n",
457
+ " </thead>\n",
458
+ " <tbody>\n",
459
+ " <tr>\n",
460
+ " <th>0</th>\n",
461
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
462
+ " <td>nct_id: NCT03055377\\nsummary: This is a 12-wee...</td>\n",
463
+ " </tr>\n",
464
+ " <tr>\n",
465
+ " <th>1</th>\n",
466
+ " <td>['marijuana abuse', 'substance-related disorde...</td>\n",
467
+ " <td>nct_id: NCT03055377\\nsummary: This is a 12-wee...</td>\n",
468
+ " </tr>\n",
469
+ " <tr>\n",
470
+ " <th>2</th>\n",
471
+ " <td>['tuberculosis', 'latent tuberculosis', 'infec...</td>\n",
472
+ " <td>nct_id: NCT03042754\\nsummary: Early diagnosis ...</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <th>3</th>\n",
476
+ " <td>['heart failure', 'heart diseases', 'cardiovas...</td>\n",
477
+ " <td>nct_id: NCT03035123\\nsummary: The EduStra-HF s...</td>\n",
478
+ " </tr>\n",
479
+ " <tr>\n",
480
+ " <th>4</th>\n",
481
+ " <td>['lymphoma', 'neoplasms by histologic type', '...</td>\n",
482
+ " <td>nct_id: NCT02272751\\nsummary: This study will ...</td>\n",
483
+ " </tr>\n",
484
+ " </tbody>\n",
485
+ "</table>\n",
486
+ "</div>"
487
+ ],
488
+ "text/plain": [
489
+ " desease_condition \\\n",
490
+ "0 ['marijuana abuse', 'substance-related disorde... \n",
491
+ "1 ['marijuana abuse', 'substance-related disorde... \n",
492
+ "2 ['tuberculosis', 'latent tuberculosis', 'infec... \n",
493
+ "3 ['heart failure', 'heart diseases', 'cardiovas... \n",
494
+ "4 ['lymphoma', 'neoplasms by histologic type', '... \n",
495
+ "\n",
496
+ " text \n",
497
+ "0 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
498
+ "1 nct_id: NCT03055377\\nsummary: This is a 12-wee... \n",
499
+ "2 nct_id: NCT03042754\\nsummary: Early diagnosis ... \n",
500
+ "3 nct_id: NCT03035123\\nsummary: The EduStra-HF s... \n",
501
+ "4 nct_id: NCT02272751\\nsummary: This study will ... "
502
+ ]
503
+ },
504
+ "execution_count": 98,
505
+ "metadata": {},
506
+ "output_type": "execute_result"
507
+ }
508
+ ],
509
+ "source": [
510
+ "df_trials = pd.read_csv(\"clinical_trials.csv\")\n",
511
+ "print(df_trials.shape)\n",
512
+ "df_trials.head()"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": null,
518
+ "id": "25b31e55-2961-474d-92d8-5963f2c6bf84",
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": []
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "id": "918c078c-46fe-4d7b-9748-88c52a5b004a",
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": []
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "execution_count": null,
534
+ "id": "36e83202-97ad-425d-95ae-075a1e26a34e",
535
+ "metadata": {},
536
+ "outputs": [],
537
+ "source": []
538
+ },
539
+ {
540
+ "cell_type": "code",
541
+ "execution_count": null,
542
+ "id": "5d705875-5dd7-4c71-8d94-99c101020ac0",
543
+ "metadata": {},
544
+ "outputs": [],
545
+ "source": []
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": 2,
550
+ "id": "f2818abe-1a43-4d7d-92a7-7562812bf43d",
551
+ "metadata": {},
552
+ "outputs": [
553
+ {
554
+ "data": {
555
+ "text/html": [
556
+ "<div>\n",
557
+ "<style scoped>\n",
558
+ " .dataframe tbody tr th:only-of-type {\n",
559
+ " vertical-align: middle;\n",
560
+ " }\n",
561
+ "\n",
562
+ " .dataframe tbody tr th {\n",
563
+ " vertical-align: top;\n",
564
+ " }\n",
565
+ "\n",
566
+ " .dataframe thead th {\n",
567
+ " text-align: right;\n",
568
+ " }\n",
569
+ "</style>\n",
570
+ "<table border=\"1\" class=\"dataframe\">\n",
571
+ " <thead>\n",
572
+ " <tr style=\"text-align: right;\">\n",
573
+ " <th></th>\n",
574
+ " <th>desease_condition</th>\n",
575
+ " </tr>\n",
576
+ " </thead>\n",
577
+ " <tbody>\n",
578
+ " <tr>\n",
579
+ " <th>0</th>\n",
580
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
581
+ " </tr>\n",
582
+ " <tr>\n",
583
+ " <th>1</th>\n",
584
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
585
+ " </tr>\n",
586
+ " <tr>\n",
587
+ " <th>2</th>\n",
588
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
589
+ " </tr>\n",
590
+ " <tr>\n",
591
+ " <th>3</th>\n",
592
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
593
+ " </tr>\n",
594
+ " <tr>\n",
595
+ " <th>4</th>\n",
596
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
597
+ " </tr>\n",
598
+ " <tr>\n",
599
+ " <th>...</th>\n",
600
+ " <td>...</td>\n",
601
+ " </tr>\n",
602
+ " <tr>\n",
603
+ " <th>440512</th>\n",
604
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
605
+ " </tr>\n",
606
+ " <tr>\n",
607
+ " <th>440513</th>\n",
608
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
609
+ " </tr>\n",
610
+ " <tr>\n",
611
+ " <th>440514</th>\n",
612
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
613
+ " </tr>\n",
614
+ " <tr>\n",
615
+ " <th>440515</th>\n",
616
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
617
+ " </tr>\n",
618
+ " <tr>\n",
619
+ " <th>440516</th>\n",
620
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
621
+ " </tr>\n",
622
+ " </tbody>\n",
623
+ "</table>\n",
624
+ "<p>440517 rows × 1 columns</p>\n",
625
+ "</div>"
626
+ ],
627
+ "text/plain": [
628
+ " desease_condition\n",
629
+ "0 marijuana abuse, substance-related disorders, ...\n",
630
+ "1 marijuana abuse, substance-related disorders, ...\n",
631
+ "2 tuberculosis, latent tuberculosis, infections,...\n",
632
+ "3 heart failure, heart diseases, cardiovascular ...\n",
633
+ "4 lymphoma, neoplasms by histologic type, neopla...\n",
634
+ "... ...\n",
635
+ "440512 obesity, overweight, overnutrition, nutrition ...\n",
636
+ "440513 obesity, overweight, overnutrition, nutrition ...\n",
637
+ "440514 obesity, overweight, overnutrition, nutrition ...\n",
638
+ "440515 autistic disorder, autism spectrum disorder, c...\n",
639
+ "440516 autistic disorder, autism spectrum disorder, c...\n",
640
+ "\n",
641
+ "[440517 rows x 1 columns]"
642
+ ]
643
+ },
644
+ "execution_count": 2,
645
+ "metadata": {},
646
+ "output_type": "execute_result"
647
+ }
648
+ ],
649
+ "source": [
650
+ "df_trials_filtered = pd.read_csv(\"diseases.csv\")\n",
651
+ "df_trials_filtered"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": 3,
657
+ "id": "c89e3cf6-a376-4029-9c04-0f5e664a2237",
658
+ "metadata": {
659
+ "notebookRunGroups": {
660
+ "groupValue": "1"
661
+ }
662
+ },
663
+ "outputs": [
664
+ {
665
+ "data": {
666
+ "text/plain": [
667
+ "(440517, 1)"
668
+ ]
669
+ },
670
+ "execution_count": 3,
671
+ "metadata": {},
672
+ "output_type": "execute_result"
673
+ }
674
+ ],
675
+ "source": [
676
+ "df2 = df_trials_filtered # [:100]\n",
677
+ "df2.shape"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": 14,
683
+ "id": "c5012bcf-3e25-4f21-a29c-6bdbdafbb8c7",
684
+ "metadata": {},
685
+ "outputs": [],
686
+ "source": [
687
+ "from openai import OpenAI\n",
688
+ "\n",
689
+ "client = OpenAI()"
690
+ ]
691
+ },
692
+ {
693
+ "cell_type": "code",
694
+ "execution_count": 15,
695
+ "id": "40a480bd-6754-40b6-870c-42d10ce9a960",
696
+ "metadata": {},
697
+ "outputs": [],
698
+ "source": [
699
+ "def get_embeddings(text):\n",
700
+ " response = client.embeddings.create(\n",
701
+ " input=text, dimensions=128, model=\"text-embedding-3-small\"\n",
702
+ " )\n",
703
+ " return response.data[0].embedding"
704
+ ]
705
+ },
706
+ {
707
+ "cell_type": "code",
708
+ "execution_count": 4,
709
+ "id": "ef6d6b62-de0b-4bc6-a6eb-847ab8e99da5",
710
+ "metadata": {
711
+ "notebookRunGroups": {
712
+ "groupValue": "1"
713
+ }
714
+ },
715
+ "outputs": [
716
+ {
717
+ "name": "stderr",
718
+ "output_type": "stream",
719
+ "text": [
720
+ "/home/aldan/miniconda3/envs/hackupc/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
721
+ " from .autonotebook import tqdm as notebook_tqdm\n",
722
+ "/home/aldan/miniconda3/envs/hackupc/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
723
+ " warnings.warn(\n",
724
+ "/home/aldan/miniconda3/envs/hackupc/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
725
+ " warnings.warn(\n",
726
+ "Batches: 100%|██████████| 6884/6884 [04:32<00:00, 25.25it/s]\n"
727
+ ]
728
+ },
729
+ {
730
+ "name": "stdout",
731
+ "output_type": "stream",
732
+ "text": [
733
+ "CPU times: user 14min 6s, sys: 1min 31s, total: 15min 37s\n",
734
+ "Wall time: 4min 48s\n"
735
+ ]
736
+ },
737
+ {
738
+ "data": {
739
+ "text/plain": [
740
+ "(440517, 768)"
741
+ ]
742
+ },
743
+ "execution_count": 4,
744
+ "metadata": {},
745
+ "output_type": "execute_result"
746
+ }
747
+ ],
748
+ "source": [
749
+ "%%time\n",
750
+ "from sentence_transformers import SentenceTransformer\n",
751
+ "\n",
752
+ "encoder= SentenceTransformer(\"allenai-specter\")\n",
753
+ "vectors= encoder.encode(df2.desease_condition, show_progress_bar=True, batch_size=64)\n",
754
+ "vectors.shape"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 17,
760
+ "id": "7966d754-56d7-4555-a6c6-6a13772fb000",
761
+ "metadata": {},
762
+ "outputs": [
763
+ {
764
+ "name": "stdout",
765
+ "output_type": "stream",
766
+ "text": [
767
+ "CPU times: user 261 ms, sys: 13 ms, total: 274 ms\n",
768
+ "Wall time: 26.8 s\n"
769
+ ]
770
+ },
771
+ {
772
+ "name": "stderr",
773
+ "output_type": "stream",
774
+ "text": [
775
+ "<timed exec>:1: SettingWithCopyWarning: \n",
776
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
777
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
778
+ "\n",
779
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n"
780
+ ]
781
+ },
782
+ {
783
+ "data": {
784
+ "text/html": [
785
+ "<div>\n",
786
+ "<style scoped>\n",
787
+ " .dataframe tbody tr th:only-of-type {\n",
788
+ " vertical-align: middle;\n",
789
+ " }\n",
790
+ "\n",
791
+ " .dataframe tbody tr th {\n",
792
+ " vertical-align: top;\n",
793
+ " }\n",
794
+ "\n",
795
+ " .dataframe thead th {\n",
796
+ " text-align: right;\n",
797
+ " }\n",
798
+ "</style>\n",
799
+ "<table border=\"1\" class=\"dataframe\">\n",
800
+ " <thead>\n",
801
+ " <tr style=\"text-align: right;\">\n",
802
+ " <th></th>\n",
803
+ " <th>desease_condition</th>\n",
804
+ " <th>embeddings</th>\n",
805
+ " </tr>\n",
806
+ " </thead>\n",
807
+ " <tbody>\n",
808
+ " <tr>\n",
809
+ " <th>0</th>\n",
810
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
811
+ " <td>[-0.05811865255236626, -0.023393018171191216, ...</td>\n",
812
+ " </tr>\n",
813
+ " <tr>\n",
814
+ " <th>1</th>\n",
815
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
816
+ " <td>[-0.0581701435148716, -0.023382455110549927, 0...</td>\n",
817
+ " </tr>\n",
818
+ " <tr>\n",
819
+ " <th>2</th>\n",
820
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
821
+ " <td>[-0.03460180386900902, -0.084668830037117, 0.2...</td>\n",
822
+ " </tr>\n",
823
+ " <tr>\n",
824
+ " <th>3</th>\n",
825
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
826
+ " <td>[-0.08236236125230789, -0.1235777735710144, 0....</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <th>4</th>\n",
830
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
831
+ " <td>[-0.1227850392460823, 0.07155642658472061, 0.1...</td>\n",
832
+ " </tr>\n",
833
+ " </tbody>\n",
834
+ "</table>\n",
835
+ "</div>"
836
+ ],
837
+ "text/plain": [
838
+ " desease_condition \\\n",
839
+ "0 marijuana abuse, substance-related disorders, ... \n",
840
+ "1 marijuana abuse, substance-related disorders, ... \n",
841
+ "2 tuberculosis, latent tuberculosis, infections,... \n",
842
+ "3 heart failure, heart diseases, cardiovascular ... \n",
843
+ "4 lymphoma, neoplasms by histologic type, neopla... \n",
844
+ "\n",
845
+ " embeddings \n",
846
+ "0 [-0.05811865255236626, -0.023393018171191216, ... \n",
847
+ "1 [-0.0581701435148716, -0.023382455110549927, 0... \n",
848
+ "2 [-0.03460180386900902, -0.084668830037117, 0.2... \n",
849
+ "3 [-0.08236236125230789, -0.1235777735710144, 0.... \n",
850
+ "4 [-0.1227850392460823, 0.07155642658472061, 0.1... "
851
+ ]
852
+ },
853
+ "execution_count": 17,
854
+ "metadata": {},
855
+ "output_type": "execute_result"
856
+ }
857
+ ],
858
+ "source": [
859
+ "%%time\n",
860
+ "df2['embeddings']= df2['desease_condition'].apply(get_embeddings)\n",
861
+ "df2.head()"
862
+ ]
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "execution_count": 18,
867
+ "id": "c2f99031",
868
+ "metadata": {},
869
+ "outputs": [],
870
+ "source": [
871
+ "df2['embeddings'] = vectors.astype('float32',casting='same_kind').tolist()"
872
+ ]
873
+ },
874
+ {
875
+ "cell_type": "code",
876
+ "execution_count": 19,
877
+ "id": "952d69c7",
878
+ "metadata": {},
879
+ "outputs": [],
880
+ "source": [
881
+ "# Remove duplicate rows based on the 'nct_id' column\n",
882
+ "df2_without_duplicates = df2.drop_duplicates(subset='nct_id', keep='first')"
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": 24,
888
+ "id": "2711980a-d1c0-441e-ae9a-531500b7b7cd",
889
+ "metadata": {},
890
+ "outputs": [],
891
+ "source": [
892
+ "df2_without_duplicates[:130077].to_csv(\n",
893
+ " \"diseases_embeddings.csv\", index=False, header=True\n",
894
+ ")"
895
+ ]
896
+ },
897
+ {
898
+ "cell_type": "code",
899
+ "execution_count": 21,
900
+ "id": "fccd4f0e",
901
+ "metadata": {},
902
+ "outputs": [
903
+ {
904
+ "data": {
905
+ "text/html": [
906
+ "<div>\n",
907
+ "<style scoped>\n",
908
+ " .dataframe tbody tr th:only-of-type {\n",
909
+ " vertical-align: middle;\n",
910
+ " }\n",
911
+ "\n",
912
+ " .dataframe tbody tr th {\n",
913
+ " vertical-align: top;\n",
914
+ " }\n",
915
+ "\n",
916
+ " .dataframe thead th {\n",
917
+ " text-align: right;\n",
918
+ " }\n",
919
+ "</style>\n",
920
+ "<table border=\"1\" class=\"dataframe\">\n",
921
+ " <thead>\n",
922
+ " <tr style=\"text-align: right;\">\n",
923
+ " <th></th>\n",
924
+ " <th>desease_condition</th>\n",
925
+ " <th>embeddings</th>\n",
926
+ " <th>nct_id</th>\n",
927
+ " </tr>\n",
928
+ " </thead>\n",
929
+ " <tbody>\n",
930
+ " <tr>\n",
931
+ " <th>0</th>\n",
932
+ " <td>marijuana abuse, substance-related disorders, ...</td>\n",
933
+ " <td>[-0.8323991298675537, 1.47855544090271, 0.0013...</td>\n",
934
+ " <td>NCT03055377</td>\n",
935
+ " </tr>\n",
936
+ " <tr>\n",
937
+ " <th>2</th>\n",
938
+ " <td>tuberculosis, latent tuberculosis, infections,...</td>\n",
939
+ " <td>[-0.43443307280540466, 0.9625586271286011, -0....</td>\n",
940
+ " <td>NCT03042754</td>\n",
941
+ " </tr>\n",
942
+ " <tr>\n",
943
+ " <th>3</th>\n",
944
+ " <td>heart failure, heart diseases, cardiovascular ...</td>\n",
945
+ " <td>[-0.5791705250740051, 0.13008448481559753, 0.1...</td>\n",
946
+ " <td>NCT03035123</td>\n",
947
+ " </tr>\n",
948
+ " <tr>\n",
949
+ " <th>4</th>\n",
950
+ " <td>lymphoma, neoplasms by histologic type, neopla...</td>\n",
951
+ " <td>[-0.1608569175004959, 0.8489153981208801, -0.5...</td>\n",
952
+ " <td>NCT02272751</td>\n",
953
+ " </tr>\n",
954
+ " <tr>\n",
955
+ " <th>6</th>\n",
956
+ " <td>anemia, hematologic diseases</td>\n",
957
+ " <td>[0.21379394829273224, 0.17073844373226166, -0....</td>\n",
958
+ " <td>NCT00931606</td>\n",
959
+ " </tr>\n",
960
+ " <tr>\n",
961
+ " <th>...</th>\n",
962
+ " <td>...</td>\n",
963
+ " <td>...</td>\n",
964
+ " <td>...</td>\n",
965
+ " </tr>\n",
966
+ " <tr>\n",
967
+ " <th>440506</th>\n",
968
+ " <td>scoliosis, spinal curvatures, spinal diseases,...</td>\n",
969
+ " <td>[-1.20807683467865, 0.19357842206954956, 0.314...</td>\n",
970
+ " <td>NCT03641469</td>\n",
971
+ " </tr>\n",
972
+ " <tr>\n",
973
+ " <th>440507</th>\n",
974
+ " <td>asphyxia neonatorum, asphyxia, death, patholog...</td>\n",
975
+ " <td>[-0.7226205468177795, 1.0146900415420532, -0.2...</td>\n",
976
+ " <td>NCT03621956</td>\n",
977
+ " </tr>\n",
978
+ " <tr>\n",
979
+ " <th>440509</th>\n",
980
+ " <td>tuberculosis, helminthiasis, malnutrition, myc...</td>\n",
981
+ " <td>[-0.7196142673492432, 0.9588190913200378, 0.08...</td>\n",
982
+ " <td>NCT03598842</td>\n",
983
+ " </tr>\n",
984
+ " <tr>\n",
985
+ " <th>440512</th>\n",
986
+ " <td>obesity, overweight, overnutrition, nutrition ...</td>\n",
987
+ " <td>[-1.159234642982483, 0.5251776576042175, 0.237...</td>\n",
988
+ " <td>NCT03574103</td>\n",
989
+ " </tr>\n",
990
+ " <tr>\n",
991
+ " <th>440515</th>\n",
992
+ " <td>autistic disorder, autism spectrum disorder, c...</td>\n",
993
+ " <td>[-0.8618993759155273, 0.7515497803688049, 0.08...</td>\n",
994
+ " <td>NCT03570372</td>\n",
995
+ " </tr>\n",
996
+ " </tbody>\n",
997
+ "</table>\n",
998
+ "<p>233077 rows × 3 columns</p>\n",
999
+ "</div>"
1000
+ ],
1001
+ "text/plain": [
1002
+ " desease_condition \\\n",
1003
+ "0 marijuana abuse, substance-related disorders, ... \n",
1004
+ "2 tuberculosis, latent tuberculosis, infections,... \n",
1005
+ "3 heart failure, heart diseases, cardiovascular ... \n",
1006
+ "4 lymphoma, neoplasms by histologic type, neopla... \n",
1007
+ "6 anemia, hematologic diseases \n",
1008
+ "... ... \n",
1009
+ "440506 scoliosis, spinal curvatures, spinal diseases,... \n",
1010
+ "440507 asphyxia neonatorum, asphyxia, death, patholog... \n",
1011
+ "440509 tuberculosis, helminthiasis, malnutrition, myc... \n",
1012
+ "440512 obesity, overweight, overnutrition, nutrition ... \n",
1013
+ "440515 autistic disorder, autism spectrum disorder, c... \n",
1014
+ "\n",
1015
+ " embeddings nct_id \n",
1016
+ "0 [-0.8323991298675537, 1.47855544090271, 0.0013... NCT03055377 \n",
1017
+ "2 [-0.43443307280540466, 0.9625586271286011, -0.... NCT03042754 \n",
1018
+ "3 [-0.5791705250740051, 0.13008448481559753, 0.1... NCT03035123 \n",
1019
+ "4 [-0.1608569175004959, 0.8489153981208801, -0.5... NCT02272751 \n",
1020
+ "6 [0.21379394829273224, 0.17073844373226166, -0.... NCT00931606 \n",
1021
+ "... ... ... \n",
1022
+ "440506 [-1.20807683467865, 0.19357842206954956, 0.314... NCT03641469 \n",
1023
+ "440507 [-0.7226205468177795, 1.0146900415420532, -0.2... NCT03621956 \n",
1024
+ "440509 [-0.7196142673492432, 0.9588190913200378, 0.08... NCT03598842 \n",
1025
+ "440512 [-1.159234642982483, 0.5251776576042175, 0.237... NCT03574103 \n",
1026
+ "440515 [-0.8618993759155273, 0.7515497803688049, 0.08... NCT03570372 \n",
1027
+ "\n",
1028
+ "[233077 rows x 3 columns]"
1029
+ ]
1030
+ },
1031
+ "execution_count": 21,
1032
+ "metadata": {},
1033
+ "output_type": "execute_result"
1034
+ }
1035
+ ],
1036
+ "source": [
1037
+ "df2_without_duplicates"
1038
+ ]
1039
+ },
1040
+ {
1041
+ "cell_type": "code",
1042
+ "execution_count": 106,
1043
+ "id": "8ed985f4-9402-431f-bfba-1236ba16b895",
1044
+ "metadata": {},
1045
+ "outputs": [
1046
+ {
1047
+ "data": {
1048
+ "text/plain": [
1049
+ "<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x0000023493347480> >"
1050
+ ]
1051
+ },
1052
+ "execution_count": 106,
1053
+ "metadata": {},
1054
+ "output_type": "execute_result"
1055
+ }
1056
+ ],
1057
+ "source": [
1058
+ "import faiss\n",
1059
+ "\n",
1060
+ "index = faiss.IndexFlatL2(dim)\n",
1061
+ "index"
1062
+ ]
1063
+ },
1064
+ {
1065
+ "cell_type": "code",
1066
+ "execution_count": 108,
1067
+ "id": "71dad860-1f19-4166-a309-c9ce15f24792",
1068
+ "metadata": {
1069
+ "scrolled": true
1070
+ },
1071
+ "outputs": [
1072
+ {
1073
+ "name": "stdout",
1074
+ "output_type": "stream",
1075
+ "text": [
1076
+ "(768,)\n"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "data": {
1081
+ "text/plain": [
1082
+ "array([-8.82369652e-03, 8.50650743e-02, 2.08267733e-03, 6.77651772e-03,\n",
1083
+ " -2.86661759e-02, -8.71188380e-03, 6.99447095e-02, 5.04214764e-02,\n",
1084
+ " 3.58386151e-02, 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n",
1085
+ " 2.27009598e-03, 2.10810862e-02, 2.66138893e-02, 1.90623086e-02,\n",
1086
+ " 4.44708914e-02, 2.96202525e-02, 5.42085357e-02, -2.34859088e-03,\n",
1087
+ " -9.87798795e-02, -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n",
1088
+ " 5.31156994e-02, -1.37044629e-02, 2.92537250e-02, -2.61334293e-02,\n",
1089
+ " -1.21854078e-04, -2.36813519e-02, -3.81283499e-02, -1.79494768e-02,\n",
1090
+ " -6.29265187e-03, 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n",
1091
+ " -1.28973828e-04, 4.01791967e-02, 4.21229303e-02, -8.72302521e-03,\n",
1092
+ " 7.44823692e-03, 7.68032745e-02, 6.50246907e-03, 3.40298638e-02,\n",
1093
+ " -1.80711355e-02, -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n",
1094
+ " -3.34868580e-02, 1.05205458e-02, 2.08975170e-02, 4.36686277e-02,\n",
1095
+ " 3.47612537e-02, -4.99080680e-02, 4.44446988e-02, 5.57280704e-03,\n",
1096
+ " -2.31200755e-02, -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n",
1097
+ " 4.67903316e-02, 1.91651955e-02, -5.12171052e-02, 2.46807020e-02,\n",
1098
+ " -5.52081019e-02, 3.50596346e-02, -7.01438356e-03, -3.36519890e-02,\n",
1099
+ " -1.41502097e-02, -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n",
1100
+ " -1.30138136e-02, 5.91567121e-02, 3.37168351e-02, 3.01467292e-02,\n",
1101
+ " -4.59552221e-02, 1.37365120e-03, 1.00179566e-02, 6.98126853e-04,\n",
1102
+ " 3.58139984e-02, 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n",
1103
+ " 4.75908853e-02, 5.48331346e-03, -6.41460950e-03, -1.23906611e-02,\n",
1104
+ " 5.82688041e-02, -1.60842277e-02, -2.95833423e-04, 6.97355811e-03,\n",
1105
+ " 2.48331465e-02, -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n",
1106
+ " 1.52637456e-02, 7.01832073e-03, 5.50601333e-02, 4.35096538e-03,\n",
1107
+ " 2.36319732e-02, -1.38118947e-02, -7.24233836e-02, 9.39742289e-03,\n",
1108
+ " -2.66901590e-02, 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n",
1109
+ " 3.08373477e-03, 7.12765753e-02, 7.13613164e-03, 3.62721197e-02,\n",
1110
+ " -4.53250594e-02, 2.54001115e-02, -2.54253373e-02, -1.23151275e-03,\n",
1111
+ " -1.34750446e-02, -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n",
1112
+ " -7.31003610e-03, 2.65329964e-02, -2.79857730e-03, 4.20840643e-02,\n",
1113
+ " 3.20205763e-02, -1.19518824e-02, -5.77116087e-02, 9.88688134e-03,\n",
1114
+ " 1.86814573e-02, -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n",
1115
+ " -3.33383717e-02, -5.52933291e-02, 5.64148054e-02, 4.92153503e-02,\n",
1116
+ " 3.33690383e-02, -3.92963700e-02, -6.91099390e-02, 3.79911740e-03,\n",
1117
+ " -1.74410697e-02, -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n",
1118
+ " 2.61192098e-02, -2.74193864e-02, 6.92490395e-03, -4.64810384e-03,\n",
1119
+ " -8.99862905e-04, 1.02159111e-02, -4.81114909e-02, 1.22787328e-02,\n",
1120
+ " -9.32844076e-03, -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n",
1121
+ " -1.60810221e-02, -2.20200215e-02, 2.32051890e-02, -5.07331975e-02,\n",
1122
+ " -1.01248249e-02, 5.62567115e-02, -2.60966737e-03, 9.27545596e-03,\n",
1123
+ " 5.32410555e-02, 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n",
1124
+ " -2.12969314e-02, 9.82244611e-02, -2.47648880e-02, 7.06253499e-02,\n",
1125
+ " 8.71159416e-03, -2.73140483e-02, 5.59884915e-03, -2.14829091e-02,\n",
1126
+ " -6.67077005e-02, 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n",
1127
+ " 3.77488993e-02, -1.37352264e-02, -2.85069812e-02, 1.81708820e-02,\n",
1128
+ " -4.07746173e-02, -4.71230270e-03, -1.59605164e-02, -1.25815195e-03,\n",
1129
+ " -6.59954594e-03, -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n",
1130
+ " 3.07933297e-02, 1.27626080e-02, -4.34489138e-02, 9.91576444e-03,\n",
1131
+ " 1.25470785e-02, -8.67356583e-02, 1.26097840e-03, 3.24709825e-02,\n",
1132
+ " -6.92409948e-02, -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n",
1133
+ " -2.35359464e-02, -2.95022167e-02, 2.88009271e-02, -3.26618887e-02,\n",
1134
+ " 7.09307985e-03, 3.09435464e-03, -5.09097055e-02, 3.54242921e-02,\n",
1135
+ " 5.37336655e-02, 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n",
1136
+ " 2.93767708e-03, 2.27999203e-02, 1.02668423e-02, 3.35033536e-02,\n",
1137
+ " -8.28316063e-02, -4.17127199e-02, -1.23034064e-02, 2.38543525e-02,\n",
1138
+ " -3.72257493e-02, 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n",
1139
+ " 5.74519299e-03, 2.89844945e-02, -2.21337453e-02, 2.34603398e-02,\n",
1140
+ " 6.33142609e-03, -2.24104542e-02, 1.47326495e-02, 1.98041964e-02,\n",
1141
+ " 3.05697713e-02, -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n",
1142
+ " 3.88860740e-02, -3.97440195e-02, -4.70216498e-02, 1.02172708e-02,\n",
1143
+ " -3.37972888e-03, -8.54947045e-03, 4.81354557e-02, 4.99849804e-02,\n",
1144
+ " 7.11378129e-03, -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n",
1145
+ " 5.24284095e-02, 2.16708388e-02, -4.00698110e-02, 5.15380092e-02,\n",
1146
+ " -6.03203699e-02, -6.50304696e-03, -1.03860423e-02, -7.47132823e-02,\n",
1147
+ " 3.59848235e-03, -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n",
1148
+ " -2.88047381e-02, -2.81904116e-02, 1.52729014e-02, -1.55570190e-02,\n",
1149
+ " 1.34619148e-02, 2.34364290e-02, 3.10326237e-02, -4.70464528e-02,\n",
1150
+ " -2.43550166e-02, -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n",
1151
+ " -5.30204549e-03, 5.52049950e-02, 4.50828709e-02, -7.30262510e-03,\n",
1152
+ " 5.56289777e-02, -9.46066808e-03, -3.37345451e-02, -1.87659152e-02,\n",
1153
+ " 3.57284099e-02, 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n",
1154
+ " 2.96422077e-04, 4.22447585e-02, 4.97253910e-02, 6.03130311e-02,\n",
1155
+ " 1.32281650e-02, 2.35939436e-02, -1.59284715e-02, 4.46444489e-02,\n",
1156
+ " -1.68315917e-02, 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n",
1157
+ " 8.99049267e-03, 4.74606343e-02, 6.70041004e-03, -1.15184486e-03,\n",
1158
+ " 2.69540539e-03, -2.77549177e-02, -1.33260442e-02, 2.60788556e-02,\n",
1159
+ " 4.35438640e-02, -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n",
1160
+ " 2.93240137e-02, 1.82274636e-03, -1.40310880e-02, -1.91633645e-02,\n",
1161
+ " 1.18790809e-02, -4.65121269e-02, -4.19883654e-02, -2.69681774e-02,\n",
1162
+ " -3.23035605e-02, -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n",
1163
+ " -2.55833156e-02, -5.73152229e-02, 3.30126472e-02, -7.90146552e-03,\n",
1164
+ " -1.08651863e-02, 1.10474667e-02, 3.03509296e-03, 1.55274626e-02,\n",
1165
+ " 1.05599947e-02, -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n",
1166
+ " 3.77239436e-02, 9.44003314e-02, -4.80610691e-02, 4.73537892e-02,\n",
1167
+ " 3.40655483e-02, 7.88806472e-03, -2.84915343e-02, 7.96849206e-02,\n",
1168
+ " 1.57442074e-02, -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n",
1169
+ " -1.72730908e-01, -8.72075930e-02, 2.86346450e-02, 2.16962174e-02,\n",
1170
+ " -4.80199270e-02, 6.49317261e-03, 1.67240556e-02, -2.56227311e-02,\n",
1171
+ " 2.19670162e-02, -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n",
1172
+ " -2.89566331e-02, 1.19498251e-02, -2.33849231e-02, -2.69133616e-02,\n",
1173
+ " -1.46602485e-02, 1.18886270e-02, 1.64973717e-02, -3.90495770e-02,\n",
1174
+ " -3.45575088e-03, 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n",
1175
+ " 2.10017413e-02, 2.74998210e-02, 3.03551817e-04, -1.15796946e-01,\n",
1176
+ " -4.66962112e-03, -4.80118394e-02, -3.55160870e-02, -4.72528581e-03,\n",
1177
+ " -4.29739058e-02, -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n",
1178
+ " 1.98413953e-02, -7.27679394e-03, 2.27117930e-02, -2.59338003e-02,\n",
1179
+ " 4.31442596e-02, 1.07885078e-02, -2.47129947e-02, -4.14506458e-02,\n",
1180
+ " 4.40958813e-02, 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n",
1181
+ " 1.13289580e-02, -5.57265691e-02, 1.71151303e-03, -1.24145029e-02,\n",
1182
+ " -3.57853901e-03, -4.86295968e-02, -5.14956787e-02, 4.79425713e-02,\n",
1183
+ " -3.24050151e-02, 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n",
1184
+ " 8.20766483e-03, -6.27530292e-02, -1.30661400e-02, -3.52081768e-02,\n",
1185
+ " 4.83807474e-02, 9.81860235e-03, 1.14539362e-01, -1.88471414e-02,\n",
1186
+ " 6.07751869e-02, -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n",
1187
+ " 2.64345529e-03, 3.07400171e-02, -4.31060083e-02, -6.19985871e-02,\n",
1188
+ " 5.50477020e-03, 1.62547994e-02, -8.26352183e-03, 7.56437238e-03,\n",
1189
+ " -4.79784003e-03, 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n",
1190
+ " 1.41595434e-02, 5.31185642e-02, 6.78585656e-03, 6.56357184e-02,\n",
1191
+ " -5.06135784e-02, -3.05179805e-02, 7.06539825e-02, -3.55644710e-02,\n",
1192
+ " -4.92612133e-03, 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n",
1193
+ " -1.86746120e-02, 2.49281265e-02, -4.92450967e-03, 1.66887734e-02,\n",
1194
+ " 4.62210961e-02, 4.07794118e-02, 2.52511259e-02, -2.83305068e-02,\n",
1195
+ " -2.78001893e-02, -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n",
1196
+ " 1.09969089e-02, 1.69700030e-02, -8.59475043e-03, 4.70476560e-02,\n",
1197
+ " 3.64770554e-02, 2.09835749e-02, 1.01236468e-02, 2.75151283e-02,\n",
1198
+ " 4.33402918e-02, -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n",
1199
+ " -6.10819347e-02, -2.86280159e-02, 4.68054451e-02, 1.29892454e-02,\n",
1200
+ " -1.71940885e-02, -2.52429228e-02, 3.86423096e-02, -1.35919163e-02,\n",
1201
+ " -5.27431667e-02, 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n",
1202
+ " 3.23252901e-02, 5.03172688e-02, -4.45654802e-02, 2.90075876e-02,\n",
1203
+ " -1.35373492e-02, 6.78209821e-03, -5.89249916e-02, 4.28890549e-02,\n",
1204
+ " -2.36034058e-02, -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n",
1205
+ " 1.45543357e-02, 1.07806427e-02, -6.06855676e-02, -4.95252907e-02,\n",
1206
+ " 1.02004781e-02, 4.60227691e-02, -1.08090881e-02, 4.42408510e-02,\n",
1207
+ " 4.15152796e-02, 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n",
1208
+ " -2.70090066e-02, 2.68773828e-02, -1.97812133e-02, 2.25932393e-02,\n",
1209
+ " -1.33560598e-02, -1.50896851e-02, -3.14053567e-03, 1.54051669e-02,\n",
1210
+ " 1.86488125e-02, -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n",
1211
+ " -2.37891261e-04, 1.84722953e-02, 3.60381305e-02, -5.85213909e-03,\n",
1212
+ " 4.44293395e-02, -1.11264118e-03, -4.79441285e-02, 3.46464328e-02,\n",
1213
+ " -2.53370814e-02, -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n",
1214
+ " -4.38152434e-04, 4.08602282e-02, -2.29470823e-02, -1.89938806e-02,\n",
1215
+ " -1.52037974e-04, 1.05516789e-02, 2.08601039e-02, -6.98119551e-02,\n",
1216
+ " 3.66246551e-02, -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n",
1217
+ " 6.51817098e-02, 4.29646857e-02, 2.56071109e-02, -3.28080021e-02,\n",
1218
+ " 1.20534413e-02, 3.56224040e-03, -1.01593453e-02, -1.96505673e-04,\n",
1219
+ " 4.33485657e-02, -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n",
1220
+ " -1.40319867e-02, -3.63940969e-02, -3.09983976e-02, -4.19548260e-33,\n",
1221
+ " 7.11604580e-02, 4.78382297e-02, 1.89297704e-03, -1.60731785e-02,\n",
1222
+ " 2.53787991e-02, -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n",
1223
+ " 1.68679946e-03, 1.92391127e-02, -2.20667192e-04, 1.32907527e-02,\n",
1224
+ " 5.99487219e-03, 2.75156219e-02, -5.06000873e-03, -3.58465910e-02,\n",
1225
+ " 8.20948277e-03, -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n",
1226
+ " -1.09853260e-01, -3.66037302e-02, 3.55480015e-02, 4.23291475e-02,\n",
1227
+ " 1.48312682e-02, 5.68749309e-02, 3.57767567e-02, 1.40728084e-02,\n",
1228
+ " -4.00471613e-02, 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n",
1229
+ " 1.24238459e-02, 1.20237898e-02, -7.69484974e-03, -3.30727436e-02,\n",
1230
+ " -1.45808076e-02, 3.43246050e-02, 3.21143419e-02, -4.96741422e-02,\n",
1231
+ " -5.27968369e-02, 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n",
1232
+ " 2.77636852e-02, 5.90689071e-02, -2.61273161e-02, -6.95008039e-02,\n",
1233
+ " -3.15576978e-02, -5.62214339e-03, -7.93884136e-03, -3.62196900e-02,\n",
1234
+ " -8.26047733e-03, 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n",
1235
+ " -2.52235290e-02, -3.88054736e-02, -2.00710595e-02, 1.50789914e-03,\n",
1236
+ " -5.51338419e-02, -8.35673045e-03, -1.61523875e-02, -8.79513845e-02,\n",
1237
+ " -5.28004877e-02, -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n",
1238
+ " 4.44932319e-02, 8.69598426e-03, -1.14432694e-02, 4.47212979e-02,\n",
1239
+ " 2.70624813e-02, -3.86100151e-02, -3.07358261e-02, 2.75634117e-02,\n",
1240
+ " 1.48464069e-02, -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n",
1241
+ " 8.05836394e-02, -1.69498641e-02, 4.44465503e-02, -2.09145956e-02,\n",
1242
+ " -3.37407738e-02, 3.85780074e-02, -7.44559616e-02, 1.17512364e-02,\n",
1243
+ " 1.01964204e-02, -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n",
1244
+ " 2.54592765e-02, -1.46158040e-02, 5.46646416e-02, 1.43051194e-03,\n",
1245
+ " 2.99116820e-02, 2.24273186e-02, -5.79927117e-03, -1.33864526e-02,\n",
1246
+ " -2.52460372e-02, -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n",
1247
+ " 3.38429734e-02, -2.11539529e-02, 7.17787817e-02, -7.78904185e-02,\n",
1248
+ " -4.04084288e-02, 4.90567498e-02, -2.61603445e-02, 1.97753590e-02,\n",
1249
+ " 4.97209951e-02, -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n",
1250
+ " 2.68440694e-02, 3.29160057e-02, -8.24410375e-03, -1.33646047e-02,\n",
1251
+ " -6.22822754e-02, -1.13362661e-02, -3.79339382e-02, -6.56360280e-05,\n",
1252
+ " -1.08087100e-02, 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n",
1253
+ " -2.54666172e-02, -3.05371322e-02, -1.53249800e-02, -9.87035502e-03,\n",
1254
+ " 1.95337094e-07, -1.76476724e-02, 5.71432859e-02, -2.49180794e-02,\n",
1255
+ " 5.85253723e-02, 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n",
1256
+ " 4.07801419e-02, 4.13940698e-02, 2.55707726e-02, 2.18985360e-02,\n",
1257
+ " -3.04434425e-03, -3.77355106e-02, -6.24866784e-02, -1.17468778e-02,\n",
1258
+ " -4.82194684e-02, -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n",
1259
+ " -2.48471629e-02, 8.05181568e-04, -4.85844910e-03, -5.16015477e-03,\n",
1260
+ " 7.53483502e-03, -9.46175400e-03, -2.39896346e-02, -3.14654633e-02,\n",
1261
+ " 1.50111094e-02, -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n",
1262
+ " 3.08971256e-02, 1.72299352e-02, 5.93419448e-02, -5.74274361e-02,\n",
1263
+ " -8.16087723e-02, -4.80572283e-02, -2.68838424e-02, -1.96331330e-02,\n",
1264
+ " -9.15831141e-03, 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n",
1265
+ " 9.21937004e-02, 1.37132118e-02, -1.19096776e-02, -4.09874134e-02,\n",
1266
+ " 3.37628126e-02, -4.64820908e-03, -2.50304434e-02, 6.25852346e-02,\n",
1267
+ " -1.24449311e-02, 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n",
1268
+ " 5.08641489e-02, 2.53822445e-03, 5.25634140e-02, 1.14882430e-02,\n",
1269
+ " 5.01894541e-02, -3.55215147e-02, -3.31749097e-02, -3.02003417e-03,\n",
1270
+ " -5.36288768e-02, -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n",
1271
+ " 9.56887701e-35, 2.55127084e-02, -1.44770980e-04, 1.96710341e-02,\n",
1272
+ " -1.33620016e-02, -1.51910949e-02, -3.28495577e-02, -1.52465852e-03,\n",
1273
+ " -2.65272055e-02, -4.35708016e-02, -1.75950192e-02, -2.20594816e-02],\n",
1274
+ " dtype=float32)"
1275
+ ]
1276
+ },
1277
+ "execution_count": 108,
1278
+ "metadata": {},
1279
+ "output_type": "execute_result"
1280
+ }
1281
+ ],
1282
+ "source": [
1283
+ "search_query = \"clinical trials related to alzheimers\"\n",
1284
+ "vec = encoder.encode(search_query)\n",
1285
+ "print(vec.shape)\n",
1286
+ "vec"
1287
+ ]
1288
+ },
1289
+ {
1290
+ "cell_type": "code",
1291
+ "execution_count": 109,
1292
+ "id": "613fd415-4194-45e6-b9f3-9a7707845ad5",
1293
+ "metadata": {
1294
+ "scrolled": true
1295
+ },
1296
+ "outputs": [
1297
+ {
1298
+ "name": "stdout",
1299
+ "output_type": "stream",
1300
+ "text": [
1301
+ "(1, 768)\n"
1302
+ ]
1303
+ },
1304
+ {
1305
+ "data": {
1306
+ "text/plain": [
1307
+ "array([[-8.82369652e-03, 8.50650743e-02, 2.08267733e-03,\n",
1308
+ " 6.77651772e-03, -2.86661759e-02, -8.71188380e-03,\n",
1309
+ " 6.99447095e-02, 5.04214764e-02, 3.58386151e-02,\n",
1310
+ " 5.29594952e-03, -1.40875215e-02, 1.99297220e-02,\n",
1311
+ " 2.27009598e-03, 2.10810862e-02, 2.66138893e-02,\n",
1312
+ " 1.90623086e-02, 4.44708914e-02, 2.96202525e-02,\n",
1313
+ " 5.42085357e-02, -2.34859088e-03, -9.87798795e-02,\n",
1314
+ " -5.00183590e-02, -3.42465192e-02, 2.08440255e-02,\n",
1315
+ " 5.31156994e-02, -1.37044629e-02, 2.92537250e-02,\n",
1316
+ " -2.61334293e-02, -1.21854078e-04, -2.36813519e-02,\n",
1317
+ " -3.81283499e-02, -1.79494768e-02, -6.29265187e-03,\n",
1318
+ " 1.27150817e-02, 1.19849676e-06, -7.78729608e-03,\n",
1319
+ " -1.28973828e-04, 4.01791967e-02, 4.21229303e-02,\n",
1320
+ " -8.72302521e-03, 7.44823692e-03, 7.68032745e-02,\n",
1321
+ " 6.50246907e-03, 3.40298638e-02, -1.80711355e-02,\n",
1322
+ " -2.71878559e-02, 5.74751608e-02, 3.67745496e-02,\n",
1323
+ " -3.34868580e-02, 1.05205458e-02, 2.08975170e-02,\n",
1324
+ " 4.36686277e-02, 3.47612537e-02, -4.99080680e-02,\n",
1325
+ " 4.44446988e-02, 5.57280704e-03, -2.31200755e-02,\n",
1326
+ " -4.60692644e-02, -1.39789237e-02, -3.79957110e-02,\n",
1327
+ " 4.67903316e-02, 1.91651955e-02, -5.12171052e-02,\n",
1328
+ " 2.46807020e-02, -5.52081019e-02, 3.50596346e-02,\n",
1329
+ " -7.01438356e-03, -3.36519890e-02, -1.41502097e-02,\n",
1330
+ " -1.37693482e-02, 4.11427952e-02, 6.94046309e-03,\n",
1331
+ " -1.30138136e-02, 5.91567121e-02, 3.37168351e-02,\n",
1332
+ " 3.01467292e-02, -4.59552221e-02, 1.37365120e-03,\n",
1333
+ " 1.00179566e-02, 6.98126853e-04, 3.58139984e-02,\n",
1334
+ " 1.18174301e-02, 1.33722462e-02, -1.35893077e-02,\n",
1335
+ " 4.75908853e-02, 5.48331346e-03, -6.41460950e-03,\n",
1336
+ " -1.23906611e-02, 5.82688041e-02, -1.60842277e-02,\n",
1337
+ " -2.95833423e-04, 6.97355811e-03, 2.48331465e-02,\n",
1338
+ " -2.35959496e-02, -1.24989869e-02, -1.36585534e-02,\n",
1339
+ " 1.52637456e-02, 7.01832073e-03, 5.50601333e-02,\n",
1340
+ " 4.35096538e-03, 2.36319732e-02, -1.38118947e-02,\n",
1341
+ " -7.24233836e-02, 9.39742289e-03, -2.66901590e-02,\n",
1342
+ " 2.96042152e-02, 1.28761679e-02, 2.23339219e-02,\n",
1343
+ " 3.08373477e-03, 7.12765753e-02, 7.13613164e-03,\n",
1344
+ " 3.62721197e-02, -4.53250594e-02, 2.54001115e-02,\n",
1345
+ " -2.54253373e-02, -1.23151275e-03, -1.34750446e-02,\n",
1346
+ " -2.70653702e-02, -1.02220355e-02, 2.07683407e-02,\n",
1347
+ " -7.31003610e-03, 2.65329964e-02, -2.79857730e-03,\n",
1348
+ " 4.20840643e-02, 3.20205763e-02, -1.19518824e-02,\n",
1349
+ " -5.77116087e-02, 9.88688134e-03, 1.86814573e-02,\n",
1350
+ " -5.10204993e-02, -6.77110278e-04, 9.40234493e-03,\n",
1351
+ " -3.33383717e-02, -5.52933291e-02, 5.64148054e-02,\n",
1352
+ " 4.92153503e-02, 3.33690383e-02, -3.92963700e-02,\n",
1353
+ " -6.91099390e-02, 3.79911740e-03, -1.74410697e-02,\n",
1354
+ " -1.60171147e-02, 4.89675067e-02, 2.67119659e-03,\n",
1355
+ " 2.61192098e-02, -2.74193864e-02, 6.92490395e-03,\n",
1356
+ " -4.64810384e-03, -8.99862905e-04, 1.02159111e-02,\n",
1357
+ " -4.81114909e-02, 1.22787328e-02, -9.32844076e-03,\n",
1358
+ " -2.00431682e-02, -1.36102587e-02, -3.67914373e-03,\n",
1359
+ " -1.60810221e-02, -2.20200215e-02, 2.32051890e-02,\n",
1360
+ " -5.07331975e-02, -1.01248249e-02, 5.62567115e-02,\n",
1361
+ " -2.60966737e-03, 9.27545596e-03, 5.32410555e-02,\n",
1362
+ " 4.81746234e-02, -9.83138476e-03, 1.81230865e-02,\n",
1363
+ " -2.12969314e-02, 9.82244611e-02, -2.47648880e-02,\n",
1364
+ " 7.06253499e-02, 8.71159416e-03, -2.73140483e-02,\n",
1365
+ " 5.59884915e-03, -2.14829091e-02, -6.67077005e-02,\n",
1366
+ " 2.48677693e-02, -8.29503238e-02, -7.96182230e-02,\n",
1367
+ " 3.77488993e-02, -1.37352264e-02, -2.85069812e-02,\n",
1368
+ " 1.81708820e-02, -4.07746173e-02, -4.71230270e-03,\n",
1369
+ " -1.59605164e-02, -1.25815195e-03, -6.59954594e-03,\n",
1370
+ " -1.51611334e-02, 7.87123516e-02, -4.09705602e-02,\n",
1371
+ " 3.07933297e-02, 1.27626080e-02, -4.34489138e-02,\n",
1372
+ " 9.91576444e-03, 1.25470785e-02, -8.67356583e-02,\n",
1373
+ " 1.26097840e-03, 3.24709825e-02, -6.92409948e-02,\n",
1374
+ " -4.35011238e-02, -2.79313605e-02, -3.37213017e-02,\n",
1375
+ " -2.35359464e-02, -2.95022167e-02, 2.88009271e-02,\n",
1376
+ " -3.26618887e-02, 7.09307985e-03, 3.09435464e-03,\n",
1377
+ " -5.09097055e-02, 3.54242921e-02, 5.37336655e-02,\n",
1378
+ " 1.55867739e-02, 2.09988486e-02, -4.38529663e-02,\n",
1379
+ " 2.93767708e-03, 2.27999203e-02, 1.02668423e-02,\n",
1380
+ " 3.35033536e-02, -8.28316063e-02, -4.17127199e-02,\n",
1381
+ " -1.23034064e-02, 2.38543525e-02, -3.72257493e-02,\n",
1382
+ " 2.97443867e-02, 3.35034318e-02, -5.21336049e-02,\n",
1383
+ " 5.74519299e-03, 2.89844945e-02, -2.21337453e-02,\n",
1384
+ " 2.34603398e-02, 6.33142609e-03, -2.24104542e-02,\n",
1385
+ " 1.47326495e-02, 1.98041964e-02, 3.05697713e-02,\n",
1386
+ " -9.37094465e-02, -6.84579164e-02, 4.63523576e-03,\n",
1387
+ " 3.88860740e-02, -3.97440195e-02, -4.70216498e-02,\n",
1388
+ " 1.02172708e-02, -3.37972888e-03, -8.54947045e-03,\n",
1389
+ " 4.81354557e-02, 4.99849804e-02, 7.11378129e-03,\n",
1390
+ " -2.54327375e-02, -1.14872465e-02, -3.54485810e-02,\n",
1391
+ " 5.24284095e-02, 2.16708388e-02, -4.00698110e-02,\n",
1392
+ " 5.15380092e-02, -6.03203699e-02, -6.50304696e-03,\n",
1393
+ " -1.03860423e-02, -7.47132823e-02, 3.59848235e-03,\n",
1394
+ " -4.68364358e-02, -4.23019789e-02, -1.86387468e-02,\n",
1395
+ " -2.88047381e-02, -2.81904116e-02, 1.52729014e-02,\n",
1396
+ " -1.55570190e-02, 1.34619148e-02, 2.34364290e-02,\n",
1397
+ " 3.10326237e-02, -4.70464528e-02, -2.43550166e-02,\n",
1398
+ " -7.20657408e-03, -1.16065536e-02, -3.42444591e-02,\n",
1399
+ " -5.30204549e-03, 5.52049950e-02, 4.50828709e-02,\n",
1400
+ " -7.30262510e-03, 5.56289777e-02, -9.46066808e-03,\n",
1401
+ " -3.37345451e-02, -1.87659152e-02, 3.57284099e-02,\n",
1402
+ " 4.20488343e-02, 1.66770478e-03, -5.27675785e-02,\n",
1403
+ " 2.96422077e-04, 4.22447585e-02, 4.97253910e-02,\n",
1404
+ " 6.03130311e-02, 1.32281650e-02, 2.35939436e-02,\n",
1405
+ " -1.59284715e-02, 4.46444489e-02, -1.68315917e-02,\n",
1406
+ " 1.34740606e-01, -3.54593806e-02, 4.79029641e-02,\n",
1407
+ " 8.99049267e-03, 4.74606343e-02, 6.70041004e-03,\n",
1408
+ " -1.15184486e-03, 2.69540539e-03, -2.77549177e-02,\n",
1409
+ " -1.33260442e-02, 2.60788556e-02, 4.35438640e-02,\n",
1410
+ " -2.55859867e-02, 2.76670083e-02, 3.37177999e-02,\n",
1411
+ " 2.93240137e-02, 1.82274636e-03, -1.40310880e-02,\n",
1412
+ " -1.91633645e-02, 1.18790809e-02, -4.65121269e-02,\n",
1413
+ " -4.19883654e-02, -2.69681774e-02, -3.23035605e-02,\n",
1414
+ " -6.84630498e-02, 6.26784265e-02, 1.37511576e-02,\n",
1415
+ " -2.55833156e-02, -5.73152229e-02, 3.30126472e-02,\n",
1416
+ " -7.90146552e-03, -1.08651863e-02, 1.10474667e-02,\n",
1417
+ " 3.03509296e-03, 1.55274626e-02, 1.05599947e-02,\n",
1418
+ " -7.16960803e-03, -5.01419827e-02, -3.34469602e-02,\n",
1419
+ " 3.77239436e-02, 9.44003314e-02, -4.80610691e-02,\n",
1420
+ " 4.73537892e-02, 3.40655483e-02, 7.88806472e-03,\n",
1421
+ " -2.84915343e-02, 7.96849206e-02, 1.57442074e-02,\n",
1422
+ " -4.15650755e-02, 7.51048513e-03, 3.66957486e-02,\n",
1423
+ " -1.72730908e-01, -8.72075930e-02, 2.86346450e-02,\n",
1424
+ " 2.16962174e-02, -4.80199270e-02, 6.49317261e-03,\n",
1425
+ " 1.67240556e-02, -2.56227311e-02, 2.19670162e-02,\n",
1426
+ " -6.10647202e-02, -2.65449155e-02, 6.17929082e-03,\n",
1427
+ " -2.89566331e-02, 1.19498251e-02, -2.33849231e-02,\n",
1428
+ " -2.69133616e-02, -1.46602485e-02, 1.18886270e-02,\n",
1429
+ " 1.64973717e-02, -3.90495770e-02, -3.45575088e-03,\n",
1430
+ " 5.12249060e-02, -8.63745401e-04, 5.59820198e-02,\n",
1431
+ " 2.10017413e-02, 2.74998210e-02, 3.03551817e-04,\n",
1432
+ " -1.15796946e-01, -4.66962112e-03, -4.80118394e-02,\n",
1433
+ " -3.55160870e-02, -4.72528581e-03, -4.29739058e-02,\n",
1434
+ " -1.07347388e-02, -1.32423071e-02, -2.34632343e-02,\n",
1435
+ " 1.98413953e-02, -7.27679394e-03, 2.27117930e-02,\n",
1436
+ " -2.59338003e-02, 4.31442596e-02, 1.07885078e-02,\n",
1437
+ " -2.47129947e-02, -4.14506458e-02, 4.40958813e-02,\n",
1438
+ " 6.65106403e-04, -2.26945560e-02, -4.76796739e-02,\n",
1439
+ " 1.13289580e-02, -5.57265691e-02, 1.71151303e-03,\n",
1440
+ " -1.24145029e-02, -3.57853901e-03, -4.86295968e-02,\n",
1441
+ " -5.14956787e-02, 4.79425713e-02, -3.24050151e-02,\n",
1442
+ " 7.39779174e-02, 2.67242044e-02, 1.16365692e-02,\n",
1443
+ " 8.20766483e-03, -6.27530292e-02, -1.30661400e-02,\n",
1444
+ " -3.52081768e-02, 4.83807474e-02, 9.81860235e-03,\n",
1445
+ " 1.14539362e-01, -1.88471414e-02, 6.07751869e-02,\n",
1446
+ " -1.75345445e-03, 3.13236266e-02, -1.94595556e-03,\n",
1447
+ " 2.64345529e-03, 3.07400171e-02, -4.31060083e-02,\n",
1448
+ " -6.19985871e-02, 5.50477020e-03, 1.62547994e-02,\n",
1449
+ " -8.26352183e-03, 7.56437238e-03, -4.79784003e-03,\n",
1450
+ " 6.93615247e-03, 3.59064825e-02, 2.08517518e-02,\n",
1451
+ " 1.41595434e-02, 5.31185642e-02, 6.78585656e-03,\n",
1452
+ " 6.56357184e-02, -5.06135784e-02, -3.05179805e-02,\n",
1453
+ " 7.06539825e-02, -3.55644710e-02, -4.92612133e-03,\n",
1454
+ " 9.91953164e-02, 1.00235650e-02, -2.22671125e-02,\n",
1455
+ " -1.86746120e-02, 2.49281265e-02, -4.92450967e-03,\n",
1456
+ " 1.66887734e-02, 4.62210961e-02, 4.07794118e-02,\n",
1457
+ " 2.52511259e-02, -2.83305068e-02, -2.78001893e-02,\n",
1458
+ " -1.69764105e-02, 1.79186705e-02, 1.09842177e-02,\n",
1459
+ " 1.09969089e-02, 1.69700030e-02, -8.59475043e-03,\n",
1460
+ " 4.70476560e-02, 3.64770554e-02, 2.09835749e-02,\n",
1461
+ " 1.01236468e-02, 2.75151283e-02, 4.33402918e-02,\n",
1462
+ " -4.30559181e-02, -3.53547297e-02, 7.77268112e-02,\n",
1463
+ " -6.10819347e-02, -2.86280159e-02, 4.68054451e-02,\n",
1464
+ " 1.29892454e-02, -1.71940885e-02, -2.52429228e-02,\n",
1465
+ " 3.86423096e-02, -1.35919163e-02, -5.27431667e-02,\n",
1466
+ " 6.45831088e-03, 2.96409409e-02, 5.97442053e-02,\n",
1467
+ " 3.23252901e-02, 5.03172688e-02, -4.45654802e-02,\n",
1468
+ " 2.90075876e-02, -1.35373492e-02, 6.78209821e-03,\n",
1469
+ " -5.89249916e-02, 4.28890549e-02, -2.36034058e-02,\n",
1470
+ " -5.30969724e-03, 3.85405980e-02, -1.82616734e-03,\n",
1471
+ " 1.45543357e-02, 1.07806427e-02, -6.06855676e-02,\n",
1472
+ " -4.95252907e-02, 1.02004781e-02, 4.60227691e-02,\n",
1473
+ " -1.08090881e-02, 4.42408510e-02, 4.15152796e-02,\n",
1474
+ " 1.23609398e-02, 5.11957100e-03, 1.17597533e-02,\n",
1475
+ " -2.70090066e-02, 2.68773828e-02, -1.97812133e-02,\n",
1476
+ " 2.25932393e-02, -1.33560598e-02, -1.50896851e-02,\n",
1477
+ " -3.14053567e-03, 1.54051669e-02, 1.86488125e-02,\n",
1478
+ " -1.71708278e-02, -3.95283476e-03, 7.68053811e-04,\n",
1479
+ " -2.37891261e-04, 1.84722953e-02, 3.60381305e-02,\n",
1480
+ " -5.85213909e-03, 4.44293395e-02, -1.11264118e-03,\n",
1481
+ " -4.79441285e-02, 3.46464328e-02, -2.53370814e-02,\n",
1482
+ " -3.26901935e-02, -2.28975322e-02, -1.96164921e-02,\n",
1483
+ " -4.38152434e-04, 4.08602282e-02, -2.29470823e-02,\n",
1484
+ " -1.89938806e-02, -1.52037974e-04, 1.05516789e-02,\n",
1485
+ " 2.08601039e-02, -6.98119551e-02, 3.66246551e-02,\n",
1486
+ " -1.26779894e-03, -4.03217562e-02, -5.35424761e-02,\n",
1487
+ " 6.51817098e-02, 4.29646857e-02, 2.56071109e-02,\n",
1488
+ " -3.28080021e-02, 1.20534413e-02, 3.56224040e-03,\n",
1489
+ " -1.01593453e-02, -1.96505673e-04, 4.33485657e-02,\n",
1490
+ " -4.25680764e-02, 9.73126665e-03, 3.76882474e-03,\n",
1491
+ " -1.40319867e-02, -3.63940969e-02, -3.09983976e-02,\n",
1492
+ " -4.19548260e-33, 7.11604580e-02, 4.78382297e-02,\n",
1493
+ " 1.89297704e-03, -1.60731785e-02, 2.53787991e-02,\n",
1494
+ " -3.15741785e-02, -4.27713171e-02, -7.53164338e-03,\n",
1495
+ " 1.68679946e-03, 1.92391127e-02, -2.20667192e-04,\n",
1496
+ " 1.32907527e-02, 5.99487219e-03, 2.75156219e-02,\n",
1497
+ " -5.06000873e-03, -3.58465910e-02, 8.20948277e-03,\n",
1498
+ " -2.11624149e-02, -7.07996823e-03, -4.23992332e-03,\n",
1499
+ " -1.09853260e-01, -3.66037302e-02, 3.55480015e-02,\n",
1500
+ " 4.23291475e-02, 1.48312682e-02, 5.68749309e-02,\n",
1501
+ " 3.57767567e-02, 1.40728084e-02, -4.00471613e-02,\n",
1502
+ " 1.01988176e-02, 2.83056553e-02, -1.55737845e-03,\n",
1503
+ " 1.24238459e-02, 1.20237898e-02, -7.69484974e-03,\n",
1504
+ " -3.30727436e-02, -1.45808076e-02, 3.43246050e-02,\n",
1505
+ " 3.21143419e-02, -4.96741422e-02, -5.27968369e-02,\n",
1506
+ " 2.51889303e-02, -1.11904610e-02, 5.64832352e-02,\n",
1507
+ " 2.77636852e-02, 5.90689071e-02, -2.61273161e-02,\n",
1508
+ " -6.95008039e-02, -3.15576978e-02, -5.62214339e-03,\n",
1509
+ " -7.93884136e-03, -3.62196900e-02, -8.26047733e-03,\n",
1510
+ " 8.05249214e-02, -4.16241921e-02, -2.01846119e-02,\n",
1511
+ " -2.52235290e-02, -3.88054736e-02, -2.00710595e-02,\n",
1512
+ " 1.50789914e-03, -5.51338419e-02, -8.35673045e-03,\n",
1513
+ " -1.61523875e-02, -8.79513845e-02, -5.28004877e-02,\n",
1514
+ " -2.88654189e-03, -1.11697149e-02, 7.10910782e-02,\n",
1515
+ " 4.44932319e-02, 8.69598426e-03, -1.14432694e-02,\n",
1516
+ " 4.47212979e-02, 2.70624813e-02, -3.86100151e-02,\n",
1517
+ " -3.07358261e-02, 2.75634117e-02, 1.48464069e-02,\n",
1518
+ " -1.00845508e-02, 6.45884350e-02, 4.28387662e-03,\n",
1519
+ " 8.05836394e-02, -1.69498641e-02, 4.44465503e-02,\n",
1520
+ " -2.09145956e-02, -3.37407738e-02, 3.85780074e-02,\n",
1521
+ " -7.44559616e-02, 1.17512364e-02, 1.01964204e-02,\n",
1522
+ " -3.02421930e-03, 4.80608828e-02, -1.49494391e-02,\n",
1523
+ " 2.54592765e-02, -1.46158040e-02, 5.46646416e-02,\n",
1524
+ " 1.43051194e-03, 2.99116820e-02, 2.24273186e-02,\n",
1525
+ " -5.79927117e-03, -1.33864526e-02, -2.52460372e-02,\n",
1526
+ " -2.69225910e-02, 1.64003875e-02, 1.20901112e-02,\n",
1527
+ " 3.38429734e-02, -2.11539529e-02, 7.17787817e-02,\n",
1528
+ " -7.78904185e-02, -4.04084288e-02, 4.90567498e-02,\n",
1529
+ " -2.61603445e-02, 1.97753590e-02, 4.97209951e-02,\n",
1530
+ " -4.88655381e-02, -4.52128090e-02, 3.63065898e-02,\n",
1531
+ " 2.68440694e-02, 3.29160057e-02, -8.24410375e-03,\n",
1532
+ " -1.33646047e-02, -6.22822754e-02, -1.13362661e-02,\n",
1533
+ " -3.79339382e-02, -6.56360280e-05, -1.08087100e-02,\n",
1534
+ " 2.67575700e-02, 1.33866509e-02, 5.89998253e-02,\n",
1535
+ " -2.54666172e-02, -3.05371322e-02, -1.53249800e-02,\n",
1536
+ " -9.87035502e-03, 1.95337094e-07, -1.76476724e-02,\n",
1537
+ " 5.71432859e-02, -2.49180794e-02, 5.85253723e-02,\n",
1538
+ " 4.49808314e-02, -5.99673577e-02, -9.97425616e-03,\n",
1539
+ " 4.07801419e-02, 4.13940698e-02, 2.55707726e-02,\n",
1540
+ " 2.18985360e-02, -3.04434425e-03, -3.77355106e-02,\n",
1541
+ " -6.24866784e-02, -1.17468778e-02, -4.82194684e-02,\n",
1542
+ " -7.78659210e-02, -1.48841189e-02, -1.75396129e-02,\n",
1543
+ " -2.48471629e-02, 8.05181568e-04, -4.85844910e-03,\n",
1544
+ " -5.16015477e-03, 7.53483502e-03, -9.46175400e-03,\n",
1545
+ " -2.39896346e-02, -3.14654633e-02, 1.50111094e-02,\n",
1546
+ " -1.22348899e-02, 3.00448518e-02, 3.55701670e-02,\n",
1547
+ " 3.08971256e-02, 1.72299352e-02, 5.93419448e-02,\n",
1548
+ " -5.74274361e-02, -8.16087723e-02, -4.80572283e-02,\n",
1549
+ " -2.68838424e-02, -1.96331330e-02, -9.15831141e-03,\n",
1550
+ " 1.07509056e-02, 2.35639680e-02, -2.62569580e-02,\n",
1551
+ " 9.21937004e-02, 1.37132118e-02, -1.19096776e-02,\n",
1552
+ " -4.09874134e-02, 3.37628126e-02, -4.64820908e-03,\n",
1553
+ " -2.50304434e-02, 6.25852346e-02, -1.24449311e-02,\n",
1554
+ " 3.82654071e-02, -2.35330854e-02, 8.68125912e-03,\n",
1555
+ " 5.08641489e-02, 2.53822445e-03, 5.25634140e-02,\n",
1556
+ " 1.14882430e-02, 5.01894541e-02, -3.55215147e-02,\n",
1557
+ " -3.31749097e-02, -3.02003417e-03, -5.36288768e-02,\n",
1558
+ " -2.80938316e-02, -7.51279444e-02, -4.71623316e-02,\n",
1559
+ " 9.56887701e-35, 2.55127084e-02, -1.44770980e-04,\n",
1560
+ " 1.96710341e-02, -1.33620016e-02, -1.51910949e-02,\n",
1561
+ " -3.28495577e-02, -1.52465852e-03, -2.65272055e-02,\n",
1562
+ " -4.35708016e-02, -1.75950192e-02, -2.20594816e-02]], dtype=float32)"
1563
+ ]
1564
+ },
1565
+ "execution_count": 109,
1566
+ "metadata": {},
1567
+ "output_type": "execute_result"
1568
+ }
1569
+ ],
1570
+ "source": [
1571
+ "import numpy as np\n",
1572
+ "\n",
1573
+ "svec = np.array(vec).reshape(1, -1)\n",
1574
+ "print(svec.shape)\n",
1575
+ "svec"
1576
+ ]
1577
+ },
1578
+ {
1579
+ "cell_type": "code",
1580
+ "execution_count": 110,
1581
+ "id": "fef30d70-6958-4259-abb6-09f8c1870a2b",
1582
+ "metadata": {},
1583
+ "outputs": [
1584
+ {
1585
+ "name": "stdout",
1586
+ "output_type": "stream",
1587
+ "text": [
1588
+ "[[0.7731663 0.79433584]] [[330 331]]\n"
1589
+ ]
1590
+ }
1591
+ ],
1592
+ "source": [
1593
+ "distances, I = index.search(svec, k=2)\n",
1594
+ "print(distances, I)"
1595
+ ]
1596
+ },
1597
+ {
1598
+ "cell_type": "code",
1599
+ "execution_count": 111,
1600
+ "id": "eb00598c-9799-4697-b2a3-356bb5aae0f1",
1601
+ "metadata": {},
1602
+ "outputs": [
1603
+ {
1604
+ "data": {
1605
+ "text/html": [
1606
+ "<div>\n",
1607
+ "<style scoped>\n",
1608
+ " .dataframe tbody tr th:only-of-type {\n",
1609
+ " vertical-align: middle;\n",
1610
+ " }\n",
1611
+ "\n",
1612
+ " .dataframe tbody tr th {\n",
1613
+ " vertical-align: top;\n",
1614
+ " }\n",
1615
+ "\n",
1616
+ " .dataframe thead th {\n",
1617
+ " text-align: right;\n",
1618
+ " }\n",
1619
+ "</style>\n",
1620
+ "<table border=\"1\" class=\"dataframe\">\n",
1621
+ " <thead>\n",
1622
+ " <tr style=\"text-align: right;\">\n",
1623
+ " <th></th>\n",
1624
+ " <th>desease_condition</th>\n",
1625
+ " <th>text</th>\n",
1626
+ " </tr>\n",
1627
+ " </thead>\n",
1628
+ " <tbody>\n",
1629
+ " <tr>\n",
1630
+ " <th>330</th>\n",
1631
+ " <td>['alzheimer disease', 'dementia', 'brain disea...</td>\n",
1632
+ " <td>nct_id: NCT02164643\\nsummary: A Multicenter na...</td>\n",
1633
+ " </tr>\n",
1634
+ " <tr>\n",
1635
+ " <th>331</th>\n",
1636
+ " <td>['alzheimer disease', 'dementia', 'brain disea...</td>\n",
1637
+ " <td>nct_id: NCT02164643\\nsummary: A Multicenter na...</td>\n",
1638
+ " </tr>\n",
1639
+ " </tbody>\n",
1640
+ "</table>\n",
1641
+ "</div>"
1642
+ ],
1643
+ "text/plain": [
1644
+ " desease_condition \\\n",
1645
+ "330 ['alzheimer disease', 'dementia', 'brain disea... \n",
1646
+ "331 ['alzheimer disease', 'dementia', 'brain disea... \n",
1647
+ "\n",
1648
+ " text \n",
1649
+ "330 nct_id: NCT02164643\\nsummary: A Multicenter na... \n",
1650
+ "331 nct_id: NCT02164643\\nsummary: A Multicenter na... "
1651
+ ]
1652
+ },
1653
+ "execution_count": 111,
1654
+ "metadata": {},
1655
+ "output_type": "execute_result"
1656
+ }
1657
+ ],
1658
+ "source": [
1659
+ "df2 = df.iloc[I[0]]\n",
1660
+ "df2"
1661
+ ]
1662
+ },
1663
+ {
1664
+ "cell_type": "code",
1665
+ "execution_count": 113,
1666
+ "id": "af5bf8e2-43b6-47af-affa-5111789371ad",
1667
+ "metadata": {},
1668
+ "outputs": [
1669
+ {
1670
+ "data": {
1671
+ "text/plain": [
1672
+ "'nct_id: NCT02164643\\nsummary: A Multicenter national longitudinal cohort study including at least 800 individuals consecutively recruited from French Research Memory Centers and followed-up over 24 month and included in Memento.\\nintervention_type: Drug\\nintervention_name: Florbetapir (18F)\\nintervention_description: nan\\nkeywords: [\"Alzheimer\\'s disease\", \\'Mild Cognitive Impairment\\']'"
1673
+ ]
1674
+ },
1675
+ "execution_count": 113,
1676
+ "metadata": {},
1677
+ "output_type": "execute_result"
1678
+ }
1679
+ ],
1680
+ "source": [
1681
+ "df2.iloc[1].text"
1682
+ ]
1683
+ },
1684
+ {
1685
+ "cell_type": "code",
1686
+ "execution_count": null,
1687
+ "id": "f3899f81-e120-475c-97ed-080cb7f46510",
1688
+ "metadata": {},
1689
+ "outputs": [],
1690
+ "source": []
1691
+ }
1692
+ ],
1693
+ "metadata": {
1694
+ "kernelspec": {
1695
+ "display_name": "Python 3 (ipykernel)",
1696
+ "language": "python",
1697
+ "name": "python3"
1698
+ },
1699
+ "language_info": {
1700
+ "codemirror_mode": {
1701
+ "name": "ipython",
1702
+ "version": 3
1703
+ },
1704
+ "file_extension": ".py",
1705
+ "mimetype": "text/x-python",
1706
+ "name": "python",
1707
+ "nbconvert_exporter": "python",
1708
+ "pygments_lexer": "ipython3",
1709
+ "version": "3.10.14"
1710
+ }
1711
+ },
1712
+ "nbformat": 4,
1713
+ "nbformat_minor": 5
1714
+ }
database.ipynb ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Vector Search "
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os, pandas as pd\n",
17
+ "from sqlalchemy import create_engine, text"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "username = 'demo'\n",
27
+ "password = 'demo'\n",
28
+ "hostname = os.getenv('IRIS_HOSTNAME', 'localhost')\n",
29
+ "port = '1972' \n",
30
+ "namespace = 'USER'\n",
31
+ "CONNECTION_STRING = f\"iris://{username}:{password}@{hostname}:{port}/{namespace}\"\n",
32
+ "\n",
33
+ "engine = create_engine(CONNECTION_STRING)"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "# Load knowledge graph\n",
43
+ "entity_embeddings = pd.read_csv('./data/entity_embeddings.csv', index_col=0)\n",
44
+ "entity_embeddings[\"embedding\"] = entity_embeddings[\"embedding\"].apply(\n",
45
+ " lambda x: x[1:-1])\n",
46
+ "\n",
47
+ "len_label = entity_embeddings['label'].str.len().max()\n",
48
+ "len_uri = entity_embeddings['uri'].str.len().max()\n",
49
+ "# TODO: set varchar length dynamically as above\n",
50
+ "with engine.connect() as conn:\n",
51
+ " with conn.begin(): \n",
52
+ " result = conn.execute(text('DROP TABLE IF EXISTS Test.EntityEmbeddings'))\n",
53
+ " sql = f\"\"\"\n",
54
+ " CREATE TABLE Test.EntityEmbeddings (\n",
55
+ " embedding VECTOR(DOUBLE, 50),\n",
56
+ " label VARCHAR(143),\n",
57
+ " uri VARCHAR(38)\n",
58
+ " )\n",
59
+ " \"\"\"\n",
60
+ " result = conn.execute(text(sql))\n",
61
+ "\n",
62
+ "with engine.connect() as conn:\n",
63
+ " with conn.begin():\n",
64
+ " for index, row in entity_embeddings.iterrows():\n",
65
+ " sql = text(\"\"\"\n",
66
+ " INSERT INTO Test.EntityEmbeddings \n",
67
+ " (embedding, label, uri) \n",
68
+ " VALUES (TO_VECTOR(:embedding), :label, :uri)\n",
69
+ " \"\"\")\n",
70
+ " conn.execute(sql, {\n",
71
+ " 'embedding': str(row['embedding']),\n",
72
+ " 'label': row['label'], \n",
73
+ " 'uri': row['uri']\n",
74
+ " })\n"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "# Calculate distance between entities\n",
84
+ "with engine.connect() as conn:\n",
85
+ " with conn.begin():\n",
86
+ " sql = f\"\"\"\n",
87
+ " SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2,\n",
88
+ " VECTOR_COSINE(e1.embedding, e2.embedding) AS distance\n",
89
+ " FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2\n",
90
+ " WHERE e1.uri = 'http://identifiers.org/medgen/C0002395'\n",
91
+ " ORDER BY distance DESC\n",
92
+ " \"\"\"\n",
93
+ " result = conn.execute(text(sql))\n",
94
+ " data = result.fetchall()\n",
95
+ " display(data)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "# Load clinical trials\n",
105
+ "\n",
106
+ "relation_embeddings = pd.read_csv('./data/relation_embeddings.csv', index_col=0)\n",
107
+ "relation_embeddings[\"embedding\"] = relation_embeddings[\"embedding\"].apply(\n",
108
+ " lambda x: x[1:-1])\n",
109
+ "\n",
110
+ "len_label = relation_embeddings['label'].str.len().max()\n",
111
+ "len_uri = relation_embeddings['uri'].str.len().max()\n",
112
+ "# TODO: set varchar length dynamically as above\n",
113
+ "with engine.connect() as conn:\n",
114
+ " with conn.begin():# Load \n",
115
+ " result = conn.execute(text('DROP TABLE IF EXISTS Test.RelationEmbeddings'))\n",
116
+ " sql = f\"\"\"\n",
117
+ " CREATE TABLE Test.RelationEmbeddings (\n",
118
+ " embedding VECTOR(DOUBLE, 50),\n",
119
+ " label VARCHAR(10),\n",
120
+ " uri VARCHAR(38)\n",
121
+ " )\n",
122
+ " \"\"\"\n",
123
+ " result = conn.execute(text(sql))\n",
124
+ "\n",
125
+ "with engine.connect() as conn:\n",
126
+ " with conn.begin():\n",
127
+ " for index, row in relation_embeddings.iterrows():\n",
128
+ " sql = text(\"\"\"\n",
129
+ " INSERT INTO Test.ClinicalTrials \n",
130
+ " (embedding, label, uri) \n",
131
+ " VALUES (TO_VECTOR(:embedding), :label, :uri)\n",
132
+ " \"\"\")\n",
133
+ " conn.execute(sql, {\n",
134
+ " 'embedding': str(row['embedding']),\n",
135
+ " 'label': row['label'], \n",
136
+ " 'uri': row['uri']\n",
137
+ " })\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "# Load knowledge graph\n",
147
+ "clinical_trials = pd.read_csv(\"clinical_trials_embeddings.csv\")\n",
148
+ "clinical_trials[\"embeddings\"] = clinical_trials[\"embeddings\"].apply(lambda x: x[1:-1])\n",
149
+ "display(clinical_trials.head())\n",
150
+ "\n",
151
+ "# TODO: set varchar length dynamically as above\n",
152
+ "with engine.connect() as conn:\n",
153
+ " with conn.begin():\n",
154
+ " result = conn.execute(text(\"DROP TABLE IF EXISTS Test.ClinicalTrials\"))\n",
155
+ " sql = f\"\"\"\n",
156
+ " CREATE TABLE Test.ClinicalTrials (\n",
157
+ " nct_id VARCHAR(11) PRIMARY KEY,\n",
158
+ " diseases TEXT,\n",
159
+ " embedding VECTOR(DOUBLE, 768)\n",
160
+ " )\n",
161
+ " \"\"\"\n",
162
+ " result = conn.execute(text(sql))\n",
163
+ "\n",
164
+ "with engine.connect() as conn:\n",
165
+ " with conn.begin():\n",
166
+ " for index, row in clinical_trials.iterrows():\n",
167
+ "\n",
168
+ " sql = text(\n",
169
+ " \"\"\"\n",
170
+ " INSERT INTO Test.ClinicalTrials \n",
171
+ " (nct_id, diseases, embedding)\n",
172
+ " VALUES (:nct_id, :diseases, TO_VECTOR(:embedding))\n",
173
+ " \"\"\"\n",
174
+ " )\n",
175
+ " conn.execute(\n",
176
+ " sql,\n",
177
+ " {\n",
178
+ " \"nct_id\": row[\"nct_id\"],\n",
179
+ " \"diseases\": row[\"desease_condition\"],\n",
180
+ " \"embedding\": str(row[\"embeddings\"]),\n",
181
+ " },\n",
182
+ " )"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "# %%\n",
192
+ "import pandas as pd\n",
193
+ "import rdflib\n",
194
+ "\n",
195
+ "# Load the disease descriptions from MGDEF.RRF\n",
196
+ "df_disease_descriptions = pd.read_csv(\"MGDEF.RRF\", sep=\"|\", header=0)\n",
197
+ "# Rename the column '#CUI' to 'CUI'\n",
198
+ "df_disease_descriptions.rename(columns={\"#CUI\": \"CUI\"}, inplace=True)\n",
199
+ "# Remove the last column, it's empty\n",
200
+ "df_disease_descriptions = df_disease_descriptions.iloc[:, :-1]\n",
201
+ "# Filter out the rows where the SUPPRESS field is equal to 'Y'\n",
202
+ "df_disease_descriptions = df_disease_descriptions[df_disease_descriptions[\"SUPPRESS\"] != \"Y\"]\n",
203
+ "# Some of the rows include a \\n character, so we need to remove the rows where the CUI field contains spaces or doesn't start with 'C'\n",
204
+ "df_disease_descriptions = df_disease_descriptions[df_disease_descriptions[\"CUI\"].str.startswith(\"C\") & ~df_disease_descriptions[\"CUI\"].str.contains(\" \")]\n",
205
+ "# Remove the rows where the DEF field is empty\n",
206
+ "df_disease_descriptions = df_disease_descriptions[df_disease_descriptions[\"DEF\"].notnull()]\n",
207
+ "df_disease_descriptions['uri'] = df_disease_descriptions['CUI'].apply(lambda x: f'http://identifiers.org/medgen/{x}')\n",
208
+ "\n",
209
+ "with engine.connect() as conn:\n",
210
+ " with conn.begin(): \n",
211
+ " result = conn.execute(text('DROP TABLE IF EXISTS Test.DiseaseDescriptions'))\n",
212
+ " sql = f\"\"\"\n",
213
+ " CREATE TABLE Test.DiseaseDescriptions (\n",
214
+ " uri VARCHAR(50),\n",
215
+ " description TEXT\n",
216
+ " )\n",
217
+ " \"\"\"\n",
218
+ " result = conn.execute(text(sql))\n",
219
+ "\n",
220
+ "with engine.connect() as conn:\n",
221
+ " with conn.begin():\n",
222
+ " for index, row in df_disease_descriptions.iterrows():\n",
223
+ " print(row['DEF'])\n",
224
+ " print(row['uri'])\n",
225
+ " sql = text(\"\"\"\n",
226
+ " INSERT INTO Test.DiseaseDescriptions \n",
227
+ " (uri, description) \n",
228
+ " VALUES ( :uri, :description)\n",
229
+ " \"\"\")\n",
230
+ " conn.execute(sql, {\n",
231
+ " 'uri': row['uri'],\n",
232
+ " 'description': row['DEF'], \n",
233
+ " })"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
+ "metadata": {},
240
+ "outputs": [],
241
+ "source": []
242
+ }
243
+ ],
244
+ "metadata": {
245
+ "kernelspec": {
246
+ "display_name": "treehacks",
247
+ "language": "python",
248
+ "name": "python3"
249
+ },
250
+ "language_info": {
251
+ "codemirror_mode": {
252
+ "name": "ipython",
253
+ "version": 3
254
+ },
255
+ "file_extension": ".py",
256
+ "mimetype": "text/x-python",
257
+ "name": "python",
258
+ "nbconvert_exporter": "python",
259
+ "pygments_lexer": "ipython3",
260
+ "version": "3.11.9"
261
+ }
262
+ },
263
+ "nbformat": 4,
264
+ "nbformat_minor": 2
265
+ }
disease_descriptions_with_embeddings.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:098d7a2172812d9eeaaa5f6be94d356d88e2338ee80425d300866e59ea5008db
3
+ size 1075089337
docker-compose.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ iris:
3
+ image: intersystemsdc/iris-community:latest
4
+ environment:
5
+ IRIS_USERNAME: demo
6
+ IRIS_PASSWORD: demo
7
+ restart: always
8
+ hostname: iris
9
+ ports:
10
+ - 1972:1972
11
+ - 52773:52773
entity_embeddings.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2d933ddb9bf7777abcc3e6aaf95c8e4212c6b2d3e145e8a5508d79a6fe01818
3
+ size 86750825
get_embeddings_of_disease_descriptions.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import pandas as pd
3
+
4
+ # Load the disease descriptions from MGDEF.RRF
5
+ df_disease_descriptions = pd.read_csv("MGDEF.RRF", sep="|", header=0)
6
+ # Rename the column '#CUI' to 'CUI'
7
+ df_disease_descriptions.rename(columns={"#CUI": "CUI"}, inplace=True)
8
+ # Rename the column 'DEF' to 'definition'
9
+ df_disease_descriptions.rename(columns={"DEF": "definition"}, inplace=True)
10
+ # Remove the last column, it's empty
11
+ df_disease_descriptions = df_disease_descriptions.iloc[:, :-1]
12
+ # Filter out the rows where the SUPPRESS field is equal to 'Y'
13
+ df_disease_descriptions = df_disease_descriptions[
14
+ df_disease_descriptions["SUPPRESS"] != "Y"
15
+ ]
16
+ # Some of the rows include a \n character, so we need to remove the rows where the CUI field contains spaces or doesn't start with 'C'
17
+ df_disease_descriptions = df_disease_descriptions[
18
+ df_disease_descriptions["CUI"].str.startswith("C")
19
+ & ~df_disease_descriptions["CUI"].str.contains(" ")
20
+ ]
21
+ # Remove the rows where the DEF field is empty
22
+ df_disease_descriptions = df_disease_descriptions[
23
+ df_disease_descriptions["definition"].notnull()
24
+ ]
25
+ df_disease_descriptions["uri"] = df_disease_descriptions["CUI"].apply(
26
+ lambda x: f"http://identifiers.org/medgen/{x}"
27
+ )
28
+ # Drop the columns that are not needed (source, SUPPRESS, CUI)
29
+ df_disease_descriptions.drop(columns=["source", "SUPPRESS", "CUI"], inplace=True)
30
+
31
+ # Drop the descriptions that are duplicates
32
+ df_disease_descriptions.drop_duplicates(subset=["definition"], inplace=True)
33
+
34
+ # Reset the index
35
+ df_disease_descriptions.reset_index(drop=True, inplace=True)
36
+
37
+ # %%
38
+ from sentence_transformers import SentenceTransformer
39
+
40
+ encoder = SentenceTransformer("allenai-specter")
41
+ vectors = encoder.encode(
42
+ df_disease_descriptions.definition, show_progress_bar=True, batch_size=64
43
+ )
44
+ vectors.shape
45
+
46
+ # %%
47
+ import numpy as np
48
+
49
+ df_disease_descriptions["embeddings"] = vectors.astype(
50
+ "float32", casting="same_kind"
51
+ ).tolist()
52
+ # %%
53
+ # Write to a CSV file
54
+ df_disease_descriptions.to_csv(
55
+ "disease_descriptions_with_embeddings.csv", index=False, header=True
56
+ )
57
+
58
+ # %%
graph.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import rdflib
3
+ import pandas as pd
4
+
5
+
6
+ def get_graph():
7
+ # File with the graph: MGCONSO.RRF
8
+ df_concepts = pd.read_csv("MGCONSO.RRF", sep="|", header=0)
9
+ # Rename the column '#CUI' to 'CUI'
10
+ df_concepts.rename(columns={"#CUI": "CUI"}, inplace=True)
11
+ # Remove the last column, it's empty
12
+ df_concepts = df_concepts.iloc[:, :-1]
13
+ print(df_concepts.head())
14
+ # Create a graph
15
+ g = rdflib.Graph()
16
+ # Bind the namespace
17
+ g.bind("medgen", "http://identifiers.org/medgen/")
18
+ # Iterate over the rows
19
+ for i, row in df_concepts.iterrows():
20
+ if row.SUPPRESS == "Y":
21
+ continue
22
+ if row.ISPREF == "Y" and row.STT == "PF" and row.TS == "P":
23
+ # Create the URI
24
+ uri = rdflib.URIRef(f"http://identifiers.org/medgen/{row.CUI}")
25
+ # Add the triple
26
+ g.add((uri, rdflib.RDFS.label, rdflib.Literal(row.STR)))
27
+
28
+ # Now, load MGREL.RRF
29
+ df_relations = pd.read_csv("MGREL.RRF", sep="|", header=0)
30
+ # Rename the column '#CUI1' to 'CUI1'
31
+ df_relations.rename(columns={"#CUI1": "CUI1"}, inplace=True)
32
+ # Remove the last column, it's empty
33
+ df_relations = df_relations.iloc[:, :-1]
34
+ print(df_relations.head())
35
+ # Iterate over the rows
36
+ for i, row in df_relations.iterrows():
37
+ if row.SUPPRESS == "Y":
38
+ continue
39
+ # Create the URI
40
+ uri1 = rdflib.URIRef(f"http://identifiers.org/medgen/{row.CUI1}")
41
+ uri2 = rdflib.URIRef(f"http://identifiers.org/medgen/{row.CUI2}")
42
+ # Add the triple
43
+ if row.REL == "RL":
44
+ g.add((uri1, rdflib.URIRef("related"), uri2))
45
+ continue
46
+ g.add((uri1, rdflib.URIRef(f"http://identifiers.org/medgen/{row.REL}"), uri2))
47
+
48
+ return g
49
+
50
+ def apply_rules_to_graph(g):
51
+ # Now, apply this rule: if two nodes have the same parent (i.e. node1 RB node2 and node3 RB node2, then node1 related node3)
52
+ # Query the graph to get the parents of each node
53
+ query = """
54
+ PREFIX medgen: <http://identifiers.org/medgen/>
55
+ SELECT DISTINCT ?parent ?child1 ?child2 WHERE {
56
+ ?parent medgen:RN ?child1 .
57
+ ?parent medgen:RN ?child2 .
58
+ FILTER (?child1 != ?child2)
59
+ }
60
+ """
61
+ res = g.query(query)
62
+ for row in res:
63
+ g.add((row.child1, rdflib.URIRef("related"), row.child2))
64
+ g.add((row.child2, rdflib.URIRef("related"), row.child1))
65
+ return g
66
+
67
+
68
+ def get_labels_of_entities():
69
+ """
70
+ Returns a dictionary with the labels of the entities
71
+ """
72
+ # File with the graph: MGCONSO.RRF
73
+ df_concepts = pd.read_csv("MGCONSO.RRF", sep="|", header=0)
74
+ # Rename the column '#CUI' to 'CUI'
75
+ df_concepts.rename(columns={"#CUI": "CUI"}, inplace=True)
76
+ # Remove the last column, it's empty
77
+ df_concepts = df_concepts.iloc[:, :-1]
78
+ # Create a dictionary
79
+ labels_of_entities = {}
80
+ # Iterate over the rows
81
+ for i, row in df_concepts.iterrows():
82
+ if row.SUPPRESS == "Y":
83
+ continue
84
+ if row.ISPREF == "Y" and row.STT == "PF" and row.TS == "P":
85
+ labels_of_entities[f"http://identifiers.org/medgen/{row.CUI}"] = row.STR
86
+ return labels_of_entities
87
+
88
+
89
+ def generate_triples_file(graph: rdflib.Graph):
90
+ with open("triples_medgen.tsv", "w") as f:
91
+ # Output the triples ?s ?p ?o
92
+ for s, p, o in graph.triples((None, rdflib.URIRef("related"), None)):
93
+ f.write(f"{s}\t{p}\t{o}\n")
94
+ for s, p, o in graph.triples(
95
+ (None, rdflib.URIRef("http://identifiers.org/medgen/RN"), None)
96
+ ):
97
+ f.write(f"{s}\t{p}\t{o}\n")
98
+ for s, p, o in graph.triples(
99
+ (None, rdflib.URIRef("http://identifiers.org/medgen/RB"), None)
100
+ ):
101
+ f.write(f"{s}\t{p}\t{o}\n")
102
+ for s, p, o in graph.triples((None, rdflib.URIRef("http://identifiers.org/medgen/PAR"), None)):
103
+ f.write(f"{s}\t{p}\t{o}\n")
104
+ for s, p, o in graph.triples((None, rdflib.URIRef("http://identifiers.org/medgen/CHD"), None)):
105
+ f.write(f"{s}\t{p}\t{o}\n")
106
+
107
+
108
+ def save_adjacency_matrix():
109
+ # Load the triples file generated
110
+ df = pd.read_csv("triples_medgen.tsv", sep="\t", header=None)
111
+ # Now output the adjacency matrix, where the rows are the subjects and the columns are the objects
112
+ # The values are the relations (i.e. 0 if no relation and 1 if there is a relation)
113
+ # Get the unique subjects and objects
114
+ subjects = df[0].unique()
115
+ objects = df[2].unique()
116
+ # Create the adjacency matrix
117
+ adj_matrix = pd.DataFrame(0, index=subjects, columns=objects)
118
+ # Iterate over the rows
119
+ for i, row in df.iterrows():
120
+ adj_matrix.loc[row[0], row[2]] = 1
121
+ # Save the adjacency matrix
122
+ adj_matrix.to_csv("adjacency_matrix.mat", sep="\t")
123
+
124
+
125
+ # %%
126
+ g = get_graph()
127
+ # %%
128
+ g = apply_rules_to_graph(g)
129
+ # %%
130
+ labels_of_entities = get_labels_of_entities()
131
+ # %%
132
+ generate_triples_file(g)
133
+ # %%
134
+ from pykeen.triples import TriplesFactory
135
+ from pykeen.models import TuckER, TransE, TransH
136
+ from pykeen.pipeline import pipeline
137
+
138
+ tf = TriplesFactory.from_path("triples_medgen.tsv")
139
+ print(f"Triples count: {tf.num_triples}")
140
+ training, testing, validation = tf.split([0.8, 0.1, 0.1], random_state=42, randomize_cleanup=False)
141
+ result = pipeline(
142
+ training=training,
143
+ testing=testing,
144
+ validation=validation,
145
+ model=TransE,
146
+ stopper="early",
147
+ epochs=500, # short epochs for testing - you should go
148
+ # higher, especially with early stopper enabled
149
+ )
150
+ result.save_to_directory("doctests/test_unstratified_stopped_complex")
151
+ # %%
152
+ import torch
153
+
154
+ alzheimers = "http://identifiers.org/medgen/C1843013"
155
+ # What does the model predict for Alzheimer's disease?
156
+ model = result.model
157
+ alzheimers_id = tf.entity_to_id[alzheimers]
158
+ relation_id = tf.relation_to_id["related"]
159
+
160
+ batch_to_predict = torch.tensor([[alzheimers_id, relation_id]])
161
+
162
+ alzheimers_pred = model.predict_t(hr_batch=batch_to_predict)
163
+
164
+ print(alzheimers_pred.shape)
165
+ # Get the indices of the top 10 predictions
166
+ top10 = torch.topk(alzheimers_pred, 10, largest=True)
167
+ # Get the entities
168
+ entities = tf.entity_id_to_label
169
+ print(top10.indices)
170
+ for i in top10.indices[0]:
171
+ # Ask the graph, what is the label for this entity?
172
+ query = f"""
173
+ PREFIX medgen: <http://identifiers.org/medgen/>
174
+ SELECT ?label WHERE {{
175
+ <{entities[i.item()]}> <http://www.w3.org/2000/01/rdf-schema#label> ?label
176
+ }}
177
+ """
178
+ res = g.query(query)
179
+ for i, row in enumerate(res):
180
+ print(f"{i}: {row}")
181
+ # %%
182
+ from pykeen.nn.representation import Embedding
183
+
184
+ # Get the embeddings of all the entities
185
+ entity_ids = torch.LongTensor(list(tf.entity_to_id.values())).cuda()
186
+ entity_embeddings: Embedding = model.entity_representations[0]._embeddings(entity_ids)
187
+ # Get the embeddings of the relations
188
+ relation_ids = torch.LongTensor(list(tf.relation_to_id.values())).cuda()
189
+ relation_embeddings: Embedding = model.relation_representations[0]._embeddings(
190
+ relation_ids
191
+ )
192
+
193
+ print(f"Entity embeddings shape: {entity_embeddings.shape}")
194
+ print(f"Relation embeddings shape: {relation_embeddings.shape}")
195
+
196
+ # Store the embeddings in a DataFrame
197
+ df = pd.DataFrame(
198
+ {
199
+ "embedding": entity_embeddings.detach().cpu().tolist(),
200
+ "label": [
201
+ labels_of_entities[tf.entity_id_to_label[i]] if tf.entity_id_to_label[i] in labels_of_entities else ""
202
+ for i in range(len(tf.entity_id_to_label))
203
+ ],
204
+ "uri": [
205
+ f"{tf.entity_id_to_label[i]}" for i in range(len(tf.entity_id_to_label))
206
+ ],
207
+ },
208
+ index=range(len(entity_embeddings)),
209
+ )
210
+ ## Save the DataFrame
211
+ df.to_csv("entity_embeddings.csv")
212
+
213
+ # Store the embeddings in a DataFrame
214
+ df = pd.DataFrame(
215
+ {
216
+ "embedding": relation_embeddings.detach().cpu().tolist(),
217
+ "label": [
218
+ tf.relation_id_to_label[i] for i in range(len(tf.relation_id_to_label))
219
+ ],
220
+ "uri": [
221
+ f"{tf.relation_id_to_label[i]}" for i in range(len(tf.relation_id_to_label))
222
+ ],
223
+ },
224
+ index=range(len(relation_embeddings)),
225
+ )
226
+ ## Save the DataFrame
227
+ df.to_csv("relation_embeddings.csv")
228
+
229
+ # %%
230
+ import pyobo
231
+
232
+ pyobo.get_name("mesh", "16793")
233
+
234
+ # %%
graph_analysis.m ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Read the CSV file
2
+ data = readtable('MGREL.RRF', Delimiter='|', FileType='text', NumHeaderLines=0, VariableNamingRule='preserve');
3
+ data = renamevars(data,"#CUI1","CUI1");
4
+ data = data(1:1000,:);
5
+ ids_1 = data.CUI1;
6
+ for k = 1 : length(ids_1)
7
+ cellContents = ids_1{k};
8
+ % Truncate and stick back into the cell
9
+ ids_1{k} = cellContents(2:end);
10
+ end
11
+ ids_1 = str2double(ids_1);
12
+ ids_2 = data.CUI2;
13
+ ids_2 = data.CUI1(2:end);
14
+ for k = 1 : length(ids_2)
15
+ cellContents = ids_2{k};
16
+ % Truncate and stick back into the cell
17
+ ids_2{k} = cellContents(2:end);
18
+ end
19
+ ids_2 = str2double(ids_2);
20
+
21
+
22
+ ids_1 = ids_1(1:end-1);
23
+ ids_2 = ids_2(2:end);
24
+
25
+
26
+ % Get the number of unique nodes
27
+ %nodes = unique([ids_1; ids_2]);
28
+ %num_nodes = length(nodes);
29
+
30
+ % Initialize sparse adjacency matrix
31
+ %A = sparse(ids_1, ids_2, 1, max(ids_2), max(ids_2));
32
+ % Display adjacency matrix
33
+ %disp(A);
34
+
35
+
36
+ %G = digraph(A);
37
+ G = digraph(ids_1, ids_2);
38
+ [bin,binsize] = conncomp(G,'Type','weak');
39
+ bin(1:100)
40
+ size(unique(bin))
41
+ max(binsize)
42
+ pg_ranks = centrality(G,'pagerank');
43
+ G.Nodes.PageRank = pg_ranks;
44
+ %hub_ranks = centrality(G,'hubs');
45
+ %auth_ranks = centrality(G,'authorities');
46
+ %G.Nodes.Hubs = hub_ranks;
47
+ %G.Nodes.Authorities = auth_ranks;
48
+ G.Nodes
49
+ %plot(G);
graph_visualization.mlapp ADDED
Binary file (30.2 kB). View file
 
main.ipynb ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 12,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "text/html": [
11
+ "<div>\n",
12
+ "<style scoped>\n",
13
+ " .dataframe tbody tr th:only-of-type {\n",
14
+ " vertical-align: middle;\n",
15
+ " }\n",
16
+ "\n",
17
+ " .dataframe tbody tr th {\n",
18
+ " vertical-align: top;\n",
19
+ " }\n",
20
+ "\n",
21
+ " .dataframe thead th {\n",
22
+ " text-align: right;\n",
23
+ " }\n",
24
+ "</style>\n",
25
+ "<table border=\"1\" class=\"dataframe\">\n",
26
+ " <thead>\n",
27
+ " <tr style=\"text-align: right;\">\n",
28
+ " <th></th>\n",
29
+ " <th>id</th>\n",
30
+ " <th>nct_id</th>\n",
31
+ " <th>mesh_term</th>\n",
32
+ " <th>downcase_mesh_term</th>\n",
33
+ " <th>mesh_type</th>\n",
34
+ " </tr>\n",
35
+ " </thead>\n",
36
+ " <tbody>\n",
37
+ " <tr>\n",
38
+ " <th>0</th>\n",
39
+ " <td>336369685</td>\n",
40
+ " <td>NCT04016870</td>\n",
41
+ " <td>Infections</td>\n",
42
+ " <td>infections</td>\n",
43
+ " <td>mesh-ancestor</td>\n",
44
+ " </tr>\n",
45
+ " <tr>\n",
46
+ " <th>1</th>\n",
47
+ " <td>336369788</td>\n",
48
+ " <td>NCT03266874</td>\n",
49
+ " <td>Necrosis</td>\n",
50
+ " <td>necrosis</td>\n",
51
+ " <td>mesh-list</td>\n",
52
+ " </tr>\n",
53
+ " <tr>\n",
54
+ " <th>2</th>\n",
55
+ " <td>336369897</td>\n",
56
+ " <td>NCT02743455</td>\n",
57
+ " <td>Fever</td>\n",
58
+ " <td>fever</td>\n",
59
+ " <td>mesh-list</td>\n",
60
+ " </tr>\n",
61
+ " <tr>\n",
62
+ " <th>3</th>\n",
63
+ " <td>336370004</td>\n",
64
+ " <td>NCT01683877</td>\n",
65
+ " <td>Neoplasms</td>\n",
66
+ " <td>neoplasms</td>\n",
67
+ " <td>mesh-ancestor</td>\n",
68
+ " </tr>\n",
69
+ " <tr>\n",
70
+ " <th>4</th>\n",
71
+ " <td>336370095</td>\n",
72
+ " <td>NCT01268579</td>\n",
73
+ " <td>Carcinoma</td>\n",
74
+ " <td>carcinoma</td>\n",
75
+ " <td>mesh-list</td>\n",
76
+ " </tr>\n",
77
+ " </tbody>\n",
78
+ "</table>\n",
79
+ "</div>"
80
+ ],
81
+ "text/plain": [
82
+ " id nct_id mesh_term downcase_mesh_term mesh_type\n",
83
+ "0 336369685 NCT04016870 Infections infections mesh-ancestor\n",
84
+ "1 336369788 NCT03266874 Necrosis necrosis mesh-list\n",
85
+ "2 336369897 NCT02743455 Fever fever mesh-list\n",
86
+ "3 336370004 NCT01683877 Neoplasms neoplasms mesh-ancestor\n",
87
+ "4 336370095 NCT01268579 Carcinoma carcinoma mesh-list"
88
+ ]
89
+ },
90
+ "execution_count": 12,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "import pandas as pd\n",
97
+ "\n",
98
+ "df = pd.read_csv('file_db/browse_conditions.txt', delimiter='|') # Use the appropriate delimiter if not tab-separated\n",
99
+ "\n",
100
+ "df.head()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 13,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "files_to_keep = [\"brief_summaries\", \"interventions\", \"keywords\", \"browse_conditions\"]\n",
110
+ "\n",
111
+ "# maybe \"study_references\" \"sponsors\" \"overall_officials\" \"pending_results\" \"outcome_analyses\" \"provided_documents\" \"reported_event_totals\" \"responsible_parties\"\n",
112
+ "\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 14,
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "data": {
122
+ "text/html": [
123
+ "<div>\n",
124
+ "<style scoped>\n",
125
+ " .dataframe tbody tr th:only-of-type {\n",
126
+ " vertical-align: middle;\n",
127
+ " }\n",
128
+ "\n",
129
+ " .dataframe tbody tr th {\n",
130
+ " vertical-align: top;\n",
131
+ " }\n",
132
+ "\n",
133
+ " .dataframe thead th {\n",
134
+ " text-align: right;\n",
135
+ " }\n",
136
+ "</style>\n",
137
+ "<table border=\"1\" class=\"dataframe\">\n",
138
+ " <thead>\n",
139
+ " <tr style=\"text-align: right;\">\n",
140
+ " <th></th>\n",
141
+ " <th>nct_id</th>\n",
142
+ " <th>summary</th>\n",
143
+ " <th>intervention_name</th>\n",
144
+ " <th>intervention_type</th>\n",
145
+ " <th>intervention_description</th>\n",
146
+ " <th>keywords</th>\n",
147
+ " <th>desease_condition</th>\n",
148
+ " </tr>\n",
149
+ " </thead>\n",
150
+ " <tbody>\n",
151
+ " <tr>\n",
152
+ " <th>0</th>\n",
153
+ " <td>NCT03569293</td>\n",
154
+ " <td>The objective of this study is to assess the e...</td>\n",
155
+ " <td>[Placebo for Upadacitinib, Upadacitinib]</td>\n",
156
+ " <td>Drug</td>\n",
157
+ " <td>Tablets taken orally once a day</td>\n",
158
+ " <td>[Atopic Dermatitis, Upadacitinib]</td>\n",
159
+ " <td>[dermatitis, atopic, dermatitis, eczema, skin ...</td>\n",
160
+ " </tr>\n",
161
+ " <tr>\n",
162
+ " <th>2</th>\n",
163
+ " <td>NCT03556839</td>\n",
164
+ " <td>The study will integrate the efficacy of combi...</td>\n",
165
+ " <td>[Atezolizumab, Bevacizumab, Cisplatin/Carbopla...</td>\n",
166
+ " <td>Drug</td>\n",
167
+ " <td>Intravenous Infusion</td>\n",
168
+ " <td>[Cervix, Carcinoma, Atezolizumab]</td>\n",
169
+ " <td>[carcinoma, neoplasms, glandular and epithelia...</td>\n",
170
+ " </tr>\n",
171
+ " <tr>\n",
172
+ " <th>6</th>\n",
173
+ " <td>NCT03526874</td>\n",
174
+ " <td>Migraine affects 10-28% of children and adoles...</td>\n",
175
+ " <td>[Lidocaine 4% Topical Application Cream [LMX 4...</td>\n",
176
+ " <td>Drug</td>\n",
177
+ " <td>Run-in Step: All subjects receive 32 mg (4 cm ...</td>\n",
178
+ " <td>[Episodic Migraine, Headache, Nerve Block, Pai...</td>\n",
179
+ " <td>[pain, migraine disorders, headache, headache ...</td>\n",
180
+ " </tr>\n",
181
+ " <tr>\n",
182
+ " <th>9</th>\n",
183
+ " <td>NCT03526835</td>\n",
184
+ " <td>This is a Phase 1/2 open-label, multi-center, ...</td>\n",
185
+ " <td>[MCLA-158, MCLA-158 +Pembrolizumab]</td>\n",
186
+ " <td>Drug</td>\n",
187
+ " <td>full-length IgG1 bispecific antibody targeting...</td>\n",
188
+ " <td>[Bispecific antibody, First-in-human, MCLA-158...</td>\n",
189
+ " <td>[squamous cell carcinoma of head and neck, neo...</td>\n",
190
+ " </tr>\n",
191
+ " <tr>\n",
192
+ " <th>11</th>\n",
193
+ " <td>NCT02272751</td>\n",
194
+ " <td>This study will aim to compare the effects of ...</td>\n",
195
+ " <td>[Exercise, Relaxation]</td>\n",
196
+ " <td>Behavioral</td>\n",
197
+ " <td>The Exercise intervention will consist of aero...</td>\n",
198
+ " <td>[cancer survivorship, exercise, relaxation, mi...</td>\n",
199
+ " <td>[lymphoma, neoplasms by histologic type, neopl...</td>\n",
200
+ " </tr>\n",
201
+ " </tbody>\n",
202
+ "</table>\n",
203
+ "</div>"
204
+ ],
205
+ "text/plain": [
206
+ " nct_id summary \\\n",
207
+ "0 NCT03569293 The objective of this study is to assess the e... \n",
208
+ "2 NCT03556839 The study will integrate the efficacy of combi... \n",
209
+ "6 NCT03526874 Migraine affects 10-28% of children and adoles... \n",
210
+ "9 NCT03526835 This is a Phase 1/2 open-label, multi-center, ... \n",
211
+ "11 NCT02272751 This study will aim to compare the effects of ... \n",
212
+ "\n",
213
+ " intervention_name intervention_type \\\n",
214
+ "0 [Placebo for Upadacitinib, Upadacitinib] Drug \n",
215
+ "2 [Atezolizumab, Bevacizumab, Cisplatin/Carbopla... Drug \n",
216
+ "6 [Lidocaine 4% Topical Application Cream [LMX 4... Drug \n",
217
+ "9 [MCLA-158, MCLA-158 +Pembrolizumab] Drug \n",
218
+ "11 [Exercise, Relaxation] Behavioral \n",
219
+ "\n",
220
+ " intervention_description \\\n",
221
+ "0 Tablets taken orally once a day \n",
222
+ "2 Intravenous Infusion \n",
223
+ "6 Run-in Step: All subjects receive 32 mg (4 cm ... \n",
224
+ "9 full-length IgG1 bispecific antibody targeting... \n",
225
+ "11 The Exercise intervention will consist of aero... \n",
226
+ "\n",
227
+ " keywords \\\n",
228
+ "0 [Atopic Dermatitis, Upadacitinib] \n",
229
+ "2 [Cervix, Carcinoma, Atezolizumab] \n",
230
+ "6 [Episodic Migraine, Headache, Nerve Block, Pai... \n",
231
+ "9 [Bispecific antibody, First-in-human, MCLA-158... \n",
232
+ "11 [cancer survivorship, exercise, relaxation, mi... \n",
233
+ "\n",
234
+ " desease_condition \n",
235
+ "0 [dermatitis, atopic, dermatitis, eczema, skin ... \n",
236
+ "2 [carcinoma, neoplasms, glandular and epithelia... \n",
237
+ "6 [pain, migraine disorders, headache, headache ... \n",
238
+ "9 [squamous cell carcinoma of head and neck, neo... \n",
239
+ "11 [lymphoma, neoplasms by histologic type, neopl... "
240
+ ]
241
+ },
242
+ "execution_count": 14,
243
+ "metadata": {},
244
+ "output_type": "execute_result"
245
+ }
246
+ ],
247
+ "source": [
248
+ "df_summary = pd.read_csv('file_db/brief_summaries.txt', delimiter='|')\n",
249
+ "df_summary = df_summary.rename(columns={'description': 'summary'})\n",
250
+ "\n",
251
+ "### create and merge intervention ###\n",
252
+ "df_intervention = pd.read_csv('file_db/interventions.txt', delimiter='|')\n",
253
+ "\n",
254
+ "intervention_grouped = df_intervention.groupby('nct_id')['name'].apply(list).reset_index()\n",
255
+ "intervention_grouped = intervention_grouped.rename(columns={'name': 'intervention_name'})\n",
256
+ "merged_df = pd.merge(\n",
257
+ " df_summary[['nct_id', 'summary']], \n",
258
+ " intervention_grouped[['nct_id', 'intervention_name']], \n",
259
+ " on='nct_id')\n",
260
+ "\n",
261
+ "df_intervention = df_intervention.rename(columns={'description': 'intervention_description'})\n",
262
+ "\n",
263
+ "merged_df = pd.merge(\n",
264
+ " merged_df,\n",
265
+ " df_intervention[['nct_id', 'intervention_type', 'intervention_description']], \n",
266
+ " on='nct_id')\n",
267
+ "\n",
268
+ "### create and merge keywords ###\n",
269
+ "df_keyword = pd.read_csv('file_db/keywords.txt', delimiter='|')\n",
270
+ "keywords_grouped = df_keyword.groupby('nct_id')['name'].apply(list).reset_index()\n",
271
+ "keywords_grouped = keywords_grouped.rename(columns={'name': 'keywords'})\n",
272
+ "\n",
273
+ "merged_df = pd.merge(\n",
274
+ " merged_df,\n",
275
+ " keywords_grouped,\n",
276
+ " on='nct_id'\n",
277
+ ")\n",
278
+ "\n",
279
+ "### create and merge browse conditions\n",
280
+ "df_condition = pd.read_csv('file_db/browse_conditions.txt', delimiter='|')\n",
281
+ "conditions_grouped = df_condition.groupby('nct_id')['downcase_mesh_term'].apply(list).reset_index()\n",
282
+ "conditions_grouped = conditions_grouped.rename(columns={'downcase_mesh_term': 'desease_condition'})\n",
283
+ "\n",
284
+ "merged_df = pd.merge(\n",
285
+ " merged_df,\n",
286
+ " conditions_grouped,\n",
287
+ " on='nct_id'\n",
288
+ ")\n",
289
+ "\n",
290
+ "merged_df = merged_df.drop_duplicates(subset='nct_id')\n",
291
+ "\n",
292
+ "merged_df.head()\n",
293
+ "\n"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 15,
299
+ "metadata": {},
300
+ "outputs": [
301
+ {
302
+ "data": {
303
+ "text/html": [
304
+ "<div>\n",
305
+ "<style scoped>\n",
306
+ " .dataframe tbody tr th:only-of-type {\n",
307
+ " vertical-align: middle;\n",
308
+ " }\n",
309
+ "\n",
310
+ " .dataframe tbody tr th {\n",
311
+ " vertical-align: top;\n",
312
+ " }\n",
313
+ "\n",
314
+ " .dataframe thead th {\n",
315
+ " text-align: right;\n",
316
+ " }\n",
317
+ "</style>\n",
318
+ "<table border=\"1\" class=\"dataframe\">\n",
319
+ " <thead>\n",
320
+ " <tr style=\"text-align: right;\">\n",
321
+ " <th></th>\n",
322
+ " <th>desease_condition</th>\n",
323
+ " <th>text</th>\n",
324
+ " </tr>\n",
325
+ " </thead>\n",
326
+ " <tbody>\n",
327
+ " <tr>\n",
328
+ " <th>0</th>\n",
329
+ " <td>[dermatitis, atopic, dermatitis, eczema, skin ...</td>\n",
330
+ " <td>nct_id: NCT03569293\\nsummary: The objective of...</td>\n",
331
+ " </tr>\n",
332
+ " <tr>\n",
333
+ " <th>2</th>\n",
334
+ " <td>[carcinoma, neoplasms, glandular and epithelia...</td>\n",
335
+ " <td>nct_id: NCT03556839\\nsummary: The study will i...</td>\n",
336
+ " </tr>\n",
337
+ " <tr>\n",
338
+ " <th>6</th>\n",
339
+ " <td>[pain, migraine disorders, headache, headache ...</td>\n",
340
+ " <td>nct_id: NCT03526874\\nsummary: Migraine affects...</td>\n",
341
+ " </tr>\n",
342
+ " <tr>\n",
343
+ " <th>9</th>\n",
344
+ " <td>[squamous cell carcinoma of head and neck, neo...</td>\n",
345
+ " <td>nct_id: NCT03526835\\nsummary: This is a Phase ...</td>\n",
346
+ " </tr>\n",
347
+ " <tr>\n",
348
+ " <th>11</th>\n",
349
+ " <td>[lymphoma, neoplasms by histologic type, neopl...</td>\n",
350
+ " <td>nct_id: NCT02272751\\nsummary: This study will ...</td>\n",
351
+ " </tr>\n",
352
+ " </tbody>\n",
353
+ "</table>\n",
354
+ "</div>"
355
+ ],
356
+ "text/plain": [
357
+ " desease_condition \\\n",
358
+ "0 [dermatitis, atopic, dermatitis, eczema, skin ... \n",
359
+ "2 [carcinoma, neoplasms, glandular and epithelia... \n",
360
+ "6 [pain, migraine disorders, headache, headache ... \n",
361
+ "9 [squamous cell carcinoma of head and neck, neo... \n",
362
+ "11 [lymphoma, neoplasms by histologic type, neopl... \n",
363
+ "\n",
364
+ " text \n",
365
+ "0 nct_id: NCT03569293\\nsummary: The objective of... \n",
366
+ "2 nct_id: NCT03556839\\nsummary: The study will i... \n",
367
+ "6 nct_id: NCT03526874\\nsummary: Migraine affects... \n",
368
+ "9 nct_id: NCT03526835\\nsummary: This is a Phase ... \n",
369
+ "11 nct_id: NCT02272751\\nsummary: This study will ... "
370
+ ]
371
+ },
372
+ "execution_count": 15,
373
+ "metadata": {},
374
+ "output_type": "execute_result"
375
+ }
376
+ ],
377
+ "source": [
378
+ "# Concatenate all columns into one written text\n",
379
+ "merged_df['text'] = merged_df.drop(columns=['desease_condition']).apply(lambda row: '\\n'.join([f\"{col}: {val}\" for col, val in row.items()]), axis=1)\n",
380
+ "\n",
381
+ "# Save the DataFrame to a new CSV file\n",
382
+ "merged_df = merged_df[['desease_condition', 'text']]\n",
383
+ "merged_df.to_csv('clinical_trials.csv', index=False)\n",
384
+ "\n",
385
+ "merged_df.head()"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": []
394
+ }
395
+ ],
396
+ "metadata": {
397
+ "kernelspec": {
398
+ "display_name": "env",
399
+ "language": "python",
400
+ "name": "python3"
401
+ },
402
+ "language_info": {
403
+ "codemirror_mode": {
404
+ "name": "ipython",
405
+ "version": 3
406
+ },
407
+ "file_extension": ".py",
408
+ "mimetype": "text/x-python",
409
+ "name": "python",
410
+ "nbconvert_exporter": "python",
411
+ "pygments_lexer": "ipython3",
412
+ "version": "3.12.3"
413
+ }
414
+ },
415
+ "nbformat": 4,
416
+ "nbformat_minor": 2
417
+ }
mock_trial.json ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "protocolSection": {
3
+ "identificationModule": {
4
+ "nctId": "NCT00841061",
5
+ "orgStudyIdInfo": {
6
+ "id": "B530500"
7
+ },
8
+ "secondaryIdInfos": [
9
+ {
10
+ "id": "HD40315-01"
11
+ }
12
+ ],
13
+ "organization": {
14
+ "fullName": "Eunice Kennedy Shriver National Institute of Child Health and Human Development (NICHD)",
15
+ "class": "NIH"
16
+ },
17
+ "briefTitle": "Cereals as a Source of Iron for Breastfed Infants",
18
+ "officialTitle": "Breast Feeding and Iron: Comparison of Cereals Fortified With Different Forms of Iron",
19
+ "acronym": "Bfe03B"
20
+ },
21
+ "statusModule": {
22
+ "statusVerifiedDate": "2009-01",
23
+ "overallStatus": "COMPLETED",
24
+ "expandedAccessInfo": {
25
+ "hasExpandedAccess": false
26
+ },
27
+ "startDateStruct": {
28
+ "date": "2003-07"
29
+ },
30
+ "primaryCompletionDateStruct": {
31
+ "date": "2006-05",
32
+ "type": "ACTUAL"
33
+ },
34
+ "completionDateStruct": {
35
+ "date": "2006-05",
36
+ "type": "ACTUAL"
37
+ },
38
+ "studyFirstSubmitDate": "2009-02-09",
39
+ "studyFirstSubmitQcDate": "2009-02-09",
40
+ "studyFirstPostDateStruct": {
41
+ "date": "2009-02-11",
42
+ "type": "ESTIMATED"
43
+ },
44
+ "lastUpdateSubmitDate": "2009-02-09",
45
+ "lastUpdatePostDateStruct": {
46
+ "date": "2009-02-11",
47
+ "type": "ESTIMATED"
48
+ }
49
+ },
50
+ "sponsorCollaboratorsModule": {
51
+ "responsibleParty": {
52
+ "oldNameTitle": "Dr. Ekhard E. Ziegler",
53
+ "oldOrganization": "University of Iowa"
54
+ },
55
+ "leadSponsor": {
56
+ "name": "National Institutes of Health (NIH)",
57
+ "class": "NIH"
58
+ }
59
+ },
60
+ "oversightModule": {
61
+ "oversightHasDmc": false
62
+ },
63
+ "descriptionModule": {
64
+ "briefSummary": "The purpose of this research study is to determine whether the type of iron in infant cereals makes a differance in how well the cereal helps infants remain free of iron deficiency."
65
+ },
66
+ "conditionsModule": {
67
+ "conditions": [
68
+ "Iron Deficiency"
69
+ ]
70
+ },
71
+ "designModule": {
72
+ "studyType": "INTERVENTIONAL",
73
+ "phases": [
74
+ "NA"
75
+ ],
76
+ "designInfo": {
77
+ "allocation": "RANDOMIZED",
78
+ "interventionModel": "PARALLEL",
79
+ "primaryPurpose": "PREVENTION",
80
+ "maskingInfo": {
81
+ "masking": "QUADRUPLE",
82
+ "whoMasked": [
83
+ "PARTICIPANT",
84
+ "CARE_PROVIDER",
85
+ "INVESTIGATOR",
86
+ "OUTCOMES_ASSESSOR"
87
+ ]
88
+ }
89
+ },
90
+ "enrollmentInfo": {
91
+ "count": 111,
92
+ "type": "ACTUAL"
93
+ }
94
+ },
95
+ "armsInterventionsModule": {
96
+ "armGroups": [
97
+ {
98
+ "label": "Cereal L",
99
+ "type": "ACTIVE_COMPARATOR",
100
+ "description": "Rice cereal with electrolytic iron",
101
+ "interventionNames": [
102
+ "Dietary Supplement: electrolytic iron"
103
+ ]
104
+ },
105
+ {
106
+ "label": "Cereal M",
107
+ "type": "ACTIVE_COMPARATOR",
108
+ "description": "Rice cereal with ferrous fumarate",
109
+ "interventionNames": [
110
+ "Dietary Supplement: ferrous fumarate"
111
+ ]
112
+ }
113
+ ],
114
+ "interventions": [
115
+ {
116
+ "type": "DIETARY_SUPPLEMENT",
117
+ "name": "electrolytic iron",
118
+ "description": "1/4 a cup of cereal fortified with electrolytic iron per day between the ages of 112 days and 280 days of age",
119
+ "armGroupLabels": [
120
+ "Cereal L"
121
+ ]
122
+ },
123
+ {
124
+ "type": "DIETARY_SUPPLEMENT",
125
+ "name": "ferrous fumarate",
126
+ "description": "1/4 cup of cereal fortified with ferrous fumarate to be fed per day between ages of 112 days and 280 days of age",
127
+ "armGroupLabels": [
128
+ "Cereal M"
129
+ ]
130
+ }
131
+ ]
132
+ },
133
+ "outcomesModule": {
134
+ "primaryOutcomes": [
135
+ {
136
+ "measure": "plasma ferritin",
137
+ "timeFrame": "280 days"
138
+ }
139
+ ],
140
+ "secondaryOutcomes": [
141
+ {
142
+ "measure": "hemoglobin",
143
+ "timeFrame": "280"
144
+ }
145
+ ]
146
+ },
147
+ "eligibilityModule": {
148
+ "eligibilityCriteria": "Inclusion Criteria:\n\n* exclusively breastfed\n* birth weight between 2500 and 4200g\n* gestational age \\>36 weeks\n\nExclusion Criteria:\n\n* supplementing formula\n* no iron drops",
149
+ "healthyVolunteers": true,
150
+ "sex": "ALL",
151
+ "minimumAge": "28 Days",
152
+ "maximumAge": "1 Year",
153
+ "stdAges": [
154
+ "CHILD"
155
+ ]
156
+ },
157
+ "contactsLocationsModule": {
158
+ "overallOfficials": [
159
+ {
160
+ "name": "Ekhard E Ziegler, MD",
161
+ "affiliation": "University of Iowa",
162
+ "role": "PRINCIPAL_INVESTIGATOR"
163
+ }
164
+ ],
165
+ "locations": [
166
+ {
167
+ "facility": "University of Iowa",
168
+ "city": "Iowa City",
169
+ "state": "Iowa",
170
+ "zip": "52242",
171
+ "country": "United States",
172
+ "geoPoint": {
173
+ "lat": 41.66113,
174
+ "lon": -91.53017
175
+ }
176
+ }
177
+ ]
178
+ },
179
+ "referencesModule": {
180
+ "references": [
181
+ {
182
+ "pmid": "21178077",
183
+ "type": "DERIVED",
184
+ "citation": "Ziegler EE, Fomon SJ, Nelson SE, Jeter JM, Theuer RC. Dry cereals fortified with electrolytic iron or ferrous fumarate are equally effective in breast-fed infants. J Nutr. 2011 Feb;141(2):243-8. doi: 10.3945/jn.110.127266. Epub 2010 Dec 22."
185
+ }
186
+ ]
187
+ }
188
+ },
189
+ "derivedSection": {
190
+ "miscInfoModule": {
191
+ "versionHolder": "2024-05-03"
192
+ },
193
+ "conditionBrowseModule": {
194
+ "meshes": [
195
+ {
196
+ "id": "D000090463",
197
+ "term": "Iron Deficiencies"
198
+ }
199
+ ],
200
+ "ancestors": [
201
+ {
202
+ "id": "D000019189",
203
+ "term": "Iron Metabolism Disorders"
204
+ },
205
+ {
206
+ "id": "D000008659",
207
+ "term": "Metabolic Diseases"
208
+ }
209
+ ],
210
+ "browseLeaves": [
211
+ {
212
+ "id": "M20857",
213
+ "name": "Anemia, Iron-Deficiency",
214
+ "relevance": "LOW"
215
+ },
216
+ {
217
+ "id": "M2781",
218
+ "name": "Iron Deficiencies",
219
+ "asFound": "Iron Deficiency",
220
+ "relevance": "HIGH"
221
+ },
222
+ {
223
+ "id": "M11639",
224
+ "name": "Metabolic Diseases",
225
+ "relevance": "LOW"
226
+ },
227
+ {
228
+ "id": "M21177",
229
+ "name": "Iron Metabolism Disorders",
230
+ "relevance": "LOW"
231
+ }
232
+ ],
233
+ "browseBranches": [
234
+ {
235
+ "abbrev": "BC15",
236
+ "name": "Blood and Lymph Conditions"
237
+ },
238
+ {
239
+ "abbrev": "BC18",
240
+ "name": "Nutritional and Metabolic Diseases"
241
+ },
242
+ {
243
+ "abbrev": "All",
244
+ "name": "All Conditions"
245
+ }
246
+ ]
247
+ },
248
+ "interventionBrowseModule": {
249
+ "meshes": [
250
+ {
251
+ "id": "C000031621",
252
+ "term": "Ferrous fumarate"
253
+ }
254
+ ],
255
+ "ancestors": [
256
+ {
257
+ "id": "D000014131",
258
+ "term": "Trace Elements"
259
+ },
260
+ {
261
+ "id": "D000018977",
262
+ "term": "Micronutrients"
263
+ },
264
+ {
265
+ "id": "D000045505",
266
+ "term": "Physiological Effects of Drugs"
267
+ }
268
+ ],
269
+ "browseLeaves": [
270
+ {
271
+ "id": "M10533",
272
+ "name": "Iron",
273
+ "relevance": "LOW"
274
+ },
275
+ {
276
+ "id": "M225448",
277
+ "name": "Ferrous fumarate",
278
+ "asFound": "Mouse",
279
+ "relevance": "HIGH"
280
+ },
281
+ {
282
+ "id": "M21009",
283
+ "name": "Micronutrients",
284
+ "relevance": "LOW"
285
+ },
286
+ {
287
+ "id": "M16885",
288
+ "name": "Trace Elements",
289
+ "relevance": "LOW"
290
+ }
291
+ ],
292
+ "browseBranches": [
293
+ {
294
+ "abbrev": "Micro",
295
+ "name": "Micronutrients"
296
+ },
297
+ {
298
+ "abbrev": "All",
299
+ "name": "All Drugs and Chemicals"
300
+ }
301
+ ]
302
+ }
303
+ },
304
+ "hasResults": false
305
+ }
relation_embeddings.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81f5c180ff1a488b185fb73bb512b3e4402d0fcc6b483d1592a768a6a376a261
3
+ size 5747
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sqlalchemy-iris==0.13.3
2
+ datasets==2.19.0
3
+ pandas==2.2.0
4
+ pykeen==1.10.2
5
+ rdflib==7.0.0
6
+ scipy==1.13.0
7
+ pyobo==0.10.11
8
+ langchain==0.1.17
9
+ openai==1.25.1
10
+ sentence_transformers==2.7.0
11
+ streamlit-agraph
12
+ streamlit==1.34.0
utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from typing import List, Dict, Any
3
+ import os
4
+ from sqlalchemy import create_engine, text
5
+ import requests
6
+
7
+ def get_all_diseases_name(engine) -> List[List[str]]:
8
+ with engine.connect() as conn:
9
+ with conn.begin():
10
+ sql = f"""
11
+ SELECT * FROM Test.EntityEmbeddings
12
+ """
13
+ result = conn.execute(text(sql))
14
+ data = result.fetchall()
15
+
16
+ all_diseases = [row[1] for row in data if row[1] != "nan"]
17
+ return all_diseases
18
+
19
+ def get_uri_from_name(engine, name: str) -> str:
20
+ with engine.connect() as conn:
21
+ with conn.begin():
22
+ sql = f"""
23
+ SELECT uri FROM Test.EntityEmbeddings
24
+ WHERE label = '{name}'
25
+ """
26
+ result = conn.execute(text(sql))
27
+ data = result.fetchall()
28
+ return data[0][0].split('/')[-1]
29
+
30
+ def get_most_similar_diseases_from_uri(engine, original_disease_uri: str, threshold: float = 0.8) -> List[str]:
31
+ with engine.connect() as conn:
32
+ with conn.begin():
33
+ sql = f"""
34
+ SELECT * FROM Test.EntityEmbeddings
35
+ """
36
+ result = conn.execute(text(sql))
37
+ data = result.fetchall()
38
+
39
+ all_diseases = [row[1] for row in data if row[1] != "nan"]
40
+ return all_diseases
41
+
42
+ def get_uri_from_name(engine, name: str) -> str:
43
+ with engine.connect() as conn:
44
+ with conn.begin():
45
+ sql = f"""
46
+ SELECT uri FROM Test.EntityEmbeddings
47
+ WHERE label = '{name}'
48
+ """
49
+ result = conn.execute(text(sql))
50
+ data = result.fetchall()
51
+ return data[0][0].split('/')[-1]
52
+
53
+ def get_most_similar_diseases_from_uri(engine, original_disease_uri: str, threshold: float = 0.8) -> List[str]:
54
+ with engine.connect() as conn:
55
+ with conn.begin():
56
+ sql = f"""
57
+ SELECT TOP 10 e1.uri AS uri1, e2.uri AS uri2, e1.label AS label1, e2.label AS label2,
58
+ VECTOR_COSINE(e1.embedding, e2.embedding) AS distance
59
+ FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
60
+ WHERE e1.uri = 'http://identifiers.org/medgen/{original_disease_uri}'
61
+ AND VECTOR_COSINE(e1.embedding, e2.embedding) > {threshold}
62
+ AND e1.uri != e2.uri
63
+ ORDER BY distance DESC
64
+ """
65
+ result = conn.execute(text(sql))
66
+ data = result.fetchall()
67
+
68
+ similar_diseases = [(row[1].split('/')[-1], row[3], row[4]) for row in data if row[3] != "nan"]
69
+ return similar_diseases
70
+
71
+ def get_clinical_record_info(clinical_record_id: str) -> Dict[str, Any]:
72
+ # Request:
73
+ # curl -X GET "https://clinicaltrials.gov/api/v2/studies/NCT00841061" \
74
+ # -H "accept: text/csv"
75
+ request_url = f"https://clinicaltrials.gov/api/v2/studies/{clinical_record_id}"
76
+ response = requests.get(request_url, headers={"accept": "application/json"})
77
+ return response.json()
78
+
79
+ def get_clinical_records_by_ids(clinical_record_ids: List[str]) -> List[Dict[str, Any]]:
80
+ clinical_records = []
81
+ for clinical_record_id in clinical_record_ids:
82
+ clinical_record_info = get_clinical_record_info(clinical_record_id)
83
+ clinical_records.append(clinical_record_info)
84
+ return clinical_records
85
+
86
+
87
+ if __name__ == "__main__":
88
+ username = 'demo'
89
+ password = 'demo'
90
+ hostname = os.getenv('IRIS_HOSTNAME', 'localhost')
91
+ port = '1972'
92
+ namespace = 'USER'
93
+ CONNECTION_STRING = f"iris://{username}:{password}@{hostname}:{port}/{namespace}"
94
+
95
+ try:
96
+ engine = create_engine(CONNECTION_STRING)
97
+ diseases = get_most_similar_diseases_from_uri('C1843013')
98
+ for disease in diseases:
99
+ print(disease)
100
+ except Exception as e:
101
+ print(e)
102
+
103
+ print(get_uri_from_name(engine, 'Alzheimer disease 3'))
104
+
105
+ clinical_record_info = get_clinical_records_by_ids(['NCT00841061'])
106
+ print(clinical_record_info)
107
+