Upload 164 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- README.md +80 -14
- __init__.py +3 -0
- __pycache__/get_R_group_sub_agent.cpython-310.pyc +0 -0
- __pycache__/get_molecular_agent.cpython-310.pyc +0 -0
- __pycache__/get_reaction_agent.cpython-310.pyc +0 -0
- __pycache__/main.cpython-310.pyc +0 -0
- app.ipynb +251 -0
- app.py +202 -0
- apt.txt +2 -0
- chemiener/__init__.py +1 -0
- chemiener/__pycache__/__init__.cpython-310.pyc +0 -0
- chemiener/__pycache__/__init__.cpython-38.pyc +0 -0
- chemiener/__pycache__/dataset.cpython-310.pyc +0 -0
- chemiener/__pycache__/dataset.cpython-38.pyc +0 -0
- chemiener/__pycache__/interface.cpython-310.pyc +0 -0
- chemiener/__pycache__/interface.cpython-38.pyc +0 -0
- chemiener/__pycache__/model.cpython-310.pyc +0 -0
- chemiener/__pycache__/model.cpython-38.pyc +0 -0
- chemiener/__pycache__/utils.cpython-310.pyc +0 -0
- chemiener/__pycache__/utils.cpython-38.pyc +0 -0
- chemiener/dataset.py +172 -0
- chemiener/interface.py +124 -0
- chemiener/main.py +345 -0
- chemiener/model.py +14 -0
- chemiener/utils.py +23 -0
- chemietoolkit/__init__.py +1 -0
- chemietoolkit/__pycache__/__init__.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/__init__.cpython-38.pyc +0 -0
- chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc +0 -0
- chemietoolkit/__pycache__/interface.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/interface.cpython-38.pyc +0 -0
- chemietoolkit/__pycache__/tableextractor.cpython-310.pyc +0 -0
- chemietoolkit/__pycache__/utils.cpython-310.pyc +0 -0
- chemietoolkit/chemrxnextractor.py +107 -0
- chemietoolkit/interface.py +749 -0
- chemietoolkit/tableextractor.py +340 -0
- chemietoolkit/utils.py +1018 -0
- examples/exp.png +3 -0
- examples/image.webp +0 -0
- examples/molecules1.png +0 -0
- examples/molecules2.png +0 -0
- examples/rdkit.png +0 -0
- examples/reaction0.png +3 -0
- examples/reaction1.jpg +0 -0
- examples/reaction2.png +0 -0
- examples/reaction3.png +0 -0
- examples/reaction4.png +3 -0
- examples/reaction5.png +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/exp.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/reaction0.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/reaction4.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/template1.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
molnextr/indigo/lib/Linux/x64/libbingo.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
molnextr/indigo/lib/Linux/x64/libindigo-renderer.so filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
molnextr/indigo/lib/Linux/x64/libindigo.so filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,80 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ChemEagle
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
## :sparkles: Highlights
|
| 5 |
+
<p align="justify">
|
| 6 |
+
In this work, we present ChemEagle, a multimodal large language model (MLLM)-based multi-agent system that integrates diverse chemical information extraction tools to extract multimodal chemical reactions. By integrating 7 expert-designed tools and 6 chemical information extraction agents, ChemEagle not only processes individual modalities but also utilizes MLLMs' reasoning capabilities to unify extracted data, ensuring more accurate and comprehensive reaction representations. By bridging multimodal gaps, our approach significantly improves automated chemical knowledge extraction, facilitating more robust AI-driven chemical research.
|
| 7 |
+
|
| 8 |
+
[comment]: <> ()
|
| 9 |
+

|
| 10 |
+
<div align="center">
|
| 11 |
+
An example workflow of our ChemEagle. It illustrates how ChemEagle extracts and structures multimodal chemical reaction data. Each agent handles specific tasks, from reaction image parsing and molecular recognition to SMILES reconstruction and condition role interpretation, ensuring accurate and structured data integration.
|
| 12 |
+
</div>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## 🤗 Multimodal chemical information extraction using [ChemEagle.Web](https://huggingface.co/spaces/CYF200127/ChemEagle)
|
| 16 |
+
|
| 17 |
+
Go to our [ChemEagle.Web demo](https://huggingface.co/spaces/CYF200127/ChemEagle) to directly use our tool online!
|
| 18 |
+
|
| 19 |
+
The input is a multimodal chemical reaction image:
|
| 20 |
+

|
| 21 |
+
<div align="center",width="100">
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
The output dictionary is a complete reaction list with reactant SMILES, product SMILES, detailed conditions and additional information for every reaction in the image:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
{
|
| 28 |
+
"reactions":[
|
| 29 |
+
{"reaction_id":"0_1",###Reaction template
|
| 30 |
+
"reactants":[{"smiles":"*C(*)=O","label":"1"},{"smiles":"Cc1ccc(S(=O)(=O)N2OC2c2ccccc2Cl)cc1","label":"2"}],
|
| 31 |
+
"conditions":[{"role":"reagent","text":"10 mol% B17 orB27","smiles":"C(C=CC=C1)=C1C[N+]2=CN3[C@H](C(C4=CC=CC=C4)(C5=CC=CC=C5)O[Si](C)(C)C(C)(C)C)CCC3=N2.F[B-](F)(F)F","label":"B17"},{"role":"reagent","text":"10 mol% B17 or B27","smiles":"CCCC(C=CC=C1)=C1[N+]2=CN3[C@H](C(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))O)CCC3=N2.F[B-](F)(F)F","label":"B27"},{"role":"reagent","text":"10 mol% Cs2CO3","smiles":"[Cs+].[Cs+].[O-]C(=O)[O-]"},{"role":"solvent","text":"PhMe","smiles":"Cc1ccccc1"},{"role":"temperature","text":"rt"},{"role":"yield","text":"38-78%"}],
|
| 32 |
+
"products":[{"smiles":"*C1*O[C@H](c2ccccc2Cl)N(S(=O)(=O)c2ccc(C)cc2)C1=O","label":"3"}]},
|
| 33 |
+
|
| 34 |
+
{"reaction_id":"1_1",###Detailed reaction
|
| 35 |
+
"reactants":[{"smiles":"CCC(=O)c1ccccc1","label":"1a"},{"smiles":"Cc1ccc(S(=O)(=O)N2OC2c2ccccc2Cl)cc1","label":"2"}],
|
| 36 |
+
"conditions":[{"role":"reagent","text":"10 mol% B17 or B27","smiles":"C(C=CC=C1)=C1C[N+]2=CN3[C@H](C(C4=CC=CC=C4)(C5=CC=CC=C5)O[Si](C)(C)C(C)(C)C)CCC3=N2.F[B-](F)(F)F","label":"B17"},{"role":"reagent","text":"10 mol% B17 or B27","smiles":"CCCC(C=CC=C1)=C1[N+]2=CN3[C@H](C(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))O)CCC3=N2.F[B-](F)(F)F","label":"B27"},{"role":"reagent","text":"10 mol% Cs2CO3","smiles":"[Cs+].[Cs+].[O-]C(=O)[O-]"},{"role":"solvent","text":"PhMe","smiles":"Cc1ccccc1"},{"role":"temperature","text":"rt"},{"role":"yield","text":"71%"}],
|
| 37 |
+
"products":[{"smiles":"CC[C@]1(c2ccccc2)O[C@H](c2ccccc2Cl)N(S(=O)(=O)c2ccc(C)cc2)C1=O","label":"3a"}],
|
| 38 |
+
"additional_info":[{"text":"14:1 dr, 91% ee"}]},
|
| 39 |
+
|
| 40 |
+
{"reaction_id":"2_1",... ###More detailed reactions}
|
| 41 |
+
]
|
| 42 |
+
}
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## :rocket: Using the code and the model
|
| 46 |
+
### Using the code
|
| 47 |
+
Clone the following repositories:
|
| 48 |
+
```
|
| 49 |
+
git clone https://github.com/CYF2000127/ChemEagle
|
| 50 |
+
```
|
| 51 |
+
### Example usage of the model
|
| 52 |
+
1. First create and activate a [conda](https://numdifftools.readthedocs.io/en/stable/how-to/create_virtual_env_with_conda.html) environment with the following command in a Linux, Windows, or MacOS environment (Linux is the most recommended):
|
| 53 |
+
```
|
| 54 |
+
conda create -n chemeagle python=3.10
|
| 55 |
+
conda activate chemeagle
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
2. Then install requirements:
|
| 59 |
+
```
|
| 60 |
+
pip install -r requirements.txt
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
3. Set up your API keys in your environment.
|
| 64 |
+
```
|
| 65 |
+
export API_KEY=your-openai-api-key
|
| 66 |
+
```
|
| 67 |
+
Alternatively, add your API keys in the [api_key.txt](./api_key.txt)
|
| 68 |
+
|
| 69 |
+
4. Run the following code to extract multimodal chemical reactions from a multimodal reaction image:
|
| 70 |
+
```python
|
| 71 |
+
from main import ChemEagle
|
| 72 |
+
image_path = './examples/1.png'
|
| 73 |
+
results = ChemEagle(image_path)
|
| 74 |
+
print(results)
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## :warning: Acknowledgement
|
| 78 |
+
Our code is based on [MolNexTR](https://github.com/CYF2000127/MolNexTR), [MolScribe](https://github.com/thomas0809/MolScribe), [RxnIM](https://github.com/CYF2000127/RxnIM) [RxnScribe](https://github.com/thomas0809/RxNScribe), [ChemNER](https://github.com/Ozymandias314/ChemIENER), [ChemRxnExtractor](https://github.com/jiangfeng1124/ChemRxnExtractor), [AutoAgents](https://github.com/Link-AGI/AutoAgents) and [OpenAI](https://openai.com/) thanks their great jobs!
|
| 79 |
+
|
| 80 |
+
|
__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.1.0"
|
| 2 |
+
__author__ = 'Alex Wang'
|
| 3 |
+
__credits__ = 'CSAIL'
|
__pycache__/get_R_group_sub_agent.cpython-310.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
__pycache__/get_molecular_agent.cpython-310.pyc
ADDED
|
Binary file (8.85 kB). View file
|
|
|
__pycache__/get_reaction_agent.cpython-310.pyc
ADDED
|
Binary file (7.69 kB). View file
|
|
|
__pycache__/main.cpython-310.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
app.ipynb
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "d13d3631",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stderr",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Using CPU. Note: This module is much faster with a GPU.\n"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"ename": "ValueError",
|
| 18 |
+
"evalue": "Please set API_KEY",
|
| 19 |
+
"output_type": "error",
|
| 20 |
+
"traceback": [
|
| 21 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 22 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
| 23 |
+
"Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mgradio\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mgr\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmain\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ChemEagle \u001b[38;5;66;03m# 假设内部已经管理 API Key\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrdkit\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Chem\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrdkit\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mChem\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m rdChemReactions, Draw, AllChem\n",
|
| 24 |
+
"File \u001b[0;32m/media/chenyufan/F/ChemEagle-hf/main.py:10\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mget_molecular_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m process_reaction_image_with_multiple_products_and_text_correctR\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mget_reaction_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_reaction_withatoms_correctR\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mget_R_group_sub_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m process_reaction_image_with_table_R_group, process_reaction_image_with_product_variant_R_group,get_full_reaction,get_multi_molecular_full\n",
|
| 25 |
+
"File \u001b[0;32m/media/chenyufan/F/ChemEagle-hf/get_molecular_agent.py:35\u001b[0m\n\u001b[1;32m 33\u001b[0m API_KEY \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAPI_KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m API_KEY:\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease set API_KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 36\u001b[0m AZURE_ENDPOINT \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAZURE_ENDPOINT\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_multi_molecular\u001b[39m(image_path: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mlist\u001b[39m:\n",
|
| 26 |
+
"\u001b[0;31mValueError\u001b[0m: Please set API_KEY"
|
| 27 |
+
]
|
| 28 |
+
}
|
| 29 |
+
],
|
| 30 |
+
"source": [
|
| 31 |
+
"import gradio as gr\n",
|
| 32 |
+
"import json\n",
|
| 33 |
+
"from main import ChemEagle # 假设内部已经管理 API Key\n",
|
| 34 |
+
"from rdkit import Chem\n",
|
| 35 |
+
"from rdkit.Chem import rdChemReactions, Draw, AllChem\n",
|
| 36 |
+
"from rdkit.Chem.Draw import rdMolDraw2D\n",
|
| 37 |
+
"import cairosvg\n",
|
| 38 |
+
"import re\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"example_diagram = \"examples/exp.png\"\n",
|
| 41 |
+
"rdkit_image = \"examples/rdkit.png\"\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"# 解析 ChemEagle 返回的结构化数据\n",
|
| 44 |
+
"def parse_reactions(output_json):\n",
|
| 45 |
+
" if isinstance(output_json, str):\n",
|
| 46 |
+
" reactions_data = json.loads(output_json)\n",
|
| 47 |
+
" else:\n",
|
| 48 |
+
" reactions_data = output_json\n",
|
| 49 |
+
" reactions_list = reactions_data.get(\"reactions\", [])\n",
|
| 50 |
+
" detailed_output = []\n",
|
| 51 |
+
" smiles_output = []\n",
|
| 52 |
+
"\n",
|
| 53 |
+
" for reaction in reactions_list:\n",
|
| 54 |
+
" reaction_id = reaction.get(\"reaction_id\", \"Unknown ID\")\n",
|
| 55 |
+
" reactants = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"reactants\", [])]\n",
|
| 56 |
+
" conditions = [\n",
|
| 57 |
+
" f\"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
|
| 58 |
+
" for c in reaction.get(\"condition\", [])\n",
|
| 59 |
+
" ]\n",
|
| 60 |
+
" conditions_1 = [\n",
|
| 61 |
+
" f\"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
|
| 62 |
+
" for c in reaction.get(\"condition\", [])\n",
|
| 63 |
+
" ]\n",
|
| 64 |
+
" products = [f\"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
|
| 65 |
+
" products_1 = [f\"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
|
| 66 |
+
" products_2 = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"products\", [])]\n",
|
| 67 |
+
" additional = reaction.get(\"additional_info\", [])\n",
|
| 68 |
+
" additional_str = [str(x) for x in additional if x]\n",
|
| 69 |
+
"\n",
|
| 70 |
+
" tail = conditions_1 + additional_str\n",
|
| 71 |
+
" tail_str = \", \".join(tail)\n",
|
| 72 |
+
" full_reaction = f\"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}\"\n",
|
| 73 |
+
" full_reaction = f\"<span style='color:black'>{full_reaction}</span>\"\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" reaction_output = f\"<b>Reaction: </b> {reaction_id}<br>\"\n",
|
| 76 |
+
" reaction_output += f\" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>\"\n",
|
| 77 |
+
" reaction_output += f\" Conditions: {', '.join(conditions)}<br>\"\n",
|
| 78 |
+
" reaction_output += f\" Products: {', '.join(products)}<br>\"\n",
|
| 79 |
+
" reaction_output += f\" additional_info: {', '.join(additional_str)}<br>\"\n",
|
| 80 |
+
" reaction_output += f\" <b>Full Reaction:</b> {full_reaction}<br><br>\"\n",
|
| 81 |
+
" detailed_output.append(reaction_output)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
" reaction_smiles = f\"{'.'.join(reactants)}>>{'.'.join(products_2)}\"\n",
|
| 84 |
+
" smiles_output.append(reaction_smiles)\n",
|
| 85 |
+
"\n",
|
| 86 |
+
" return detailed_output, smiles_output\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"def process_chem_image(image):\n",
|
| 89 |
+
" image_path = \"temp_image.png\"\n",
|
| 90 |
+
" image.save(image_path)\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" chemeagle_result = ChemEagle(image_path)\n",
|
| 93 |
+
" detailed, smiles = parse_reactions(chemeagle_result)\n",
|
| 94 |
+
"\n",
|
| 95 |
+
" json_path = \"output.json\"\n",
|
| 96 |
+
" with open(json_path, 'w') as jf:\n",
|
| 97 |
+
" json.dump(chemeagle_result, jf, indent=2)\n",
|
| 98 |
+
"\n",
|
| 99 |
+
" return \"\\n\\n\".join(detailed), smiles, example_diagram, json_path\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"with gr.Blocks() as demo:\n",
|
| 102 |
+
" gr.Markdown(\n",
|
| 103 |
+
" \"\"\"\n",
|
| 104 |
+
" <center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>\n",
|
| 105 |
+
" Upload a multimodal reaction image to extract multimodal chemical information.\n",
|
| 106 |
+
" \"\"\"\n",
|
| 107 |
+
" )\n",
|
| 108 |
+
"\n",
|
| 109 |
+
" with gr.Row():\n",
|
| 110 |
+
" with gr.Column(scale=1):\n",
|
| 111 |
+
" image_input = gr.Image(type=\"pil\", label=\"Upload a multimodal reaction image\")\n",
|
| 112 |
+
" with gr.Row():\n",
|
| 113 |
+
" clear_btn = gr.Button(\"Clear\")\n",
|
| 114 |
+
" run_btn = gr.Button(\"Run\", elem_id=\"submit-btn\")\n",
|
| 115 |
+
"\n",
|
| 116 |
+
" with gr.Column(scale=1):\n",
|
| 117 |
+
" gr.Markdown(\"### Parsed Reactions\")\n",
|
| 118 |
+
" reaction_output = gr.HTML(label=\"Detailed Reaction Output\")\n",
|
| 119 |
+
" gr.Markdown(\"### Schematic Diagram\")\n",
|
| 120 |
+
" schematic_diagram = gr.Image(value=example_diagram, label=\"示意图\")\n",
|
| 121 |
+
"\n",
|
| 122 |
+
" with gr.Column(scale=1):\n",
|
| 123 |
+
" gr.Markdown(\"### Machine-readable Output\")\n",
|
| 124 |
+
" smiles_output = gr.Textbox(\n",
|
| 125 |
+
" label=\"Reaction SMILES\",\n",
|
| 126 |
+
" show_copy_button=True,\n",
|
| 127 |
+
" interactive=False,\n",
|
| 128 |
+
" visible=False\n",
|
| 129 |
+
" )\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" @gr.render(inputs=smiles_output)\n",
|
| 132 |
+
" def show_split(inputs):\n",
|
| 133 |
+
" if not inputs or (isinstance(inputs, str) and inputs.strip() == \"\"):\n",
|
| 134 |
+
" return gr.Textbox(label=\"SMILES of Reaction i\"), gr.Image(value=rdkit_image, label=\"RDKit Image of Reaction i\", height=100)\n",
|
| 135 |
+
" smiles_list = inputs.split(\",\")\n",
|
| 136 |
+
" smiles_list = [re.sub(r\"^\\s*\\[?'?|']?\\s*$\", \"\", item) for item in smiles_list]\n",
|
| 137 |
+
" components = []\n",
|
| 138 |
+
" for i, smiles in enumerate(smiles_list):\n",
|
| 139 |
+
" smiles_clean = smiles.replace('\"', '').replace(\"'\", \"\").replace(\"[\", \"\").replace(\"]\", \"\")\n",
|
| 140 |
+
" # 始终加入 SMILES 文本框\n",
|
| 141 |
+
" components.append(gr.Textbox(value=smiles_clean, label=f\"SMILES of Reaction {i+1}\", show_copy_button=True, interactive=False))\n",
|
| 142 |
+
" try:\n",
|
| 143 |
+
" rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True)\n",
|
| 144 |
+
" if not rxn:\n",
|
| 145 |
+
" continue\n",
|
| 146 |
+
" new_rxn = AllChem.ChemicalReaction()\n",
|
| 147 |
+
" for mol in rxn.GetReactants():\n",
|
| 148 |
+
" mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
|
| 149 |
+
" new_rxn.AddReactantTemplate(mol)\n",
|
| 150 |
+
" for mol in rxn.GetProducts():\n",
|
| 151 |
+
" mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
|
| 152 |
+
" new_rxn.AddProductTemplate(mol)\n",
|
| 153 |
+
" cleaned_rxn = new_rxn\n",
|
| 154 |
+
"\n",
|
| 155 |
+
" # 移除原子映射\n",
|
| 156 |
+
" for react in cleaned_rxn.GetReactants():\n",
|
| 157 |
+
" for atom in react.GetAtoms(): atom.SetAtomMapNum(0)\n",
|
| 158 |
+
" for prod in cleaned_rxn.GetProducts():\n",
|
| 159 |
+
" for atom in prod.GetAtoms(): atom.SetAtomMapNum(0)\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" # 计算键长参考\n",
|
| 162 |
+
" ref_rxn = cleaned_rxn\n",
|
| 163 |
+
" react0 = ref_rxn.GetReactantTemplate(0)\n",
|
| 164 |
+
" react1 = ref_rxn.GetReactantTemplate(1) if ref_rxn.GetNumReactantTemplates() > 1 else None\n",
|
| 165 |
+
" if react0.GetNumBonds() > 0:\n",
|
| 166 |
+
" bond_len = Draw.MeanBondLength(react0)\n",
|
| 167 |
+
" elif react1 and react1.GetNumBonds() > 0:\n",
|
| 168 |
+
" bond_len = Draw.MeanBondLength(react1)\n",
|
| 169 |
+
" else:\n",
|
| 170 |
+
" bond_len = 1.0\n",
|
| 171 |
+
"\n",
|
| 172 |
+
" # 绘图\n",
|
| 173 |
+
" drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)\n",
|
| 174 |
+
" dopts = drawer.drawOptions()\n",
|
| 175 |
+
" dopts.padding = 0.1\n",
|
| 176 |
+
" dopts.includeRadicals = True\n",
|
| 177 |
+
" Draw.SetACS1996Mode(dopts, bond_len * 0.55)\n",
|
| 178 |
+
" dopts.bondLineWidth = 1.5\n",
|
| 179 |
+
" drawer.DrawReaction(cleaned_rxn)\n",
|
| 180 |
+
" drawer.FinishDrawing()\n",
|
| 181 |
+
" svg = drawer.GetDrawingText()\n",
|
| 182 |
+
" svg_file = f\"reaction_{i+1}.svg\"\n",
|
| 183 |
+
" with open(svg_file, \"w\") as f: f.write(svg)\n",
|
| 184 |
+
" png_file = f\"reaction_{i+1}.png\"\n",
|
| 185 |
+
" cairosvg.svg2png(url=svg_file, write_to=png_file)\n",
|
| 186 |
+
" components.append(gr.Image(value=png_file, label=f\"RDKit Image of Reaction {i+1}\"))\n",
|
| 187 |
+
" except Exception as e:\n",
|
| 188 |
+
" print(f\"Failed to draw reaction {i+1} for SMILES '{smiles_clean}': {e}\")\n",
|
| 189 |
+
" # 绘图失败则跳过\n",
|
| 190 |
+
" return components\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" download_json = gr.File(label=\"Download JSON File\")\n",
|
| 193 |
+
"\n",
|
| 194 |
+
" gr.Examples(\n",
|
| 195 |
+
" examples=[\n",
|
| 196 |
+
" [\"examples/reaction1.jpg\"],\n",
|
| 197 |
+
" [\"examples/reaction2.png\"],\n",
|
| 198 |
+
" [\"examples/reaction3.png\"],\n",
|
| 199 |
+
" [\"examples/reaction4.png\"],\n",
|
| 200 |
+
" ],\n",
|
| 201 |
+
" inputs=[image_input],\n",
|
| 202 |
+
" outputs=[reaction_output, smiles_output, schematic_diagram, download_json],\n",
|
| 203 |
+
" cache_examples=False,\n",
|
| 204 |
+
" examples_per_page=4,\n",
|
| 205 |
+
" )\n",
|
| 206 |
+
"\n",
|
| 207 |
+
" clear_btn.click(\n",
|
| 208 |
+
" lambda: (None, None, None, None),\n",
|
| 209 |
+
" inputs=[],\n",
|
| 210 |
+
" outputs=[image_input, reaction_output, smiles_output, download_json]\n",
|
| 211 |
+
" )\n",
|
| 212 |
+
" run_btn.click(\n",
|
| 213 |
+
" process_chem_image,\n",
|
| 214 |
+
" inputs=[image_input],\n",
|
| 215 |
+
" outputs=[reaction_output, smiles_output, schematic_diagram, download_json]\n",
|
| 216 |
+
" )\n",
|
| 217 |
+
"\n",
|
| 218 |
+
" demo.css = \"\"\"\n",
|
| 219 |
+
" #submit-btn {\n",
|
| 220 |
+
" background-color: #FF914D;\n",
|
| 221 |
+
" color: white;\n",
|
| 222 |
+
" font-weight: bold;\n",
|
| 223 |
+
" }\n",
|
| 224 |
+
" \"\"\"\n",
|
| 225 |
+
"\n",
|
| 226 |
+
" demo.launch()\n"
|
| 227 |
+
]
|
| 228 |
+
}
|
| 229 |
+
],
|
| 230 |
+
"metadata": {
|
| 231 |
+
"kernelspec": {
|
| 232 |
+
"display_name": "openchemie",
|
| 233 |
+
"language": "python",
|
| 234 |
+
"name": "python3"
|
| 235 |
+
},
|
| 236 |
+
"language_info": {
|
| 237 |
+
"codemirror_mode": {
|
| 238 |
+
"name": "ipython",
|
| 239 |
+
"version": 3
|
| 240 |
+
},
|
| 241 |
+
"file_extension": ".py",
|
| 242 |
+
"mimetype": "text/x-python",
|
| 243 |
+
"name": "python",
|
| 244 |
+
"nbconvert_exporter": "python",
|
| 245 |
+
"pygments_lexer": "ipython3",
|
| 246 |
+
"version": "3.10.14"
|
| 247 |
+
}
|
| 248 |
+
},
|
| 249 |
+
"nbformat": 4,
|
| 250 |
+
"nbformat_minor": 5
|
| 251 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import json
|
| 3 |
+
from main import ChemEagle # 假设内部已经管理 API Key
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
from rdkit.Chem import rdChemReactions, Draw, AllChem
|
| 6 |
+
from rdkit.Chem.Draw import rdMolDraw2D
|
| 7 |
+
import cairosvg
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
example_diagram = "examples/exp.png"
|
| 11 |
+
rdkit_image = "examples/rdkit.png"
|
| 12 |
+
|
| 13 |
+
# 解析 ChemEagle 返回的结构化数据
|
| 14 |
+
def parse_reactions(output_json):
|
| 15 |
+
if isinstance(output_json, str):
|
| 16 |
+
reactions_data = json.loads(output_json)
|
| 17 |
+
else:
|
| 18 |
+
reactions_data = output_json
|
| 19 |
+
reactions_list = reactions_data.get("reactions", [])
|
| 20 |
+
detailed_output = []
|
| 21 |
+
smiles_output = []
|
| 22 |
+
|
| 23 |
+
for reaction in reactions_list:
|
| 24 |
+
reaction_id = reaction.get("reaction_id", "Unknown ID")
|
| 25 |
+
reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
|
| 26 |
+
conditions = [
|
| 27 |
+
f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
|
| 28 |
+
for c in reaction.get("condition", [])
|
| 29 |
+
]
|
| 30 |
+
conditions_1 = [
|
| 31 |
+
f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
|
| 32 |
+
for c in reaction.get("condition", [])
|
| 33 |
+
]
|
| 34 |
+
products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
|
| 35 |
+
products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
|
| 36 |
+
products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
|
| 37 |
+
additional = reaction.get("additional_info", [])
|
| 38 |
+
additional_str = [str(x) for x in additional if x]
|
| 39 |
+
|
| 40 |
+
tail = conditions_1 + additional_str
|
| 41 |
+
tail_str = ", ".join(tail)
|
| 42 |
+
full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}"
|
| 43 |
+
full_reaction = f"<span style='color:black'>{full_reaction}</span>"
|
| 44 |
+
|
| 45 |
+
reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
|
| 46 |
+
reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
|
| 47 |
+
reaction_output += f" Conditions: {', '.join(conditions)}<br>"
|
| 48 |
+
reaction_output += f" Products: {', '.join(products)}<br>"
|
| 49 |
+
reaction_output += f" additional_info: {', '.join(additional_str)}<br>"
|
| 50 |
+
reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br><br>"
|
| 51 |
+
detailed_output.append(reaction_output)
|
| 52 |
+
|
| 53 |
+
reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
|
| 54 |
+
smiles_output.append(reaction_smiles)
|
| 55 |
+
|
| 56 |
+
return detailed_output, smiles_output
|
| 57 |
+
|
| 58 |
+
def process_chem_image(image):
|
| 59 |
+
image_path = "temp_image.png"
|
| 60 |
+
image.save(image_path)
|
| 61 |
+
|
| 62 |
+
chemeagle_result = ChemEagle(image_path)
|
| 63 |
+
detailed, smiles = parse_reactions(chemeagle_result)
|
| 64 |
+
|
| 65 |
+
json_path = "output.json"
|
| 66 |
+
with open(json_path, 'w') as jf:
|
| 67 |
+
json.dump(chemeagle_result, jf, indent=2)
|
| 68 |
+
|
| 69 |
+
return "\n\n".join(detailed), smiles, example_diagram, json_path
|
| 70 |
+
|
| 71 |
+
with gr.Blocks() as demo:
|
| 72 |
+
gr.Markdown(
|
| 73 |
+
"""
|
| 74 |
+
<center><h1>ChemEAGLE: A Multi-Agent System for Information Extraction from the Chemical Literature</h1></center>
|
| 75 |
+
Upload a chemical graphics to extract machine-readable chemical data.
|
| 76 |
+
"""
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
with gr.Row():
|
| 80 |
+
with gr.Column(scale=1):
|
| 81 |
+
image_input = gr.Image(type="pil", label="Upload a multimodal reaction image")
|
| 82 |
+
with gr.Row():
|
| 83 |
+
clear_btn = gr.Button("Clear")
|
| 84 |
+
run_btn = gr.Button("Run", elem_id="submit-btn")
|
| 85 |
+
|
| 86 |
+
with gr.Column(scale=1):
|
| 87 |
+
gr.Markdown("### Parsed Reactions")
|
| 88 |
+
reaction_output = gr.HTML(label="Detailed Reaction Output")
|
| 89 |
+
gr.Markdown("### Schematic Diagram")
|
| 90 |
+
schematic_diagram = gr.Image(value=example_diagram, label="示意图")
|
| 91 |
+
|
| 92 |
+
with gr.Column(scale=1):
|
| 93 |
+
gr.Markdown("### Machine-readable Output")
|
| 94 |
+
smiles_output = gr.Textbox(
|
| 95 |
+
label="Reaction SMILES",
|
| 96 |
+
show_copy_button=True,
|
| 97 |
+
interactive=False,
|
| 98 |
+
visible=False
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
@gr.render(inputs=smiles_output)
|
| 102 |
+
def show_split(inputs):
|
| 103 |
+
if not inputs or (isinstance(inputs, str) and inputs.strip() == ""):
|
| 104 |
+
return gr.Textbox(label="SMILES of Reaction i"), gr.Image(value=rdkit_image, label="RDKit Image of Reaction i", height=100)
|
| 105 |
+
smiles_list = inputs.split(",")
|
| 106 |
+
smiles_list = [re.sub(r"^\s*\[?'?|']?\s*$", "", item) for item in smiles_list]
|
| 107 |
+
components = []
|
| 108 |
+
for i, smiles in enumerate(smiles_list):
|
| 109 |
+
smiles_clean = smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
|
| 110 |
+
# 始终加入 SMILES 文本框
|
| 111 |
+
components.append(gr.Textbox(value=smiles_clean, label=f"SMILES of Reaction {i}", show_copy_button=True, interactive=False))
|
| 112 |
+
try:
|
| 113 |
+
rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True)
|
| 114 |
+
if not rxn:
|
| 115 |
+
continue
|
| 116 |
+
new_rxn = AllChem.ChemicalReaction()
|
| 117 |
+
for mol in rxn.GetReactants():
|
| 118 |
+
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
|
| 119 |
+
new_rxn.AddReactantTemplate(mol)
|
| 120 |
+
for mol in rxn.GetProducts():
|
| 121 |
+
mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
|
| 122 |
+
new_rxn.AddProductTemplate(mol)
|
| 123 |
+
cleaned_rxn = new_rxn
|
| 124 |
+
|
| 125 |
+
# 移除原子映射
|
| 126 |
+
for react in cleaned_rxn.GetReactants():
|
| 127 |
+
for atom in react.GetAtoms(): atom.SetAtomMapNum(0)
|
| 128 |
+
for prod in cleaned_rxn.GetProducts():
|
| 129 |
+
for atom in prod.GetAtoms(): atom.SetAtomMapNum(0)
|
| 130 |
+
|
| 131 |
+
# 计算键长参考
|
| 132 |
+
ref_rxn = cleaned_rxn
|
| 133 |
+
react0 = ref_rxn.GetReactantTemplate(0)
|
| 134 |
+
react1 = ref_rxn.GetReactantTemplate(1) if ref_rxn.GetNumReactantTemplates() > 1 else None
|
| 135 |
+
if react0.GetNumBonds() > 0:
|
| 136 |
+
bond_len = Draw.MeanBondLength(react0)
|
| 137 |
+
elif react1 and react1.GetNumBonds() > 0:
|
| 138 |
+
bond_len = Draw.MeanBondLength(react1)
|
| 139 |
+
else:
|
| 140 |
+
bond_len = 1.0
|
| 141 |
+
|
| 142 |
+
# 绘图
|
| 143 |
+
drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
|
| 144 |
+
dopts = drawer.drawOptions()
|
| 145 |
+
dopts.padding = 0.1
|
| 146 |
+
dopts.includeRadicals = True
|
| 147 |
+
Draw.SetACS1996Mode(dopts, bond_len * 0.55)
|
| 148 |
+
dopts.bondLineWidth = 1.5
|
| 149 |
+
drawer.DrawReaction(cleaned_rxn)
|
| 150 |
+
drawer.FinishDrawing()
|
| 151 |
+
svg = drawer.GetDrawingText()
|
| 152 |
+
svg_file = f"reaction_{i}.svg"
|
| 153 |
+
with open(svg_file, "w") as f: f.write(svg)
|
| 154 |
+
png_file = f"reaction_{i}.png"
|
| 155 |
+
cairosvg.svg2png(url=svg_file, write_to=png_file)
|
| 156 |
+
components.append(gr.Image(value=png_file, label=f"RDKit Image of Reaction {i}"))
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Failed to draw reaction {i} for SMILES '{smiles_clean}': {e}")
|
| 159 |
+
# 绘图失败则跳过
|
| 160 |
+
return components
|
| 161 |
+
|
| 162 |
+
download_json = gr.File(label="Download JSON File")
|
| 163 |
+
|
| 164 |
+
gr.Examples(
|
| 165 |
+
examples=[
|
| 166 |
+
["examples/reaction0.png"],
|
| 167 |
+
["examples/reaction1.jpg"],
|
| 168 |
+
["examples/reaction2.png"],
|
| 169 |
+
["examples/reaction3.png"],
|
| 170 |
+
["examples/reaction4.png"],
|
| 171 |
+
["examples/reaction5.png"],
|
| 172 |
+
["examples/template1.png"],
|
| 173 |
+
["examples/molecules1.png"],
|
| 174 |
+
["examples/molecules2.png"],
|
| 175 |
+
|
| 176 |
+
],
|
| 177 |
+
inputs=[image_input],
|
| 178 |
+
outputs=[reaction_output, smiles_output, schematic_diagram, download_json],
|
| 179 |
+
cache_examples=False,
|
| 180 |
+
examples_per_page=4,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
clear_btn.click(
|
| 184 |
+
lambda: (None, None, None, None),
|
| 185 |
+
inputs=[],
|
| 186 |
+
outputs=[image_input, reaction_output, smiles_output, download_json]
|
| 187 |
+
)
|
| 188 |
+
run_btn.click(
|
| 189 |
+
process_chem_image,
|
| 190 |
+
inputs=[image_input],
|
| 191 |
+
outputs=[reaction_output, smiles_output, schematic_diagram, download_json]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
demo.css = """
|
| 195 |
+
#submit-btn {
|
| 196 |
+
background-color: #FF914D;
|
| 197 |
+
color: white;
|
| 198 |
+
font-weight: bold;
|
| 199 |
+
}
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
demo.launch()
|
apt.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libpoppler-cpp-dev
|
| 2 |
+
pkg-config
|
chemiener/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .interface import ChemNER
|
chemiener/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (187 Bytes). View file
|
|
|
chemiener/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
chemiener/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (5.37 kB). View file
|
|
|
chemiener/__pycache__/dataset.cpython-38.pyc
ADDED
|
Binary file (5.35 kB). View file
|
|
|
chemiener/__pycache__/interface.cpython-310.pyc
ADDED
|
Binary file (4.46 kB). View file
|
|
|
chemiener/__pycache__/interface.cpython-38.pyc
ADDED
|
Binary file (4.47 kB). View file
|
|
|
chemiener/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (684 Bytes). View file
|
|
|
chemiener/__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (680 Bytes). View file
|
|
|
chemiener/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.67 kB). View file
|
|
|
chemiener/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
chemiener/dataset.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import copy
|
| 4 |
+
import random
|
| 5 |
+
import json
|
| 6 |
+
import contextlib
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
| 13 |
+
|
| 14 |
+
from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast
|
| 15 |
+
|
| 16 |
+
from .utils import get_class_to_index
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class NERDataset(Dataset):
|
| 21 |
+
def __init__(self, args, data_file, split='train'):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.args = args
|
| 24 |
+
if data_file:
|
| 25 |
+
data_path = os.path.join(args.data_path, data_file)
|
| 26 |
+
with open(data_path) as f:
|
| 27 |
+
self.data = json.load(f)
|
| 28 |
+
self.name = os.path.basename(data_file).split('.')[0]
|
| 29 |
+
self.split = split
|
| 30 |
+
self.is_train = (split == 'train')
|
| 31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased')
|
| 32 |
+
self.class_to_index = get_class_to_index(self.args.corpus)
|
| 33 |
+
self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
|
| 34 |
+
|
| 35 |
+
#commment
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.data)
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, idx):
|
| 40 |
+
|
| 41 |
+
text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length)
|
| 42 |
+
if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids']))
|
| 43 |
+
text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text'])
|
| 44 |
+
return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text']))
|
| 45 |
+
|
| 46 |
+
def align_labels(self, text_tokenized, entities, length):
|
| 47 |
+
char_to_class = {}
|
| 48 |
+
|
| 49 |
+
for entity in entities:
|
| 50 |
+
for span in entities[entity]["span"]:
|
| 51 |
+
for i in range(span[0], span[1]):
|
| 52 |
+
char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])]
|
| 53 |
+
|
| 54 |
+
for i in range(length):
|
| 55 |
+
if i not in char_to_class:
|
| 56 |
+
char_to_class[i] = 0
|
| 57 |
+
|
| 58 |
+
classes = []
|
| 59 |
+
for i in range(len(text_tokenized[0])):
|
| 60 |
+
span = text_tokenized.token_to_chars(i)
|
| 61 |
+
if span is not None:
|
| 62 |
+
classes.append(char_to_class[span.start])
|
| 63 |
+
else:
|
| 64 |
+
classes.append(-100)
|
| 65 |
+
|
| 66 |
+
return torch.LongTensor(classes)
|
| 67 |
+
|
| 68 |
+
def make_html(word_tokens, predictions):
|
| 69 |
+
|
| 70 |
+
toreturn = '''<!DOCTYPE html>
|
| 71 |
+
<html>
|
| 72 |
+
<head>
|
| 73 |
+
<title>Named Entity Recognition Visualization</title>
|
| 74 |
+
<style>
|
| 75 |
+
.EXAMPLE_LABEL {
|
| 76 |
+
color: red;
|
| 77 |
+
text-decoration: underline red;
|
| 78 |
+
}
|
| 79 |
+
.REACTION_PRODUCT {
|
| 80 |
+
color: orange;
|
| 81 |
+
text-decoration: underline orange;
|
| 82 |
+
}
|
| 83 |
+
.STARTING_MATERIAL {
|
| 84 |
+
color: gold;
|
| 85 |
+
text-decoration: underline gold;
|
| 86 |
+
}
|
| 87 |
+
.REAGENT_CATALYST {
|
| 88 |
+
color: green;
|
| 89 |
+
text-decoration: underline green;
|
| 90 |
+
}
|
| 91 |
+
.SOLVENT {
|
| 92 |
+
color: cyan;
|
| 93 |
+
text-decoration: underline cyan;
|
| 94 |
+
}
|
| 95 |
+
.OTHER_COMPOUND {
|
| 96 |
+
color: blue;
|
| 97 |
+
text-decoration: underline blue;
|
| 98 |
+
}
|
| 99 |
+
.TIME {
|
| 100 |
+
color: purple;
|
| 101 |
+
text-decoration: underline purple;
|
| 102 |
+
}
|
| 103 |
+
.TEMPERATURE {
|
| 104 |
+
color: magenta;
|
| 105 |
+
text-decoration: underline magenta;
|
| 106 |
+
}
|
| 107 |
+
.YIELD_OTHER {
|
| 108 |
+
color: palegreen;
|
| 109 |
+
text-decoration: underline palegreen;
|
| 110 |
+
}
|
| 111 |
+
.YIELD_PERCENT {
|
| 112 |
+
color: pink;
|
| 113 |
+
text-decoration: underline pink;
|
| 114 |
+
}
|
| 115 |
+
</style>
|
| 116 |
+
</head>
|
| 117 |
+
<body>
|
| 118 |
+
<p>'''
|
| 119 |
+
last_label = None
|
| 120 |
+
for idx, item in enumerate(word_tokens):
|
| 121 |
+
decoded = self.tokenizer.decode(item, skip_special_tokens = True)
|
| 122 |
+
if len(decoded)>0:
|
| 123 |
+
if idx!=0 and decoded[0]!='#':
|
| 124 |
+
toreturn+=" "
|
| 125 |
+
label = predictions[idx]
|
| 126 |
+
if label == last_label:
|
| 127 |
+
|
| 128 |
+
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
|
| 129 |
+
else:
|
| 130 |
+
if last_label is not None and last_label>0:
|
| 131 |
+
toreturn+="</u>"
|
| 132 |
+
if label >0:
|
| 133 |
+
toreturn+="<u class=\""
|
| 134 |
+
toreturn+=self.index_to_class[label]
|
| 135 |
+
toreturn+="\">"
|
| 136 |
+
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
|
| 137 |
+
if label == 0:
|
| 138 |
+
toreturn+=decoded if decoded[0]!="#" else decoded[2:]
|
| 139 |
+
if idx==len(word_tokens) and label>0:
|
| 140 |
+
toreturn+="</u>"
|
| 141 |
+
last_label = label
|
| 142 |
+
|
| 143 |
+
toreturn += ''' </p>
|
| 144 |
+
</body>
|
| 145 |
+
</html>'''
|
| 146 |
+
return toreturn
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_collate_fn():
|
| 150 |
+
def collate(batch):
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
sentences = []
|
| 155 |
+
masks = []
|
| 156 |
+
refs = []
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
for ex in batch:
|
| 160 |
+
sentences.append(torch.LongTensor(ex[0]['input_ids']))
|
| 161 |
+
masks.append(torch.Tensor(ex[0]['attention_mask']))
|
| 162 |
+
refs.append(ex[1])
|
| 163 |
+
|
| 164 |
+
sentences = pad_sequence(sentences, batch_first = True, padding_value = 0)
|
| 165 |
+
masks = pad_sequence(masks, batch_first = True, padding_value = 0)
|
| 166 |
+
refs = pad_sequence(refs, batch_first = True, padding_value = -100)
|
| 167 |
+
return sentences, masks, refs
|
| 168 |
+
|
| 169 |
+
return collate
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
chemiener/interface.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
from typing import List
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .model import build_model
|
| 8 |
+
|
| 9 |
+
from .dataset import NERDataset, get_collate_fn
|
| 10 |
+
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
from .utils import get_class_to_index
|
| 14 |
+
|
| 15 |
+
class ChemNER:
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_path, device = None, cache_dir = None):
|
| 18 |
+
|
| 19 |
+
self.args = self._get_args(cache_dir)
|
| 20 |
+
|
| 21 |
+
states = torch.load(model_path, map_location = torch.device('cpu'))
|
| 22 |
+
|
| 23 |
+
if device is None:
|
| 24 |
+
device = torch.device('cpu')
|
| 25 |
+
|
| 26 |
+
self.device = device
|
| 27 |
+
|
| 28 |
+
self.model = self.get_model(self.args, device, states['state_dict'])
|
| 29 |
+
|
| 30 |
+
self.collate = get_collate_fn()
|
| 31 |
+
|
| 32 |
+
self.dataset = NERDataset(self.args, data_file = None)
|
| 33 |
+
|
| 34 |
+
self.class_to_index = get_class_to_index(self.args.corpus)
|
| 35 |
+
|
| 36 |
+
self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
|
| 37 |
+
|
| 38 |
+
def _get_args(self, cache_dir):
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
|
| 41 |
+
parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use')
|
| 42 |
+
|
| 43 |
+
parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from")
|
| 44 |
+
|
| 45 |
+
args = parser.parse_args([])
|
| 46 |
+
|
| 47 |
+
args.cache_dir = cache_dir
|
| 48 |
+
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
def get_model(self, args, device, model_states):
|
| 52 |
+
model = build_model(args)
|
| 53 |
+
|
| 54 |
+
def remove_prefix(state_dict):
|
| 55 |
+
return {k.replace('model.', ''): v for k, v in state_dict.items()}
|
| 56 |
+
|
| 57 |
+
model.load_state_dict(remove_prefix(model_states), strict = False)
|
| 58 |
+
|
| 59 |
+
model.to(device)
|
| 60 |
+
|
| 61 |
+
model.eval()
|
| 62 |
+
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
def predict_strings(self, strings: List, batch_size = 8):
|
| 66 |
+
device = self.device
|
| 67 |
+
|
| 68 |
+
predictions = []
|
| 69 |
+
|
| 70 |
+
def prepare_output(char_span, prediction):
|
| 71 |
+
toreturn = []
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
i = 0
|
| 75 |
+
|
| 76 |
+
while i < len(char_span):
|
| 77 |
+
if prediction[i][0] == 'B':
|
| 78 |
+
toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end]))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]:
|
| 84 |
+
toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end])
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
i += 1
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
return toreturn
|
| 92 |
+
|
| 93 |
+
output = []
|
| 94 |
+
for idx in range(0, len(strings), batch_size):
|
| 95 |
+
batch_strings = strings[idx:idx+batch_size]
|
| 96 |
+
batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
sentences, masks, refs = self.collate(batch_strings_tokenized)
|
| 100 |
+
|
| 101 |
+
predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu')
|
| 102 |
+
|
| 103 |
+
sentences_list = list(sentences)
|
| 104 |
+
|
| 105 |
+
predictions_list = list(predictions)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
char_spans = []
|
| 109 |
+
for j, sentence in enumerate(sentences_list):
|
| 110 |
+
to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ]
|
| 111 |
+
char_spans.append(to_add)
|
| 112 |
+
|
| 113 |
+
class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)]
|
| 118 |
+
|
| 119 |
+
return output
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
chemiener/main.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.profiler import profile, record_function, ProfilerActivity
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
from pytorch_lightning import LightningModule, LightningDataModule
|
| 15 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
| 16 |
+
from pytorch_lightning.strategies.ddp import DDPStrategy
|
| 17 |
+
from transformers import get_scheduler
|
| 18 |
+
import transformers
|
| 19 |
+
|
| 20 |
+
from dataset import NERDataset, get_collate_fn
|
| 21 |
+
|
| 22 |
+
from model import build_model
|
| 23 |
+
|
| 24 |
+
from utils import get_class_to_index
|
| 25 |
+
|
| 26 |
+
import evaluate
|
| 27 |
+
|
| 28 |
+
from seqeval.metrics import accuracy_score
|
| 29 |
+
from seqeval.metrics import classification_report
|
| 30 |
+
from seqeval.metrics import f1_score
|
| 31 |
+
from seqeval.scheme import IOB2
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_args(notebook=False):
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
parser.add_argument('--do_train', action='store_true')
|
| 38 |
+
parser.add_argument('--do_valid', action='store_true')
|
| 39 |
+
parser.add_argument('--do_test', action='store_true')
|
| 40 |
+
parser.add_argument('--fp16', action='store_true')
|
| 41 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 42 |
+
parser.add_argument('--gpus', type=int, default=1)
|
| 43 |
+
parser.add_argument('--print_freq', type=int, default=200)
|
| 44 |
+
parser.add_argument('--debug', action='store_true')
|
| 45 |
+
parser.add_argument('--no_eval', action='store_true')
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Data
|
| 49 |
+
parser.add_argument('--data_path', type=str, default=None)
|
| 50 |
+
parser.add_argument('--image_path', type=str, default=None)
|
| 51 |
+
parser.add_argument('--train_file', type=str, default=None)
|
| 52 |
+
parser.add_argument('--valid_file', type=str, default=None)
|
| 53 |
+
parser.add_argument('--test_file', type=str, default=None)
|
| 54 |
+
parser.add_argument('--vocab_file', type=str, default=None)
|
| 55 |
+
parser.add_argument('--format', type=str, default='reaction')
|
| 56 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
| 57 |
+
parser.add_argument('--input_size', type=int, default=224)
|
| 58 |
+
|
| 59 |
+
# Training
|
| 60 |
+
parser.add_argument('--epochs', type=int, default=8)
|
| 61 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
| 62 |
+
parser.add_argument('--lr', type=float, default=1e-4)
|
| 63 |
+
parser.add_argument('--weight_decay', type=float, default=0.05)
|
| 64 |
+
parser.add_argument('--max_grad_norm', type=float, default=5.)
|
| 65 |
+
parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
|
| 66 |
+
parser.add_argument('--warmup_ratio', type=float, default=0)
|
| 67 |
+
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
|
| 68 |
+
parser.add_argument('--load_path', type=str, default=None)
|
| 69 |
+
parser.add_argument('--load_encoder_only', action='store_true')
|
| 70 |
+
parser.add_argument('--train_steps_per_epoch', type=int, default=-1)
|
| 71 |
+
parser.add_argument('--eval_per_epoch', type=int, default=10)
|
| 72 |
+
parser.add_argument('--save_path', type=str, default='output/')
|
| 73 |
+
parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last'])
|
| 74 |
+
parser.add_argument('--load_ckpt', type=str, default='best')
|
| 75 |
+
parser.add_argument('--resume', action='store_true')
|
| 76 |
+
parser.add_argument('--num_train_example', type=int, default=None)
|
| 77 |
+
|
| 78 |
+
parser.add_argument('--roberta_checkpoint', type=str, default = "roberta-base")
|
| 79 |
+
|
| 80 |
+
parser.add_argument('--corpus', type=str, default = "chemu")
|
| 81 |
+
|
| 82 |
+
parser.add_argument('--cache_dir')
|
| 83 |
+
|
| 84 |
+
parser.add_argument('--eval_truncated', action='store_true')
|
| 85 |
+
|
| 86 |
+
parser.add_argument('--max_seq_length', type = int, default=512)
|
| 87 |
+
|
| 88 |
+
args = parser.parse_args([]) if notebook else parser.parse_args()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
return args
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ChemIENERecognizer(LightningModule):
|
| 98 |
+
|
| 99 |
+
def __init__(self, args):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.args = args
|
| 103 |
+
|
| 104 |
+
self.model = build_model(args)
|
| 105 |
+
|
| 106 |
+
self.validation_step_outputs = []
|
| 107 |
+
|
| 108 |
+
def training_step(self, batch, batch_idx):
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
sentences, masks, refs,_ = batch
|
| 114 |
+
'''
|
| 115 |
+
print("sentences " + str(sentences))
|
| 116 |
+
print("sentence shape " + str(sentences.shape))
|
| 117 |
+
print("masks " + str(masks))
|
| 118 |
+
print("masks shape " + str(masks.shape))
|
| 119 |
+
print("refs " + str(refs))
|
| 120 |
+
print("refs shape " + str(refs.shape))
|
| 121 |
+
'''
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
loss, logits = self.model(input_ids=sentences, attention_mask=masks, labels=refs)
|
| 126 |
+
self.log('train/loss', loss)
|
| 127 |
+
self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
|
| 128 |
+
return loss
|
| 129 |
+
|
| 130 |
+
def validation_step(self, batch, batch_idx):
|
| 131 |
+
|
| 132 |
+
sentences, masks, refs, untruncated = batch
|
| 133 |
+
'''
|
| 134 |
+
print("sentences " + str(sentences))
|
| 135 |
+
print("sentence shape " + str(sentences.shape))
|
| 136 |
+
print("masks " + str(masks))
|
| 137 |
+
print("masks shape " + str(masks.shape))
|
| 138 |
+
print("refs " + str(refs))
|
| 139 |
+
print("refs shape " + str(refs.shape))
|
| 140 |
+
'''
|
| 141 |
+
|
| 142 |
+
logits = self.model(input_ids = sentences, attention_mask=masks)[0]
|
| 143 |
+
'''
|
| 144 |
+
print("logits " + str(logits))
|
| 145 |
+
print(sentences.shape)
|
| 146 |
+
print(logits.shape)
|
| 147 |
+
print(torch.eq(logits.argmax(dim = 2), refs).sum())
|
| 148 |
+
'''
|
| 149 |
+
self.validation_step_outputs.append((sentences.to("cpu"), logits.argmax(dim = 2).to("cpu"), refs.to('cpu'), untruncated.to("cpu")))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def on_validation_epoch_end(self):
|
| 153 |
+
if self.trainer.num_devices > 1:
|
| 154 |
+
gathered_outputs = [None for i in range(self.trainer.num_devices)]
|
| 155 |
+
dist.all_gather_object(gathered_outputs, self.validation_step_outputs)
|
| 156 |
+
gathered_outputs = sum(gathered_outputs, [])
|
| 157 |
+
else:
|
| 158 |
+
gathered_outputs = self.validation_step_outputs
|
| 159 |
+
|
| 160 |
+
sentences = [list(output[0]) for output in gathered_outputs]
|
| 161 |
+
|
| 162 |
+
class_to_index = get_class_to_index(self.args.corpus)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
index_to_class = {class_to_index[key]: key for key in class_to_index}
|
| 167 |
+
predictions = [list(output[1]) for output in gathered_outputs]
|
| 168 |
+
labels = [list(output[2]) for output in gathered_outputs]
|
| 169 |
+
|
| 170 |
+
untruncateds = [list(output[3]) for output in gathered_outputs]
|
| 171 |
+
|
| 172 |
+
untruncateds = [[index_to_class[int(label.item())] for label in sentence if int(label.item()) != -100] for batched in untruncateds for sentence in batched]
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
output = {"sentences": [[int(word.item()) for (word, label) in zip(sentence_w, sentence_l) if label != -100] for (batched_w, batched_l) in zip(sentences, labels) for (sentence_w, sentence_l) in zip(batched_w, batched_l) ],
|
| 176 |
+
"predictions": [[index_to_class[int(pred.item())] for (pred, label) in zip(sentence_p, sentence_l) if label!=-100] for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) ],
|
| 177 |
+
"groundtruth": [[index_to_class[int(label.item())] for label in sentence if label != -100] for batched in labels for sentence in batched]}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
#true_labels = [str(label.item()) for batched in labels for sentence in batched for label in sentence if label != -100]
|
| 181 |
+
#true_predictions = [str(pred.item()) for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) for (pred, label) in zip(sentence_p, sentence_l) if label!=-100 ]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
#print("true_label " + str(len(true_labels)) + " true_predictions "+str(len(true_predictions)))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
#predictions = utils.merge_predictions(gathered_outputs)
|
| 189 |
+
name = self.eval_dataset.name
|
| 190 |
+
scores = [0]
|
| 191 |
+
|
| 192 |
+
#print(predictions)
|
| 193 |
+
#print(predictions[0].shape)
|
| 194 |
+
|
| 195 |
+
if self.trainer.is_global_zero:
|
| 196 |
+
if not self.args.no_eval:
|
| 197 |
+
epoch = self.trainer.current_epoch
|
| 198 |
+
|
| 199 |
+
metric = evaluate.load("seqeval", cache_dir = self.args.cache_dir)
|
| 200 |
+
|
| 201 |
+
predictions = [ preds + ['O'] * (len(full_groundtruth) - len(preds)) for (preds, full_groundtruth) in zip(output['predictions'], untruncateds)]
|
| 202 |
+
all_metrics = metric.compute(predictions = predictions, references = untruncateds)
|
| 203 |
+
|
| 204 |
+
#accuracy = sum([1 if p == l else 0 for (p, l) in zip(true_predictions, true_labels)])/len(true_labels)
|
| 205 |
+
|
| 206 |
+
#precision = torch.eq(self.eval_dataset.data, predictions.argmax(dim = 1)).sum().float()/self.eval_dataset.data.numel()
|
| 207 |
+
#self.print("Epoch: "+str(epoch)+" accuracy: "+str(accuracy))
|
| 208 |
+
if self.args.eval_truncated:
|
| 209 |
+
report = classification_report(output['groundtruth'], output['predictions'], mode = 'strict', scheme = IOB2, output_dict = True)
|
| 210 |
+
else:
|
| 211 |
+
#report = classification_report(predictions, untruncateds, output_dict = True)#, mode = 'strict', scheme = IOB2, output_dict = True)
|
| 212 |
+
report = classification_report(predictions, untruncateds, mode = 'strict', scheme = IOB2, output_dict = True)
|
| 213 |
+
self.print(report)
|
| 214 |
+
#self.print("______________________________________________")
|
| 215 |
+
#self.print(report_strict)
|
| 216 |
+
scores = [report['micro avg']['f1-score']]
|
| 217 |
+
with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f:
|
| 218 |
+
json.dump(output, f)
|
| 219 |
+
|
| 220 |
+
dist.broadcast_object_list(scores)
|
| 221 |
+
|
| 222 |
+
self.log('val/score', scores[0], prog_bar=True, rank_zero_only=True)
|
| 223 |
+
self.validation_step_outputs.clear()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
self.validation_step_outputs.clear()
|
| 228 |
+
|
| 229 |
+
def configure_optimizers(self):
|
| 230 |
+
num_training_steps = self.trainer.num_training_steps
|
| 231 |
+
|
| 232 |
+
self.print(f'Num training steps: {num_training_steps}')
|
| 233 |
+
num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
|
| 234 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
|
| 235 |
+
scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
|
| 236 |
+
return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
|
| 237 |
+
|
| 238 |
+
class NERDataModule(LightningDataModule):
|
| 239 |
+
|
| 240 |
+
def __init__(self, args):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.args = args
|
| 243 |
+
self.collate_fn = get_collate_fn()
|
| 244 |
+
|
| 245 |
+
def prepare_data(self):
|
| 246 |
+
args = self.args
|
| 247 |
+
if args.do_train:
|
| 248 |
+
self.train_dataset = NERDataset(args, args.train_file, split='train')
|
| 249 |
+
if self.args.do_train or self.args.do_valid:
|
| 250 |
+
self.val_dataset = NERDataset(args, args.valid_file, split='valid')
|
| 251 |
+
if self.args.do_test:
|
| 252 |
+
self.test_dataset = NERDataset(args, args.test_file, split='valid')
|
| 253 |
+
|
| 254 |
+
def print_stats(self):
|
| 255 |
+
if self.args.do_train:
|
| 256 |
+
print(f'Train dataset: {len(self.train_dataset)}')
|
| 257 |
+
if self.args.do_train or self.args.do_valid:
|
| 258 |
+
print(f'Valid dataset: {len(self.val_dataset)}')
|
| 259 |
+
if self.args.do_test:
|
| 260 |
+
print(f'Test dataset: {len(self.test_dataset)}')
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def train_dataloader(self):
|
| 264 |
+
return torch.utils.data.DataLoader(
|
| 265 |
+
self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
|
| 266 |
+
collate_fn=self.collate_fn)
|
| 267 |
+
|
| 268 |
+
def val_dataloader(self):
|
| 269 |
+
return torch.utils.data.DataLoader(
|
| 270 |
+
self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
|
| 271 |
+
collate_fn=self.collate_fn)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def test_dataloader(self):
|
| 275 |
+
return torch.utils.data.DataLoader(
|
| 276 |
+
self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
|
| 277 |
+
collate_fn=self.collate_fn)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
|
| 282 |
+
def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str:
|
| 283 |
+
filepath = self.format_checkpoint_name(monitor_candidates)
|
| 284 |
+
return filepath
|
| 285 |
+
|
| 286 |
+
def main():
|
| 287 |
+
transformers.utils.logging.set_verbosity_error()
|
| 288 |
+
args = get_args()
|
| 289 |
+
|
| 290 |
+
pl.seed_everything(args.seed, workers = True)
|
| 291 |
+
|
| 292 |
+
if args.do_train:
|
| 293 |
+
model = ChemIENERecognizer(args)
|
| 294 |
+
else:
|
| 295 |
+
model = ChemIENERecognizer.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False,
|
| 296 |
+
args=args)
|
| 297 |
+
|
| 298 |
+
dm = NERDataModule(args)
|
| 299 |
+
dm.prepare_data()
|
| 300 |
+
dm.print_stats()
|
| 301 |
+
|
| 302 |
+
checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True)
|
| 303 |
+
# checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True)
|
| 304 |
+
lr_monitor = LearningRateMonitor(logging_interval='step')
|
| 305 |
+
logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='')
|
| 306 |
+
|
| 307 |
+
trainer = pl.Trainer(
|
| 308 |
+
strategy=DDPStrategy(find_unused_parameters=False),
|
| 309 |
+
accelerator='gpu',
|
| 310 |
+
precision = 16,
|
| 311 |
+
devices=args.gpus,
|
| 312 |
+
logger=logger,
|
| 313 |
+
default_root_dir=args.save_path,
|
| 314 |
+
callbacks=[checkpoint, lr_monitor],
|
| 315 |
+
max_epochs=args.epochs,
|
| 316 |
+
gradient_clip_val=args.max_grad_norm,
|
| 317 |
+
accumulate_grad_batches=args.gradient_accumulation_steps,
|
| 318 |
+
check_val_every_n_epoch=args.eval_per_epoch,
|
| 319 |
+
log_every_n_steps=10,
|
| 320 |
+
deterministic='warn')
|
| 321 |
+
|
| 322 |
+
if args.do_train:
|
| 323 |
+
trainer.num_training_steps = math.ceil(
|
| 324 |
+
len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
|
| 325 |
+
model.eval_dataset = dm.val_dataset
|
| 326 |
+
ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None
|
| 327 |
+
trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
|
| 328 |
+
model = ChemIENERecognizer.load_from_checkpoint(checkpoint.best_model_path, args=args)
|
| 329 |
+
|
| 330 |
+
if args.do_valid:
|
| 331 |
+
|
| 332 |
+
model.eval_dataset = dm.val_dataset
|
| 333 |
+
|
| 334 |
+
trainer.validate(model, datamodule=dm)
|
| 335 |
+
|
| 336 |
+
if args.do_test:
|
| 337 |
+
|
| 338 |
+
model.test_dataset = dm.test_dataset
|
| 339 |
+
|
| 340 |
+
trainer.test(model, datamodule=dm)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
main()
|
| 345 |
+
|
chemiener/model.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_model(args):
|
| 9 |
+
if args.corpus == "chemu":
|
| 10 |
+
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False)
|
| 11 |
+
elif args.corpus == "chemdner":
|
| 12 |
+
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False)
|
| 13 |
+
elif args.corpus == "chemdner-mol":
|
| 14 |
+
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False)
|
chemiener/utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
def merge_predictions(results):
|
| 3 |
+
if len(results) == 0:
|
| 4 |
+
return []
|
| 5 |
+
predictions = {}
|
| 6 |
+
for batch_preds in results:
|
| 7 |
+
for idx, preds in enumerate(batch_preds):
|
| 8 |
+
predictions[idx] = preds
|
| 9 |
+
predictions = [predictions[i] for i in range(len(predictions))]
|
| 10 |
+
|
| 11 |
+
return predictions
|
| 12 |
+
|
| 13 |
+
def get_class_to_index(corpus):
|
| 14 |
+
if corpus == "chemu":
|
| 15 |
+
return {'B-EXAMPLE_LABEL': 1, 'B-REACTION_PRODUCT': 2, 'B-STARTING_MATERIAL': 3, 'B-REAGENT_CATALYST': 4, 'B-SOLVENT': 5, 'B-OTHER_COMPOUND': 6, 'B-TIME': 7, 'B-TEMPERATURE': 8, 'B-YIELD_OTHER': 9, 'B-YIELD_PERCENT': 10, 'O': 0,
|
| 16 |
+
'I-EXAMPLE_LABEL': 11, 'I-REACTION_PRODUCT': 12, 'I-STARTING_MATERIAL': 13, 'I-REAGENT_CATALYST': 14, 'I-SOLVENT': 15, 'I-OTHER_COMPOUND': 16, 'I-TIME': 17, 'I-TEMPERATURE': 18, 'I-YIELD_OTHER': 19, 'I-YIELD_PERCENT': 20}
|
| 17 |
+
elif corpus == "chemdner":
|
| 18 |
+
return {'O': 0, 'B-ABBREVIATION': 1, 'B-FAMILY': 2, 'B-FORMULA': 3, 'B-IDENTIFIER': 4, 'B-MULTIPLE': 5, 'B-SYSTEMATIC': 6, 'B-TRIVIAL': 7, 'B-NO CLASS': 8, 'I-ABBREVIATION': 9, 'I-FAMILY': 10, 'I-FORMULA': 11, 'I-IDENTIFIER': 12, 'I-MULTIPLE': 13, 'I-SYSTEMATIC': 14, 'I-TRIVIAL': 15, 'I-NO CLASS': 16}
|
| 19 |
+
elif corpus == "chemdner-mol":
|
| 20 |
+
return {'O': 0, 'B-MOL': 1, 'I-MOL': 2}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
chemietoolkit/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .interface import ChemIEToolkit
|
chemietoolkit/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
chemietoolkit/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (189 Bytes). View file
|
|
|
chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
chemietoolkit/__pycache__/interface.cpython-310.pyc
ADDED
|
Binary file (29.3 kB). View file
|
|
|
chemietoolkit/__pycache__/interface.cpython-38.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
chemietoolkit/__pycache__/tableextractor.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
chemietoolkit/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (25 kB). View file
|
|
|
chemietoolkit/chemrxnextractor.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PyPDF2 import PdfReader, PdfWriter
|
| 2 |
+
import pdfminer.high_level
|
| 3 |
+
import pdfminer.layout
|
| 4 |
+
from operator import itemgetter
|
| 5 |
+
import os
|
| 6 |
+
import pdftotext
|
| 7 |
+
from chemrxnextractor import RxnExtractor
|
| 8 |
+
|
| 9 |
+
class ChemRxnExtractor(object):
|
| 10 |
+
def __init__(self, pdf, pn, model_dir, device):
|
| 11 |
+
self.pdf_file = pdf
|
| 12 |
+
self.pages = pn
|
| 13 |
+
self.model_dir = os.path.join(model_dir, "cre_models_v0.1") # directory saving both prod and role models
|
| 14 |
+
use_cuda = (device == 'cuda')
|
| 15 |
+
self.rxn_extractor = RxnExtractor(self.model_dir, use_cuda=use_cuda)
|
| 16 |
+
self.text_file = "info.txt"
|
| 17 |
+
self.pdf_text = ""
|
| 18 |
+
if len(self.pdf_file) > 0:
|
| 19 |
+
with open(self.pdf_file, "rb") as f:
|
| 20 |
+
self.pdf_text = pdftotext.PDF(f)
|
| 21 |
+
|
| 22 |
+
def set_pdf_file(self, pdf):
|
| 23 |
+
self.pdf_file = pdf
|
| 24 |
+
with open(self.pdf_file, "rb") as f:
|
| 25 |
+
self.pdf_text = pdftotext.PDF(f)
|
| 26 |
+
|
| 27 |
+
def set_pages(self, pn):
|
| 28 |
+
self.pages = pn
|
| 29 |
+
|
| 30 |
+
def set_model_dir(self, md):
|
| 31 |
+
self.model_dir = md
|
| 32 |
+
self.rxn_extractor = RxnExtractor(self.model_dir)
|
| 33 |
+
|
| 34 |
+
def set_text_file(self, tf):
|
| 35 |
+
self.text_file = tf
|
| 36 |
+
|
| 37 |
+
def extract_reactions_from_text(self):
|
| 38 |
+
if self.pages is None:
|
| 39 |
+
return self.extract_all(len(self.pdf_text))
|
| 40 |
+
else:
|
| 41 |
+
return self.extract_all(self.pages)
|
| 42 |
+
|
| 43 |
+
def extract_all(self, pages):
|
| 44 |
+
ans = []
|
| 45 |
+
text = self.get_paragraphs_from_pdf(pages)
|
| 46 |
+
for data in text:
|
| 47 |
+
L = [sent for paragraph in data['paragraphs'] for sent in paragraph]
|
| 48 |
+
reactions = self.get_reactions(L, page_number=data['page'])
|
| 49 |
+
ans.append(reactions)
|
| 50 |
+
return ans
|
| 51 |
+
|
| 52 |
+
def get_reactions(self, sents, page_number=None):
|
| 53 |
+
rxns = self.rxn_extractor.get_reactions(sents)
|
| 54 |
+
|
| 55 |
+
ret = []
|
| 56 |
+
for r in rxns:
|
| 57 |
+
if len(r['reactions']) != 0: ret.append(r)
|
| 58 |
+
ans = {}
|
| 59 |
+
ans.update({'page' : page_number})
|
| 60 |
+
ans.update({'reactions' : ret})
|
| 61 |
+
return ans
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_paragraphs_from_pdf(self, pages):
|
| 65 |
+
current_page_num = 1
|
| 66 |
+
if pages is None:
|
| 67 |
+
pages = len(self.pdf_text)
|
| 68 |
+
result = []
|
| 69 |
+
for page in range(pages):
|
| 70 |
+
content = self.pdf_text[page]
|
| 71 |
+
pg = content.split("\n\n")
|
| 72 |
+
L = []
|
| 73 |
+
for line in pg:
|
| 74 |
+
paragraph = []
|
| 75 |
+
if '\x0c' in line:
|
| 76 |
+
continue
|
| 77 |
+
text = line
|
| 78 |
+
text = text.replace("\n", " ")
|
| 79 |
+
text = text.replace("- ", "-")
|
| 80 |
+
curind = 0
|
| 81 |
+
i = 0
|
| 82 |
+
while i < len(text):
|
| 83 |
+
if text[i] == '.':
|
| 84 |
+
if i != 0 and not text[i-1].isdigit() or i != len(text) - 1 and (text[i+1] == " " or text[i+1] == "\n"):
|
| 85 |
+
paragraph.append(text[curind:i+1] + "\n")
|
| 86 |
+
while(i < len(text) and text[i] != " "):
|
| 87 |
+
i += 1
|
| 88 |
+
curind = i + 1
|
| 89 |
+
i += 1
|
| 90 |
+
if curind != i:
|
| 91 |
+
if text[i - 1] == " ":
|
| 92 |
+
if i != 1:
|
| 93 |
+
i -= 1
|
| 94 |
+
else:
|
| 95 |
+
break
|
| 96 |
+
if text[i - 1] != '.':
|
| 97 |
+
paragraph.append(text[curind:i] + ".\n")
|
| 98 |
+
else:
|
| 99 |
+
paragraph.append(text[curind:i] + "\n")
|
| 100 |
+
L.append(paragraph)
|
| 101 |
+
|
| 102 |
+
result.append({
|
| 103 |
+
'paragraphs': L,
|
| 104 |
+
'page': current_page_num
|
| 105 |
+
})
|
| 106 |
+
current_page_num += 1
|
| 107 |
+
return result
|
chemietoolkit/interface.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import re
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
import layoutparser as lp
|
| 5 |
+
import pdf2image
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 8 |
+
from molnextr import MolScribe
|
| 9 |
+
from rxnim import RxnScribe, MolDetect
|
| 10 |
+
from chemiener import ChemNER
|
| 11 |
+
from .chemrxnextractor import ChemRxnExtractor
|
| 12 |
+
from .tableextractor import TableExtractor
|
| 13 |
+
from .utils import *
|
| 14 |
+
|
| 15 |
+
class ChemIEToolkit:
|
| 16 |
+
def __init__(self, device=None):
|
| 17 |
+
if device is None:
|
| 18 |
+
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 19 |
+
else:
|
| 20 |
+
self.device = torch.device(device)
|
| 21 |
+
|
| 22 |
+
self._molscribe = None
|
| 23 |
+
self._rxnscribe = None
|
| 24 |
+
self._pdfparser = None
|
| 25 |
+
self._moldet = None
|
| 26 |
+
self._chemrxnextractor = None
|
| 27 |
+
self._chemner = None
|
| 28 |
+
self._coref = None
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def molscribe(self):
|
| 32 |
+
if self._molscribe is None:
|
| 33 |
+
self.init_molscribe()
|
| 34 |
+
return self._molscribe
|
| 35 |
+
|
| 36 |
+
@lru_cache(maxsize=None)
|
| 37 |
+
def init_molscribe(self, ckpt_path=None):
|
| 38 |
+
"""
|
| 39 |
+
Set model to custom checkpoint
|
| 40 |
+
Parameters:
|
| 41 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 42 |
+
"""
|
| 43 |
+
if ckpt_path is None:
|
| 44 |
+
ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m680k.pth")
|
| 45 |
+
self._molscribe = MolScribe(ckpt_path, device=self.device)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def rxnscribe(self):
|
| 50 |
+
if self._rxnscribe is None:
|
| 51 |
+
self.init_rxnscribe()
|
| 52 |
+
return self._rxnscribe
|
| 53 |
+
|
| 54 |
+
@lru_cache(maxsize=None)
|
| 55 |
+
def init_rxnscribe(self, ckpt_path=None):
|
| 56 |
+
"""
|
| 57 |
+
Set model to custom checkpoint
|
| 58 |
+
Parameters:
|
| 59 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 60 |
+
"""
|
| 61 |
+
if ckpt_path is None:
|
| 62 |
+
ckpt_path = hf_hub_download("yujieq/RxnScribe", "pix2seq_reaction_full.ckpt")
|
| 63 |
+
self._rxnscribe = RxnScribe(ckpt_path, device=self.device)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def pdfparser(self):
|
| 68 |
+
if self._pdfparser is None:
|
| 69 |
+
self.init_pdfparser()
|
| 70 |
+
return self._pdfparser
|
| 71 |
+
|
| 72 |
+
@lru_cache(maxsize=None)
|
| 73 |
+
def init_pdfparser(self, ckpt_path=None):
|
| 74 |
+
"""
|
| 75 |
+
Set model to custom checkpoint
|
| 76 |
+
Parameters:
|
| 77 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 78 |
+
"""
|
| 79 |
+
config_path = "lp://efficientdet/PubLayNet/tf_efficientdet_d1"
|
| 80 |
+
self._pdfparser = lp.AutoLayoutModel(config_path, model_path=ckpt_path, device=self.device.type)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def moldet(self):
|
| 85 |
+
if self._moldet is None:
|
| 86 |
+
self.init_moldet()
|
| 87 |
+
return self._moldet
|
| 88 |
+
|
| 89 |
+
@lru_cache(maxsize=None)
|
| 90 |
+
def init_moldet(self, ckpt_path=None):
|
| 91 |
+
"""
|
| 92 |
+
Set model to custom checkpoint
|
| 93 |
+
Parameters:
|
| 94 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 95 |
+
"""
|
| 96 |
+
if ckpt_path is None:
|
| 97 |
+
ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "best_hf.ckpt")
|
| 98 |
+
self._moldet = MolDetect(ckpt_path, device=self.device)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@property
|
| 102 |
+
def coref(self):
|
| 103 |
+
if self._coref is None:
|
| 104 |
+
self.init_coref()
|
| 105 |
+
return self._coref
|
| 106 |
+
|
| 107 |
+
@lru_cache(maxsize=None)
|
| 108 |
+
def init_coref(self, ckpt_path=None):
|
| 109 |
+
"""
|
| 110 |
+
Set model to custom checkpoint
|
| 111 |
+
Parameters:
|
| 112 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 113 |
+
"""
|
| 114 |
+
if ckpt_path is None:
|
| 115 |
+
ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "coref_best_hf.ckpt")
|
| 116 |
+
self._coref = MolDetect(ckpt_path, device=self.device, coref=True)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def chemrxnextractor(self):
|
| 121 |
+
if self._chemrxnextractor is None:
|
| 122 |
+
self.init_chemrxnextractor()
|
| 123 |
+
return self._chemrxnextractor
|
| 124 |
+
|
| 125 |
+
@lru_cache(maxsize=None)
|
| 126 |
+
def init_chemrxnextractor(self, ckpt_path=None):
|
| 127 |
+
"""
|
| 128 |
+
Set model to custom checkpoint
|
| 129 |
+
Parameters:
|
| 130 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 131 |
+
"""
|
| 132 |
+
if ckpt_path is None:
|
| 133 |
+
ckpt_path = snapshot_download(repo_id="amberwang/chemrxnextractor-training-modules")
|
| 134 |
+
self._chemrxnextractor = ChemRxnExtractor("", None, ckpt_path, self.device.type)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def chemner(self):
|
| 139 |
+
if self._chemner is None:
|
| 140 |
+
self.init_chemner()
|
| 141 |
+
return self._chemner
|
| 142 |
+
|
| 143 |
+
@lru_cache(maxsize=None)
|
| 144 |
+
def init_chemner(self, ckpt_path=None):
|
| 145 |
+
"""
|
| 146 |
+
Set model to custom checkpoint
|
| 147 |
+
Parameters:
|
| 148 |
+
ckpt_path: path to checkpoint to use, if None then will use default
|
| 149 |
+
"""
|
| 150 |
+
if ckpt_path is None:
|
| 151 |
+
ckpt_path = hf_hub_download("Ozymandias314/ChemNERckpt", "best.ckpt")
|
| 152 |
+
self._chemner = ChemNER(ckpt_path, device=self.device)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def tableextractor(self):
|
| 157 |
+
return TableExtractor()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def extract_figures_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
|
| 161 |
+
"""
|
| 162 |
+
Find and return all figures from a pdf page
|
| 163 |
+
Parameters:
|
| 164 |
+
pdf: path to pdf
|
| 165 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 166 |
+
output_bbox: whether to output bounding boxes for each individual entry of a table
|
| 167 |
+
output_image: whether to include PIL image for figures. default is True
|
| 168 |
+
Returns:
|
| 169 |
+
list of content in the following format
|
| 170 |
+
[
|
| 171 |
+
{ # first figure
|
| 172 |
+
'title': str,
|
| 173 |
+
'figure': {
|
| 174 |
+
'image': PIL image or None,
|
| 175 |
+
'bbox': list in form [x1, y1, x2, y2],
|
| 176 |
+
}
|
| 177 |
+
'table': {
|
| 178 |
+
'bbox': list in form [x1, y1, x2, y2] or empty list,
|
| 179 |
+
'content': {
|
| 180 |
+
'columns': list of column headers,
|
| 181 |
+
'rows': list of list of row content,
|
| 182 |
+
} or None
|
| 183 |
+
}
|
| 184 |
+
'footnote': str or empty,
|
| 185 |
+
'page': int
|
| 186 |
+
}
|
| 187 |
+
# more figures
|
| 188 |
+
]
|
| 189 |
+
"""
|
| 190 |
+
pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
|
| 191 |
+
|
| 192 |
+
table_ext = self.tableextractor
|
| 193 |
+
table_ext.set_pdf_file(pdf)
|
| 194 |
+
table_ext.set_output_image(output_image)
|
| 195 |
+
|
| 196 |
+
table_ext.set_output_bbox(output_bbox)
|
| 197 |
+
|
| 198 |
+
return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='figures')
|
| 199 |
+
|
| 200 |
+
def extract_tables_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
|
| 201 |
+
"""
|
| 202 |
+
Find and return all tables from a pdf page
|
| 203 |
+
Parameters:
|
| 204 |
+
pdf: path to pdf
|
| 205 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 206 |
+
output_bbox: whether to include bboxes for individual entries of the table
|
| 207 |
+
output_image: whether to include PIL image for figures. default is True
|
| 208 |
+
Returns:
|
| 209 |
+
list of content in the following format
|
| 210 |
+
[
|
| 211 |
+
{ # first table
|
| 212 |
+
'title': str,
|
| 213 |
+
'figure': {
|
| 214 |
+
'image': PIL image or None,
|
| 215 |
+
'bbox': list in form [x1, y1, x2, y2] or empty list,
|
| 216 |
+
}
|
| 217 |
+
'table': {
|
| 218 |
+
'bbox': list in form [x1, y1, x2, y2] or empty list,
|
| 219 |
+
'content': {
|
| 220 |
+
'columns': list of column headers,
|
| 221 |
+
'rows': list of list of row content,
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
'footnote': str or empty,
|
| 225 |
+
'page': int
|
| 226 |
+
}
|
| 227 |
+
# more tables
|
| 228 |
+
]
|
| 229 |
+
"""
|
| 230 |
+
pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
|
| 231 |
+
|
| 232 |
+
table_ext = self.tableextractor
|
| 233 |
+
table_ext.set_pdf_file(pdf)
|
| 234 |
+
table_ext.set_output_image(output_image)
|
| 235 |
+
|
| 236 |
+
table_ext.set_output_bbox(output_bbox)
|
| 237 |
+
|
| 238 |
+
return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='tables')
|
| 239 |
+
|
| 240 |
+
def extract_molecules_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None):
|
| 241 |
+
"""
|
| 242 |
+
Get all molecules and their information from a pdf
|
| 243 |
+
Parameters:
|
| 244 |
+
pdf: path to pdf, or byte file
|
| 245 |
+
batch_size: batch size for inference in all models
|
| 246 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 247 |
+
Returns:
|
| 248 |
+
list of figures and corresponding molecule info in the following format
|
| 249 |
+
[
|
| 250 |
+
{ # first figure
|
| 251 |
+
'image': ndarray of the figure image,
|
| 252 |
+
'molecules': [
|
| 253 |
+
{ # first molecule
|
| 254 |
+
'bbox': tuple in the form (x1, y1, x2, y2),
|
| 255 |
+
'score': float,
|
| 256 |
+
'image': ndarray of cropped molecule image,
|
| 257 |
+
'smiles': str,
|
| 258 |
+
'molfile': str
|
| 259 |
+
},
|
| 260 |
+
# more molecules
|
| 261 |
+
],
|
| 262 |
+
'page': int
|
| 263 |
+
},
|
| 264 |
+
# more figures
|
| 265 |
+
]
|
| 266 |
+
"""
|
| 267 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
| 268 |
+
images = [figure['figure']['image'] for figure in figures]
|
| 269 |
+
results = self.extract_molecules_from_figures(images, batch_size=batch_size)
|
| 270 |
+
for figure, result in zip(figures, results):
|
| 271 |
+
result['page'] = figure['page']
|
| 272 |
+
return results
|
| 273 |
+
|
| 274 |
+
def extract_molecule_bboxes_from_figures(self, figures, batch_size=16):
|
| 275 |
+
"""
|
| 276 |
+
Return bounding boxes of molecules in images
|
| 277 |
+
Parameters:
|
| 278 |
+
figures: list of PIL or ndarray images
|
| 279 |
+
batch_size: batch size for inference
|
| 280 |
+
Returns:
|
| 281 |
+
list of results for each figure in the following format
|
| 282 |
+
[
|
| 283 |
+
[ # first figure
|
| 284 |
+
{ # first bounding box
|
| 285 |
+
'category': str,
|
| 286 |
+
'bbox': tuple in the form (x1, y1, x2, y2),
|
| 287 |
+
'category_id': int,
|
| 288 |
+
'score': float
|
| 289 |
+
},
|
| 290 |
+
# more bounding boxes
|
| 291 |
+
],
|
| 292 |
+
# more figures
|
| 293 |
+
]
|
| 294 |
+
"""
|
| 295 |
+
figures = [convert_to_pil(figure) for figure in figures]
|
| 296 |
+
return self.moldet.predict_images(figures, batch_size=batch_size)
|
| 297 |
+
|
| 298 |
+
def extract_molecules_from_figures(self, figures, batch_size=16):
|
| 299 |
+
"""
|
| 300 |
+
Get all molecules and their information from list of figures
|
| 301 |
+
Parameters:
|
| 302 |
+
figures: list of PIL or ndarray images
|
| 303 |
+
batch_size: batch size for inference
|
| 304 |
+
Returns:
|
| 305 |
+
list of results for each figure in the following format
|
| 306 |
+
[
|
| 307 |
+
{ # first figure
|
| 308 |
+
'image': ndarray of the figure image,
|
| 309 |
+
'molecules': [
|
| 310 |
+
{ # first molecule
|
| 311 |
+
'bbox': tuple in the form (x1, y1, x2, y2),
|
| 312 |
+
'score': float,
|
| 313 |
+
'image': ndarray of cropped molecule image,
|
| 314 |
+
'smiles': str,
|
| 315 |
+
'molfile': str
|
| 316 |
+
},
|
| 317 |
+
# more molecules
|
| 318 |
+
],
|
| 319 |
+
},
|
| 320 |
+
# more figures
|
| 321 |
+
]
|
| 322 |
+
"""
|
| 323 |
+
bboxes = self.extract_molecule_bboxes_from_figures(figures, batch_size=batch_size)
|
| 324 |
+
figures = [convert_to_cv2(figure) for figure in figures]
|
| 325 |
+
results, cropped_images, refs = clean_bbox_output(figures, bboxes)
|
| 326 |
+
mol_info = self.molscribe.predict_images(cropped_images, batch_size=batch_size)
|
| 327 |
+
for info, ref in zip(mol_info, refs):
|
| 328 |
+
ref.update(info)
|
| 329 |
+
return results
|
| 330 |
+
|
| 331 |
+
def extract_molecule_corefs_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe = True, ocr = True):
|
| 332 |
+
"""
|
| 333 |
+
Get all molecule bboxes and corefs from figures in pdf
|
| 334 |
+
Parameters:
|
| 335 |
+
pdf: path to pdf, or byte file
|
| 336 |
+
batch_size: batch size for inference in all models
|
| 337 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 338 |
+
Returns:
|
| 339 |
+
list of results for each figure in the following format:
|
| 340 |
+
[
|
| 341 |
+
{
|
| 342 |
+
'bboxes': [
|
| 343 |
+
{ # first bbox
|
| 344 |
+
'category': '[Sup]',
|
| 345 |
+
'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
|
| 346 |
+
'category_id': 4,
|
| 347 |
+
'score': -0.07593922317028046
|
| 348 |
+
},
|
| 349 |
+
# More bounding boxes
|
| 350 |
+
],
|
| 351 |
+
'corefs': [
|
| 352 |
+
[0, 1], # molecule bbox index, identifier bbox index
|
| 353 |
+
[3, 4],
|
| 354 |
+
# More coref pairs
|
| 355 |
+
],
|
| 356 |
+
'page': int
|
| 357 |
+
},
|
| 358 |
+
# More figures
|
| 359 |
+
]
|
| 360 |
+
"""
|
| 361 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
| 362 |
+
images = [figure['figure']['image'] for figure in figures]
|
| 363 |
+
results = self.extract_molecule_corefs_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
| 364 |
+
for figure, result in zip(figures, results):
|
| 365 |
+
result['page'] = figure['page']
|
| 366 |
+
return results
|
| 367 |
+
|
| 368 |
+
def extract_molecule_corefs_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
|
| 369 |
+
"""
|
| 370 |
+
Get all molecule bboxes and corefs from list of figures
|
| 371 |
+
Parameters:
|
| 372 |
+
figures: list of PIL or ndarray images
|
| 373 |
+
batch_size: batch size for inference
|
| 374 |
+
Returns:
|
| 375 |
+
list of results for each figure in the following format:
|
| 376 |
+
[
|
| 377 |
+
{
|
| 378 |
+
'bboxes': [
|
| 379 |
+
{ # first bbox
|
| 380 |
+
'category': '[Sup]',
|
| 381 |
+
'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
|
| 382 |
+
'category_id': 4,
|
| 383 |
+
'score': -0.07593922317028046
|
| 384 |
+
},
|
| 385 |
+
# More bounding boxes
|
| 386 |
+
],
|
| 387 |
+
'corefs': [
|
| 388 |
+
[0, 1], # molecule bbox index, identifier bbox index
|
| 389 |
+
[3, 4],
|
| 390 |
+
# More coref pairs
|
| 391 |
+
],
|
| 392 |
+
},
|
| 393 |
+
# More figures
|
| 394 |
+
]
|
| 395 |
+
"""
|
| 396 |
+
figures = [convert_to_pil(figure) for figure in figures]
|
| 397 |
+
return self.coref.predict_images(figures, batch_size=batch_size, coref=True, molscribe = molscribe, ocr = ocr)
|
| 398 |
+
|
| 399 |
+
def extract_reactions_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe=True, ocr=True):
|
| 400 |
+
"""
|
| 401 |
+
Get reaction information from figures in pdf
|
| 402 |
+
Parameters:
|
| 403 |
+
pdf: path to pdf, or byte file
|
| 404 |
+
batch_size: batch size for inference in all models
|
| 405 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 406 |
+
molscribe: whether to predict and return smiles and molfile info
|
| 407 |
+
ocr: whether to predict and return text of conditions
|
| 408 |
+
Returns:
|
| 409 |
+
list of figures and corresponding molecule info in the following format
|
| 410 |
+
[
|
| 411 |
+
{
|
| 412 |
+
'figure': PIL image
|
| 413 |
+
'reactions': [
|
| 414 |
+
{
|
| 415 |
+
'reactants': [
|
| 416 |
+
{
|
| 417 |
+
'category': str,
|
| 418 |
+
'bbox': tuple (x1,x2,y1,y2),
|
| 419 |
+
'category_id': int,
|
| 420 |
+
'smiles': str,
|
| 421 |
+
'molfile': str,
|
| 422 |
+
},
|
| 423 |
+
# more reactants
|
| 424 |
+
],
|
| 425 |
+
'conditions': [
|
| 426 |
+
{
|
| 427 |
+
'category': str,
|
| 428 |
+
'bbox': tuple (x1,x2,y1,y2),
|
| 429 |
+
'category_id': int,
|
| 430 |
+
'text': list of str,
|
| 431 |
+
},
|
| 432 |
+
# more conditions
|
| 433 |
+
],
|
| 434 |
+
'products': [
|
| 435 |
+
# same structure as reactants
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
+
# more reactions
|
| 439 |
+
],
|
| 440 |
+
'page': int
|
| 441 |
+
},
|
| 442 |
+
# more figures
|
| 443 |
+
]
|
| 444 |
+
"""
|
| 445 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
| 446 |
+
images = [figure['figure']['image'] for figure in figures]
|
| 447 |
+
results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
| 448 |
+
for figure, result in zip(figures, results):
|
| 449 |
+
result['page'] = figure['page']
|
| 450 |
+
return results
|
| 451 |
+
|
| 452 |
+
def extract_reactions_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
|
| 453 |
+
"""
|
| 454 |
+
Get reaction information from list of figures
|
| 455 |
+
Parameters:
|
| 456 |
+
figures: list of PIL or ndarray images
|
| 457 |
+
batch_size: batch size for inference in all models
|
| 458 |
+
molscribe: whether to predict and return smiles and molfile info
|
| 459 |
+
ocr: whether to predict and return text of conditions
|
| 460 |
+
Returns:
|
| 461 |
+
list of figures and corresponding molecule info in the following format
|
| 462 |
+
[
|
| 463 |
+
{
|
| 464 |
+
'figure': PIL image
|
| 465 |
+
'reactions': [
|
| 466 |
+
{
|
| 467 |
+
'reactants': [
|
| 468 |
+
{
|
| 469 |
+
'category': str,
|
| 470 |
+
'bbox': tuple (x1,x2,y1,y2),
|
| 471 |
+
'category_id': int,
|
| 472 |
+
'smiles': str,
|
| 473 |
+
'molfile': str,
|
| 474 |
+
},
|
| 475 |
+
# more reactants
|
| 476 |
+
],
|
| 477 |
+
'conditions': [
|
| 478 |
+
{
|
| 479 |
+
'category': str,
|
| 480 |
+
'bbox': tuple (x1,x2,y1,y2),
|
| 481 |
+
'category_id': int,
|
| 482 |
+
'text': list of str,
|
| 483 |
+
},
|
| 484 |
+
# more conditions
|
| 485 |
+
],
|
| 486 |
+
'products': [
|
| 487 |
+
# same structure as reactants
|
| 488 |
+
]
|
| 489 |
+
},
|
| 490 |
+
# more reactions
|
| 491 |
+
],
|
| 492 |
+
},
|
| 493 |
+
# more figures
|
| 494 |
+
]
|
| 495 |
+
|
| 496 |
+
"""
|
| 497 |
+
pil_figures = [convert_to_pil(figure) for figure in figures]
|
| 498 |
+
results = []
|
| 499 |
+
reactions = self.rxnscribe.predict_images(pil_figures, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
| 500 |
+
for figure, rxn in zip(figures, reactions):
|
| 501 |
+
data = {
|
| 502 |
+
'figure': figure,
|
| 503 |
+
'reactions': rxn,
|
| 504 |
+
}
|
| 505 |
+
results.append(data)
|
| 506 |
+
return results
|
| 507 |
+
|
| 508 |
+
def extract_molecules_from_text_in_pdf(self, pdf, batch_size=16, num_pages=None):
|
| 509 |
+
"""
|
| 510 |
+
Get molecules in text of given pdf
|
| 511 |
+
|
| 512 |
+
Parameters:
|
| 513 |
+
pdf: path to pdf, or byte file
|
| 514 |
+
batch_size: batch size for inference in all models
|
| 515 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 516 |
+
Returns:
|
| 517 |
+
list of sentences and found molecules in the following format
|
| 518 |
+
[
|
| 519 |
+
{
|
| 520 |
+
'molecules': [
|
| 521 |
+
{ # first paragraph
|
| 522 |
+
'text': str,
|
| 523 |
+
'labels': [
|
| 524 |
+
(str, int, int), # tuple of label, range start (inclusive), range end (exclusive)
|
| 525 |
+
# more labels
|
| 526 |
+
]
|
| 527 |
+
},
|
| 528 |
+
# more paragraphs
|
| 529 |
+
]
|
| 530 |
+
'page': int
|
| 531 |
+
},
|
| 532 |
+
# more pages
|
| 533 |
+
]
|
| 534 |
+
"""
|
| 535 |
+
self.chemrxnextractor.set_pdf_file(pdf)
|
| 536 |
+
self.chemrxnextractor.set_pages(num_pages)
|
| 537 |
+
text = self.chemrxnextractor.get_paragraphs_from_pdf(num_pages)
|
| 538 |
+
result = []
|
| 539 |
+
for data in text:
|
| 540 |
+
model_inp = []
|
| 541 |
+
for paragraph in data['paragraphs']:
|
| 542 |
+
model_inp.append(' '.join(paragraph).replace('\n', ''))
|
| 543 |
+
output = self.chemner.predict_strings(model_inp, batch_size=batch_size)
|
| 544 |
+
to_add = {
|
| 545 |
+
'molecules': [{
|
| 546 |
+
'text': t,
|
| 547 |
+
'labels': labels,
|
| 548 |
+
} for t, labels in zip(model_inp, output)],
|
| 549 |
+
'page': data['page']
|
| 550 |
+
}
|
| 551 |
+
result.append(to_add)
|
| 552 |
+
return result
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def extract_reactions_from_text_in_pdf(self, pdf, num_pages=None):
|
| 556 |
+
"""
|
| 557 |
+
Get reaction information from text in pdf
|
| 558 |
+
Parameters:
|
| 559 |
+
pdf: path to pdf
|
| 560 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 561 |
+
Returns:
|
| 562 |
+
list of pages and corresponding reaction info in the following format
|
| 563 |
+
[
|
| 564 |
+
{
|
| 565 |
+
'page': page number
|
| 566 |
+
'reactions': [
|
| 567 |
+
{
|
| 568 |
+
'tokens': list of words in relevant sentence,
|
| 569 |
+
'reactions' : [
|
| 570 |
+
{
|
| 571 |
+
# key, value pairs where key is the label and value is a tuple
|
| 572 |
+
# or list of tuples of the form (tokens, start index, end index)
|
| 573 |
+
# where indices are for the corresponding token list and start and end are inclusive
|
| 574 |
+
}
|
| 575 |
+
# more reactions
|
| 576 |
+
]
|
| 577 |
+
}
|
| 578 |
+
# more reactions in other sentences
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
# more pages
|
| 582 |
+
]
|
| 583 |
+
"""
|
| 584 |
+
self.chemrxnextractor.set_pdf_file(pdf)
|
| 585 |
+
self.chemrxnextractor.set_pages(num_pages)
|
| 586 |
+
return self.chemrxnextractor.extract_reactions_from_text()
|
| 587 |
+
|
| 588 |
+
def extract_reactions_from_text_in_pdf_combined(self, pdf, num_pages=None):
|
| 589 |
+
"""
|
| 590 |
+
Get reaction information from text in pdf and combined with corefs from figures
|
| 591 |
+
Parameters:
|
| 592 |
+
pdf: path to pdf
|
| 593 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 594 |
+
Returns:
|
| 595 |
+
list of pages and corresponding reaction info in the following format
|
| 596 |
+
[
|
| 597 |
+
{
|
| 598 |
+
'page': page number
|
| 599 |
+
'reactions': [
|
| 600 |
+
{
|
| 601 |
+
'tokens': list of words in relevant sentence,
|
| 602 |
+
'reactions' : [
|
| 603 |
+
{
|
| 604 |
+
# key, value pairs where key is the label and value is a tuple
|
| 605 |
+
# or list of tuples of the form (tokens, start index, end index)
|
| 606 |
+
# where indices are for the corresponding token list and start and end are inclusive
|
| 607 |
+
}
|
| 608 |
+
# more reactions
|
| 609 |
+
]
|
| 610 |
+
}
|
| 611 |
+
# more reactions in other sentences
|
| 612 |
+
]
|
| 613 |
+
},
|
| 614 |
+
# more pages
|
| 615 |
+
]
|
| 616 |
+
"""
|
| 617 |
+
results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
|
| 618 |
+
results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
|
| 619 |
+
return associate_corefs(results, results_coref)
|
| 620 |
+
|
| 621 |
+
def extract_reactions_from_figures_and_tables_in_pdf(self, pdf, num_pages=None, batch_size=16, molscribe=True, ocr=True):
|
| 622 |
+
"""
|
| 623 |
+
Get reaction information from figures and combine with table information in pdf
|
| 624 |
+
Parameters:
|
| 625 |
+
pdf: path to pdf, or byte file
|
| 626 |
+
batch_size: batch size for inference in all models
|
| 627 |
+
num_pages: process only first `num_pages` pages, if `None` then process all
|
| 628 |
+
molscribe: whether to predict and return smiles and molfile info
|
| 629 |
+
ocr: whether to predict and return text of conditions
|
| 630 |
+
Returns:
|
| 631 |
+
list of figures and corresponding molecule info in the following format
|
| 632 |
+
[
|
| 633 |
+
{
|
| 634 |
+
'figure': PIL image
|
| 635 |
+
'reactions': [
|
| 636 |
+
{
|
| 637 |
+
'reactants': [
|
| 638 |
+
{
|
| 639 |
+
'category': str,
|
| 640 |
+
'bbox': tuple (x1,x2,y1,y2),
|
| 641 |
+
'category_id': int,
|
| 642 |
+
'smiles': str,
|
| 643 |
+
'molfile': str,
|
| 644 |
+
},
|
| 645 |
+
# more reactants
|
| 646 |
+
],
|
| 647 |
+
'conditions': [
|
| 648 |
+
{
|
| 649 |
+
'category': str,
|
| 650 |
+
'text': list of str,
|
| 651 |
+
},
|
| 652 |
+
# more conditions
|
| 653 |
+
],
|
| 654 |
+
'products': [
|
| 655 |
+
# same structure as reactants
|
| 656 |
+
]
|
| 657 |
+
},
|
| 658 |
+
# more reactions
|
| 659 |
+
],
|
| 660 |
+
'page': int
|
| 661 |
+
},
|
| 662 |
+
# more figures
|
| 663 |
+
]
|
| 664 |
+
"""
|
| 665 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
| 666 |
+
images = [figure['figure']['image'] for figure in figures]
|
| 667 |
+
results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
|
| 668 |
+
results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
|
| 669 |
+
results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
|
| 670 |
+
results = replace_rgroups_in_figure(figures, results, results_coref, self.molscribe, batch_size=batch_size)
|
| 671 |
+
results = expand_reactions_with_backout(results, results_coref, self.molscribe)
|
| 672 |
+
return results
|
| 673 |
+
|
| 674 |
+
def extract_reactions_from_pdf(self, pdf, num_pages=None, batch_size=16):
|
| 675 |
+
"""
|
| 676 |
+
Returns:
|
| 677 |
+
dictionary of reactions from multimodal sources
|
| 678 |
+
{
|
| 679 |
+
'figures': [
|
| 680 |
+
{
|
| 681 |
+
'figure': PIL image
|
| 682 |
+
'reactions': [
|
| 683 |
+
{
|
| 684 |
+
'reactants': [
|
| 685 |
+
{
|
| 686 |
+
'category': str,
|
| 687 |
+
'bbox': tuple (x1,x2,y1,y2),
|
| 688 |
+
'category_id': int,
|
| 689 |
+
'smiles': str,
|
| 690 |
+
'molfile': str,
|
| 691 |
+
},
|
| 692 |
+
# more reactants
|
| 693 |
+
],
|
| 694 |
+
'conditions': [
|
| 695 |
+
{
|
| 696 |
+
'category': str,
|
| 697 |
+
'text': list of str,
|
| 698 |
+
},
|
| 699 |
+
# more conditions
|
| 700 |
+
],
|
| 701 |
+
'products': [
|
| 702 |
+
# same structure as reactants
|
| 703 |
+
]
|
| 704 |
+
},
|
| 705 |
+
# more reactions
|
| 706 |
+
],
|
| 707 |
+
'page': int
|
| 708 |
+
},
|
| 709 |
+
# more figures
|
| 710 |
+
]
|
| 711 |
+
'text': [
|
| 712 |
+
{
|
| 713 |
+
'page': page number
|
| 714 |
+
'reactions': [
|
| 715 |
+
{
|
| 716 |
+
'tokens': list of words in relevant sentence,
|
| 717 |
+
'reactions' : [
|
| 718 |
+
{
|
| 719 |
+
# key, value pairs where key is the label and value is a tuple
|
| 720 |
+
# or list of tuples of the form (tokens, start index, end index)
|
| 721 |
+
# where indices are for the corresponding token list and start and end are inclusive
|
| 722 |
+
}
|
| 723 |
+
# more reactions
|
| 724 |
+
]
|
| 725 |
+
}
|
| 726 |
+
# more reactions in other sentences
|
| 727 |
+
]
|
| 728 |
+
},
|
| 729 |
+
# more pages
|
| 730 |
+
]
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
"""
|
| 734 |
+
figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
|
| 735 |
+
images = [figure['figure']['image'] for figure in figures]
|
| 736 |
+
results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=True, ocr=True)
|
| 737 |
+
table_expanded_results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
|
| 738 |
+
text_results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
|
| 739 |
+
results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
|
| 740 |
+
figure_results = replace_rgroups_in_figure(figures, table_expanded_results, results_coref, self.molscribe, batch_size=batch_size)
|
| 741 |
+
table_expanded_results = expand_reactions_with_backout(figure_results, results_coref, self.molscribe)
|
| 742 |
+
coref_expanded_results = associate_corefs(text_results, results_coref)
|
| 743 |
+
return {
|
| 744 |
+
'figures': table_expanded_results,
|
| 745 |
+
'text': coref_expanded_results,
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
if __name__=="__main__":
|
| 749 |
+
model = ChemIEToolkit()
|
chemietoolkit/tableextractor.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pdf2image
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import layoutparser as lp
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
from PyPDF2 import PdfReader, PdfWriter
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
import pdfminer.high_level
|
| 12 |
+
import pdfminer.layout
|
| 13 |
+
from operator import itemgetter
|
| 14 |
+
|
| 15 |
+
# inputs: pdf_file, page #, bounding box (optional) (llur or ullr), output_bbox
|
| 16 |
+
class TableExtractor(object):
|
| 17 |
+
def __init__(self, output_bbox=True):
|
| 18 |
+
self.pdf_file = ""
|
| 19 |
+
self.page = ""
|
| 20 |
+
self.image_dpi = 200
|
| 21 |
+
self.pdf_dpi = 72
|
| 22 |
+
self.output_bbox = output_bbox
|
| 23 |
+
self.blocks = {}
|
| 24 |
+
self.title_y = 0
|
| 25 |
+
self.column_header_y = 0
|
| 26 |
+
self.model = None
|
| 27 |
+
self.img = None
|
| 28 |
+
self.output_image = True
|
| 29 |
+
self.tagging = {
|
| 30 |
+
'substance': ['compound', 'salt', 'base', 'solvent', 'CBr4', 'collidine', 'InX3', 'substrate', 'ligand', 'PPh3', 'PdL2', 'Cu', 'compd', 'reagent', 'reagant', 'acid', 'aldehyde', 'amine', 'Ln', 'H2O', 'enzyme', 'cofactor', 'oxidant', 'Pt(COD)Cl2', 'CuBr2', 'additive'],
|
| 31 |
+
'ratio': [':'],
|
| 32 |
+
'measurement': ['μM', 'nM', 'IC50', 'CI', 'excitation', 'emission', 'Φ', 'φ', 'shift', 'ee', 'ΔG', 'ΔH', 'TΔS', 'Δ', 'distance', 'trajectory', 'V', 'eV'],
|
| 33 |
+
'temperature': ['temp', 'temperature', 'T', '°C'],
|
| 34 |
+
'time': ['time', 't(', 't ('],
|
| 35 |
+
'result': ['yield', 'aa', 'result', 'product', 'conversion', '(%)'],
|
| 36 |
+
'alkyl group': ['R', 'Ar', 'X', 'Y'],
|
| 37 |
+
'solvent': ['solvent'],
|
| 38 |
+
'counter': ['entry', 'no.'],
|
| 39 |
+
'catalyst': ['catalyst', 'cat.'],
|
| 40 |
+
'conditions': ['condition'],
|
| 41 |
+
'reactant': ['reactant'],
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def set_output_image(self, oi):
|
| 45 |
+
self.output_image = oi
|
| 46 |
+
|
| 47 |
+
def set_pdf_file(self, pdf):
|
| 48 |
+
self.pdf_file = pdf
|
| 49 |
+
|
| 50 |
+
def set_page_num(self, pn):
|
| 51 |
+
self.page = pn
|
| 52 |
+
|
| 53 |
+
def set_output_bbox(self, ob):
|
| 54 |
+
self.output_bbox = ob
|
| 55 |
+
|
| 56 |
+
def run_model(self, page_info):
|
| 57 |
+
#img = np.asarray(pdf2image.convert_from_path(self.pdf_file, dpi=self.image_dpi)[self.page])
|
| 58 |
+
|
| 59 |
+
#model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config', extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"})
|
| 60 |
+
|
| 61 |
+
img = np.asarray(page_info)
|
| 62 |
+
self.img = img
|
| 63 |
+
|
| 64 |
+
layout_result = self.model.detect(img)
|
| 65 |
+
|
| 66 |
+
text_blocks = lp.Layout([b for b in layout_result if b.type == 'Text'])
|
| 67 |
+
title_blocks = lp.Layout([b for b in layout_result if b.type == 'Title'])
|
| 68 |
+
list_blocks = lp.Layout([b for b in layout_result if b.type == 'List'])
|
| 69 |
+
table_blocks = lp.Layout([b for b in layout_result if b.type == 'Table'])
|
| 70 |
+
figure_blocks = lp.Layout([b for b in layout_result if b.type == 'Figure'])
|
| 71 |
+
|
| 72 |
+
self.blocks.update({'text': text_blocks})
|
| 73 |
+
self.blocks.update({'title': title_blocks})
|
| 74 |
+
self.blocks.update({'list': list_blocks})
|
| 75 |
+
self.blocks.update({'table': table_blocks})
|
| 76 |
+
self.blocks.update({'figure': figure_blocks})
|
| 77 |
+
|
| 78 |
+
# type is what coordinates you want to get. it comes in text, title, list, table, and figure
|
| 79 |
+
def convert_to_pdf_coordinates(self, type):
|
| 80 |
+
# scale coordinates
|
| 81 |
+
|
| 82 |
+
blocks = self.blocks[type]
|
| 83 |
+
coordinates = [blocks[a].scale(self.pdf_dpi/self.image_dpi) for a in range(len(blocks))]
|
| 84 |
+
|
| 85 |
+
reader = PdfReader(self.pdf_file)
|
| 86 |
+
|
| 87 |
+
writer = PdfWriter()
|
| 88 |
+
p = reader.pages[self.page]
|
| 89 |
+
a = p.mediabox.upper_left
|
| 90 |
+
new_coords = []
|
| 91 |
+
for new_block in coordinates:
|
| 92 |
+
new_coords.append((new_block.block.x_1, pd.to_numeric(a[1]) - new_block.block.y_2, new_block.block.x_2, pd.to_numeric(a[1]) - new_block.block.y_1))
|
| 93 |
+
|
| 94 |
+
return new_coords
|
| 95 |
+
# output: list of bounding boxes for tables but in pdf coordinates
|
| 96 |
+
|
| 97 |
+
# input: new_coords is singular table bounding box in pdf coordinates
|
| 98 |
+
def extract_singular_table(self, new_coords):
|
| 99 |
+
for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
|
| 100 |
+
elements = []
|
| 101 |
+
for element in page_layout:
|
| 102 |
+
if isinstance(element, pdfminer.layout.LTTextBox):
|
| 103 |
+
for e in element._objs:
|
| 104 |
+
temp = e.bbox
|
| 105 |
+
if temp[0] > min(new_coords[0], new_coords[2]) and temp[0] < max(new_coords[0], new_coords[2]) and temp[1] > min(new_coords[1], new_coords[3]) and temp[1] < max(new_coords[1], new_coords[3]) and temp[2] > min(new_coords[0], new_coords[2]) and temp[2] < max(new_coords[0], new_coords[2]) and temp[3] > min(new_coords[1], new_coords[3]) and temp[3] < max(new_coords[1], new_coords[3]) and isinstance(e, pdfminer.layout.LTTextLineHorizontal):
|
| 106 |
+
elements.append([e.bbox[0], e.bbox[1], e.bbox[2], e.bbox[3], e.get_text()])
|
| 107 |
+
|
| 108 |
+
elements = sorted(elements, key=itemgetter(0))
|
| 109 |
+
w = sorted(elements, key=itemgetter(3), reverse=True)
|
| 110 |
+
if len(w) <= 1:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
ret = {}
|
| 114 |
+
i = 1
|
| 115 |
+
g = [w[0]]
|
| 116 |
+
|
| 117 |
+
while i < len(w) and w[i][3] > w[i-1][1]:
|
| 118 |
+
g.append(w[i])
|
| 119 |
+
i += 1
|
| 120 |
+
g = sorted(g, key=itemgetter(0))
|
| 121 |
+
# check for overlaps
|
| 122 |
+
for a in range(len(g)-1, 0, -1):
|
| 123 |
+
if g[a][0] < g[a-1][2]:
|
| 124 |
+
g[a-1][0] = min(g[a][0], g[a-1][0])
|
| 125 |
+
g[a-1][1] = min(g[a][1], g[a-1][1])
|
| 126 |
+
g[a-1][2] = max(g[a][2], g[a-1][2])
|
| 127 |
+
g[a-1][3] = max(g[a][3], g[a-1][3])
|
| 128 |
+
g[a-1][4] = g[a-1][4].strip() + " " + g[a][4]
|
| 129 |
+
g.pop(a)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
ret.update({"columns":[]})
|
| 133 |
+
for t in g:
|
| 134 |
+
temp_bbox = t[:4]
|
| 135 |
+
|
| 136 |
+
column_text = t[4].strip()
|
| 137 |
+
tag = 'unknown'
|
| 138 |
+
tagged = False
|
| 139 |
+
for key in self.tagging.keys():
|
| 140 |
+
for word in self.tagging[key]:
|
| 141 |
+
if word in column_text:
|
| 142 |
+
tag = key
|
| 143 |
+
tagged = True
|
| 144 |
+
break
|
| 145 |
+
if tagged:
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
if self.output_bbox:
|
| 149 |
+
ret["columns"].append({'text':column_text,'tag': tag, 'bbox':temp_bbox})
|
| 150 |
+
else:
|
| 151 |
+
ret["columns"].append({'text':column_text,'tag': tag})
|
| 152 |
+
self.column_header_y = max(t[1], t[3])
|
| 153 |
+
ret.update({"rows":[]})
|
| 154 |
+
|
| 155 |
+
g.insert(0, [0, 0, new_coords[0], 0, ''])
|
| 156 |
+
g.append([new_coords[2], 0, 0, 0, ''])
|
| 157 |
+
while i < len(w):
|
| 158 |
+
group = [w[i]]
|
| 159 |
+
i += 1
|
| 160 |
+
while i < len(w) and w[i][3] > w[i-1][1]:
|
| 161 |
+
group.append(w[i])
|
| 162 |
+
i += 1
|
| 163 |
+
group = sorted(group, key=itemgetter(0))
|
| 164 |
+
|
| 165 |
+
for a in range(len(group)-1, 0, -1):
|
| 166 |
+
if group[a][0] < group[a-1][2]:
|
| 167 |
+
group[a-1][0] = min(group[a][0], group[a-1][0])
|
| 168 |
+
group[a-1][1] = min(group[a][1], group[a-1][1])
|
| 169 |
+
group[a-1][2] = max(group[a][2], group[a-1][2])
|
| 170 |
+
group[a-1][3] = max(group[a][3], group[a-1][3])
|
| 171 |
+
group[a-1][4] = group[a-1][4].strip() + " " + group[a][4]
|
| 172 |
+
group.pop(a)
|
| 173 |
+
|
| 174 |
+
a = 1
|
| 175 |
+
while a < len(g) - 1:
|
| 176 |
+
if a > len(group):
|
| 177 |
+
group.append([0, 0, 0, 0, '\n'])
|
| 178 |
+
a += 1
|
| 179 |
+
continue
|
| 180 |
+
if group[a-1][0] >= g[a-1][2] and group[a-1][2] <= g[a+1][0]:
|
| 181 |
+
pass
|
| 182 |
+
"""
|
| 183 |
+
if a < len(group) and group[a][0] >= g[a-1][2] and group[a][2] <= g[a+1][0]:
|
| 184 |
+
g.insert(1, [g[0][2], 0, group[a-1][2], 0, ''])
|
| 185 |
+
#ret["columns"].insert(0, '')
|
| 186 |
+
else:
|
| 187 |
+
a += 1
|
| 188 |
+
continue
|
| 189 |
+
"""
|
| 190 |
+
else: group.insert(a-1, [0, 0, 0, 0, '\n'])
|
| 191 |
+
a += 1
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
added_row = []
|
| 195 |
+
for t in group:
|
| 196 |
+
temp_bbox = t[:4]
|
| 197 |
+
if self.output_bbox:
|
| 198 |
+
added_row.append({'text':t[4].strip(), 'bbox':temp_bbox})
|
| 199 |
+
else:
|
| 200 |
+
added_row.append(t[4].strip())
|
| 201 |
+
ret["rows"].append(added_row)
|
| 202 |
+
if ret["rows"] and len(ret["rows"][0]) != len(ret["columns"]):
|
| 203 |
+
ret["columns"] = ret["rows"][0]
|
| 204 |
+
ret["rows"] = ret["rows"][1:]
|
| 205 |
+
for col in ret['columns']:
|
| 206 |
+
tag = 'unknown'
|
| 207 |
+
tagged = False
|
| 208 |
+
for key in self.tagging.keys():
|
| 209 |
+
for word in self.tagging[key]:
|
| 210 |
+
if word in col['text']:
|
| 211 |
+
tag = key
|
| 212 |
+
tagged = True
|
| 213 |
+
break
|
| 214 |
+
if tagged:
|
| 215 |
+
break
|
| 216 |
+
col['tag'] = tag
|
| 217 |
+
|
| 218 |
+
return ret
|
| 219 |
+
|
| 220 |
+
def get_title_and_footnotes(self, tb_coords):
|
| 221 |
+
|
| 222 |
+
for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
|
| 223 |
+
title = (0, 0, 0, 0, '')
|
| 224 |
+
footnote = (0, 0, 0, 0, '')
|
| 225 |
+
title_gap = 30
|
| 226 |
+
footnote_gap = 30
|
| 227 |
+
for element in page_layout:
|
| 228 |
+
if isinstance(element, pdfminer.layout.LTTextBoxHorizontal):
|
| 229 |
+
if (element.bbox[0] >= tb_coords[0] and element.bbox[0] <= tb_coords[2]) or (element.bbox[2] >= tb_coords[0] and element.bbox[2] <= tb_coords[2]) or (tb_coords[0] >= element.bbox[0] and tb_coords[0] <= element.bbox[2]) or (tb_coords[2] >= element.bbox[0] and tb_coords[2] <= element.bbox[2]):
|
| 230 |
+
#print(element)
|
| 231 |
+
if 'Table' in element.get_text():
|
| 232 |
+
if abs(element.bbox[1] - tb_coords[3]) < title_gap:
|
| 233 |
+
title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Table'):].replace('\n', ' '),)
|
| 234 |
+
title_gap = abs(element.bbox[1] - tb_coords[3])
|
| 235 |
+
if 'Scheme' in element.get_text():
|
| 236 |
+
if abs(element.bbox[1] - tb_coords[3]) < title_gap:
|
| 237 |
+
title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Scheme'):].replace('\n', ' '),)
|
| 238 |
+
title_gap = abs(element.bbox[1] - tb_coords[3])
|
| 239 |
+
if element.bbox[1] >= tb_coords[1] and element.bbox[3] <= tb_coords[3]: continue
|
| 240 |
+
#print(element)
|
| 241 |
+
temp = ['aA', 'aB', 'aC', 'aD', 'aE', 'aF', 'aG', 'aH', 'aI', 'aJ', 'aK', 'aL', 'aM', 'aN', 'aO', 'aP', 'aQ', 'aR', 'aS', 'aT', 'aU', 'aV', 'aW', 'aX', 'aY', 'aZ', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'a0']
|
| 242 |
+
for segment in temp:
|
| 243 |
+
if segment in element.get_text():
|
| 244 |
+
if abs(element.bbox[3] - tb_coords[1]) < footnote_gap:
|
| 245 |
+
footnote = tuple(element.bbox) + (element.get_text()[element.get_text().index(segment):].replace('\n', ' '),)
|
| 246 |
+
footnote_gap = abs(element.bbox[3] - tb_coords[1])
|
| 247 |
+
break
|
| 248 |
+
self.title_y = min(title[1], title[3])
|
| 249 |
+
if self.output_bbox:
|
| 250 |
+
return ({'text': title[4], 'bbox': list(title[:4])}, {'text': footnote[4], 'bbox': list(footnote[:4])})
|
| 251 |
+
else:
|
| 252 |
+
return (title[4], footnote[4])
|
| 253 |
+
|
| 254 |
+
def extract_table_information(self):
|
| 255 |
+
#self.run_model(page_info) # changed
|
| 256 |
+
table_coordinates = self.blocks['table'] #should return a list of layout objects
|
| 257 |
+
table_coordinates_in_pdf = self.convert_to_pdf_coordinates('table') #should return a list of lists
|
| 258 |
+
|
| 259 |
+
ans = []
|
| 260 |
+
i = 0
|
| 261 |
+
for coordinate in table_coordinates_in_pdf:
|
| 262 |
+
ret = {}
|
| 263 |
+
pad = 20
|
| 264 |
+
coordinate = [coordinate[0] - pad, coordinate[1], coordinate[2] + pad, coordinate[3]]
|
| 265 |
+
ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
|
| 266 |
+
|
| 267 |
+
table_results = self.extract_singular_table(coordinate)
|
| 268 |
+
tf = self.get_title_and_footnotes(coordinate)
|
| 269 |
+
figure = Image.fromarray(table_coordinates[i].crop_image(self.img))
|
| 270 |
+
ret.update({'title': tf[0]})
|
| 271 |
+
ret.update({'figure': {
|
| 272 |
+
'image': None,
|
| 273 |
+
'bbox': []
|
| 274 |
+
}})
|
| 275 |
+
if self.output_image:
|
| 276 |
+
ret['figure']['image'] = figure
|
| 277 |
+
ret.update({'table': {'bbox': list(coordinate), 'content': table_results}})
|
| 278 |
+
ret.update({'footnote': tf[1]})
|
| 279 |
+
if abs(self.title_y - self.column_header_y) > 50:
|
| 280 |
+
ret['figure']['bbox'] = list(coordinate)
|
| 281 |
+
|
| 282 |
+
ret.update({'page':self.page})
|
| 283 |
+
|
| 284 |
+
ans.append(ret)
|
| 285 |
+
i += 1
|
| 286 |
+
|
| 287 |
+
return ans
|
| 288 |
+
|
| 289 |
+
def extract_figure_information(self):
|
| 290 |
+
figure_coordinates = self.blocks['figure']
|
| 291 |
+
figure_coordinates_in_pdf = self.convert_to_pdf_coordinates('figure')
|
| 292 |
+
|
| 293 |
+
ans = []
|
| 294 |
+
for i in range(len(figure_coordinates)):
|
| 295 |
+
ret = {}
|
| 296 |
+
coordinate = figure_coordinates_in_pdf[i]
|
| 297 |
+
ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
|
| 298 |
+
|
| 299 |
+
tf = self.get_title_and_footnotes(coordinate)
|
| 300 |
+
figure = Image.fromarray(figure_coordinates[i].crop_image(self.img))
|
| 301 |
+
ret.update({'title':tf[0]})
|
| 302 |
+
ret.update({'figure': {
|
| 303 |
+
'image': None,
|
| 304 |
+
'bbox': []
|
| 305 |
+
}})
|
| 306 |
+
if self.output_image:
|
| 307 |
+
ret['figure']['image'] = figure
|
| 308 |
+
ret.update({'table': {
|
| 309 |
+
'bbox': [],
|
| 310 |
+
'content': None
|
| 311 |
+
}})
|
| 312 |
+
ret.update({'footnote': tf[1]})
|
| 313 |
+
ret['figure']['bbox'] = list(coordinate)
|
| 314 |
+
|
| 315 |
+
ret.update({'page':self.page})
|
| 316 |
+
|
| 317 |
+
ans.append(ret)
|
| 318 |
+
|
| 319 |
+
return ans
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def extract_all_tables_and_figures(self, pages, pdfparser, content=None):
|
| 323 |
+
self.model = pdfparser
|
| 324 |
+
ret = []
|
| 325 |
+
for i in range(len(pages)):
|
| 326 |
+
self.set_page_num(i)
|
| 327 |
+
self.run_model(pages[i])
|
| 328 |
+
table_info = self.extract_table_information()
|
| 329 |
+
figure_info = self.extract_figure_information()
|
| 330 |
+
if content == 'tables':
|
| 331 |
+
ret += table_info
|
| 332 |
+
elif content == 'figures':
|
| 333 |
+
ret += figure_info
|
| 334 |
+
for table in table_info:
|
| 335 |
+
if table['figure']['bbox'] != []:
|
| 336 |
+
ret.append(table)
|
| 337 |
+
else:
|
| 338 |
+
ret += table_info
|
| 339 |
+
ret += figure_info
|
| 340 |
+
return ret
|
chemietoolkit/utils.py
ADDED
|
@@ -0,0 +1,1018 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import cv2
|
| 4 |
+
import layoutparser as lp
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
from rdkit.Chem import Draw
|
| 7 |
+
from rdkit.Chem import rdDepictor
|
| 8 |
+
rdDepictor.SetPreferCoordGen(True)
|
| 9 |
+
from rdkit.Chem.Draw import IPythonConsole
|
| 10 |
+
from rdkit.Chem import AllChem
|
| 11 |
+
import re
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
BOND_TO_INT = {
|
| 15 |
+
"": 0,
|
| 16 |
+
"single": 1,
|
| 17 |
+
"double": 2,
|
| 18 |
+
"triple": 3,
|
| 19 |
+
"aromatic": 4,
|
| 20 |
+
"solid wedge": 5,
|
| 21 |
+
"dashed wedge": 6
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12',
|
| 25 |
+
'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', "R'",
|
| 26 |
+
'1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]']
|
| 27 |
+
|
| 28 |
+
RGROUP_SYMBOLS = RGROUP_SYMBOLS + [f'[{i}]' for i in RGROUP_SYMBOLS]
|
| 29 |
+
|
| 30 |
+
RGROUP_SMILES = ['[1*]', '[2*]','[3*]', '[4*]','[5*]', '[6*]','[7*]', '[8*]','[9*]', '[10*]','[11*]', '[12*]','[a*]', '[b*]','[c*]', '[d*]','*', '[Rf]']
|
| 31 |
+
|
| 32 |
+
def get_figures_from_pages(pages, pdfparser):
|
| 33 |
+
figures = []
|
| 34 |
+
for i in range(len(pages)):
|
| 35 |
+
img = np.asarray(pages[i])
|
| 36 |
+
layout = pdfparser.detect(img)
|
| 37 |
+
blocks = lp.Layout([b for b in layout if b.type == "Figure"])
|
| 38 |
+
for block in blocks:
|
| 39 |
+
figure = Image.fromarray(block.crop_image(img))
|
| 40 |
+
figures.append({
|
| 41 |
+
'image': figure,
|
| 42 |
+
'page': i
|
| 43 |
+
})
|
| 44 |
+
return figures
|
| 45 |
+
|
| 46 |
+
def clean_bbox_output(figures, bboxes):
|
| 47 |
+
results = []
|
| 48 |
+
cropped = []
|
| 49 |
+
references = []
|
| 50 |
+
for i, output in enumerate(bboxes):
|
| 51 |
+
mol_bboxes = [elt['bbox'] for elt in output if elt['category'] == '[Mol]']
|
| 52 |
+
mol_scores = [elt['score'] for elt in output if elt['category'] == '[Mol]']
|
| 53 |
+
data = {}
|
| 54 |
+
results.append(data)
|
| 55 |
+
data['image'] = figures[i]
|
| 56 |
+
data['molecules'] = []
|
| 57 |
+
for bbox, score in zip(mol_bboxes, mol_scores):
|
| 58 |
+
x1, y1, x2, y2 = bbox
|
| 59 |
+
height, width, _ = figures[i].shape
|
| 60 |
+
cropped_img = figures[i][int(y1*height):int(y2*height),int(x1*width):int(x2*width)]
|
| 61 |
+
cur_mol = {
|
| 62 |
+
'bbox': bbox,
|
| 63 |
+
'score': score,
|
| 64 |
+
'image': cropped_img,
|
| 65 |
+
#'info': None,
|
| 66 |
+
}
|
| 67 |
+
cropped.append(cropped_img)
|
| 68 |
+
data['molecules'].append(cur_mol)
|
| 69 |
+
references.append(cur_mol)
|
| 70 |
+
return results, cropped, references
|
| 71 |
+
|
| 72 |
+
def convert_to_pil(image):
|
| 73 |
+
if type(image) == np.ndarray:
|
| 74 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 75 |
+
image = Image.fromarray(image)
|
| 76 |
+
return image
|
| 77 |
+
|
| 78 |
+
def convert_to_cv2(image):
|
| 79 |
+
if type(image) != np.ndarray:
|
| 80 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
|
| 81 |
+
return image
|
| 82 |
+
|
| 83 |
+
def replace_rgroups_in_figure(figures, results, coref_results, molscribe, batch_size=16):
|
| 84 |
+
pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
|
| 85 |
+
for figure, result, corefs in zip(figures, results, coref_results):
|
| 86 |
+
r_groups = []
|
| 87 |
+
seen_r_groups = set()
|
| 88 |
+
for bbox in corefs['bboxes']:
|
| 89 |
+
if bbox['category'] == '[Idt]':
|
| 90 |
+
for text in bbox['text']:
|
| 91 |
+
res = pattern.search(text)
|
| 92 |
+
if res is None:
|
| 93 |
+
continue
|
| 94 |
+
name = res.group('name')
|
| 95 |
+
group = res.group('group')
|
| 96 |
+
if (name, group) in seen_r_groups:
|
| 97 |
+
continue
|
| 98 |
+
seen_r_groups.add((name, group))
|
| 99 |
+
r_groups.append({name: res.group('group')})
|
| 100 |
+
if r_groups and result['reactions']:
|
| 101 |
+
seen_r_groups = set([pair[0] for pair in seen_r_groups])
|
| 102 |
+
orig_reaction = result['reactions'][0]
|
| 103 |
+
graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
|
| 104 |
+
relevant_locs = {}
|
| 105 |
+
for i, graph in enumerate(graphs):
|
| 106 |
+
to_add = []
|
| 107 |
+
for j, atom in enumerate(graph['chartok_coords']['symbols']):
|
| 108 |
+
if atom[1:-1] in seen_r_groups:
|
| 109 |
+
to_add.append((atom[1:-1], j))
|
| 110 |
+
relevant_locs[i] = to_add
|
| 111 |
+
|
| 112 |
+
for r_group in r_groups:
|
| 113 |
+
reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_group, molscribe)
|
| 114 |
+
to_add ={
|
| 115 |
+
'reactants': reaction['reactants'][:],
|
| 116 |
+
'conditions': orig_reaction['conditions'][:],
|
| 117 |
+
'products': reaction['products'][:]
|
| 118 |
+
}
|
| 119 |
+
result['reactions'].append(to_add)
|
| 120 |
+
return results
|
| 121 |
+
|
| 122 |
+
def process_tables(figures, results, molscribe, batch_size=16):
|
| 123 |
+
r_group_pattern = re.compile(r'^(\w+-)?(?P<group>[\w-]+)( \(\w+\))?$')
|
| 124 |
+
for figure, result in zip(figures, results):
|
| 125 |
+
result['page'] = figure['page']
|
| 126 |
+
if figure['table']['content'] is not None:
|
| 127 |
+
content = figure['table']['content']
|
| 128 |
+
if len(result['reactions']) > 1:
|
| 129 |
+
print("Warning: multiple reactions detected for table")
|
| 130 |
+
elif len(result['reactions']) == 0:
|
| 131 |
+
continue
|
| 132 |
+
orig_reaction = result['reactions'][0]
|
| 133 |
+
graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
|
| 134 |
+
relevant_locs = find_relevant_groups(graphs, content['columns'])
|
| 135 |
+
conditions_to_extend = []
|
| 136 |
+
for row in content['rows']:
|
| 137 |
+
r_groups = {}
|
| 138 |
+
expanded_conditions = orig_reaction['conditions'][:]
|
| 139 |
+
replaced = False
|
| 140 |
+
for col, entry in zip(content['columns'], row):
|
| 141 |
+
if col['tag'] != 'alkyl group':
|
| 142 |
+
expanded_conditions.append({
|
| 143 |
+
'category': '[Table]',
|
| 144 |
+
'text': entry['text'],
|
| 145 |
+
'tag': col['tag'],
|
| 146 |
+
'header': col['text'],
|
| 147 |
+
})
|
| 148 |
+
else:
|
| 149 |
+
found = r_group_pattern.match(entry['text'])
|
| 150 |
+
if found is not None:
|
| 151 |
+
r_groups[col['text']] = found.group('group')
|
| 152 |
+
replaced = True
|
| 153 |
+
reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_groups, molscribe)
|
| 154 |
+
if replaced:
|
| 155 |
+
to_add = {
|
| 156 |
+
'reactants': reaction['reactants'][:],
|
| 157 |
+
'conditions': expanded_conditions,
|
| 158 |
+
'products': reaction['products'][:]
|
| 159 |
+
}
|
| 160 |
+
result['reactions'].append(to_add)
|
| 161 |
+
else:
|
| 162 |
+
conditions_to_extend.append(expanded_conditions)
|
| 163 |
+
orig_reaction['conditions'] = [orig_reaction['conditions']]
|
| 164 |
+
orig_reaction['conditions'].extend(conditions_to_extend)
|
| 165 |
+
return results
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_atoms_and_bonds(image, reaction, molscribe, batch_size=16):
|
| 169 |
+
image = convert_to_cv2(image)
|
| 170 |
+
cropped_images = []
|
| 171 |
+
results = []
|
| 172 |
+
for key, molecules in reaction.items():
|
| 173 |
+
for i, elt in enumerate(molecules):
|
| 174 |
+
if type(elt) != dict or elt['category'] != '[Mol]':
|
| 175 |
+
continue
|
| 176 |
+
x1, y1, x2, y2 = elt['bbox']
|
| 177 |
+
height, width, _ = image.shape
|
| 178 |
+
cropped_images.append(image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)])
|
| 179 |
+
to_add = {
|
| 180 |
+
'image': cropped_images[-1],
|
| 181 |
+
'chartok_coords': {
|
| 182 |
+
'coords': [],
|
| 183 |
+
'symbols': [],
|
| 184 |
+
},
|
| 185 |
+
'edges': [],
|
| 186 |
+
'key': (key, i)
|
| 187 |
+
}
|
| 188 |
+
results.append(to_add)
|
| 189 |
+
outputs = molscribe.predict_images(cropped_images, return_atoms_bonds=True, batch_size=batch_size)
|
| 190 |
+
for mol, result in zip(outputs, results):
|
| 191 |
+
for atom in mol['atoms']:
|
| 192 |
+
result['chartok_coords']['coords'].append((atom['x'], atom['y']))
|
| 193 |
+
result['chartok_coords']['symbols'].append(atom['atom_symbol'])
|
| 194 |
+
result['edges'] = [[0] * len(mol['atoms']) for _ in range(len(mol['atoms']))]
|
| 195 |
+
for bond in mol['bonds']:
|
| 196 |
+
i, j = bond['endpoint_atoms']
|
| 197 |
+
result['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
|
| 198 |
+
result['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
|
| 199 |
+
return results
|
| 200 |
+
|
| 201 |
+
def find_relevant_groups(graphs, columns):
|
| 202 |
+
results = {}
|
| 203 |
+
r_groups = set([f"[{col['text']}]" for col in columns if col['tag'] == 'alkyl group'])
|
| 204 |
+
for i, graph in enumerate(graphs):
|
| 205 |
+
to_add = []
|
| 206 |
+
for j, atom in enumerate(graph['chartok_coords']['symbols']):
|
| 207 |
+
if atom in r_groups:
|
| 208 |
+
to_add.append((atom[1:-1], j))
|
| 209 |
+
results[i] = to_add
|
| 210 |
+
return results
|
| 211 |
+
|
| 212 |
+
def get_replaced_reaction(orig_reaction, graphs, relevant_locs, mappings, molscribe):
|
| 213 |
+
graph_copy = []
|
| 214 |
+
for graph in graphs:
|
| 215 |
+
graph_copy.append({
|
| 216 |
+
'image': graph['image'],
|
| 217 |
+
'chartok_coords': {
|
| 218 |
+
'coords': graph['chartok_coords']['coords'][:],
|
| 219 |
+
'symbols': graph['chartok_coords']['symbols'][:],
|
| 220 |
+
},
|
| 221 |
+
'edges': graph['edges'][:],
|
| 222 |
+
'key': graph['key'],
|
| 223 |
+
})
|
| 224 |
+
for graph_idx, atoms in relevant_locs.items():
|
| 225 |
+
for atom, atom_idx in atoms:
|
| 226 |
+
if atom in mappings:
|
| 227 |
+
graph_copy[graph_idx]['chartok_coords']['symbols'][atom_idx] = mappings[atom]
|
| 228 |
+
reaction_copy = {}
|
| 229 |
+
def append_copy(copy_list, entity):
|
| 230 |
+
if entity['category'] == '[Mol]':
|
| 231 |
+
copy_list.append({
|
| 232 |
+
k1: v1 for k1, v1 in entity.items()
|
| 233 |
+
})
|
| 234 |
+
else:
|
| 235 |
+
copy_list.append(entity)
|
| 236 |
+
|
| 237 |
+
for k, v in orig_reaction.items():
|
| 238 |
+
reaction_copy[k] = []
|
| 239 |
+
for entity in v:
|
| 240 |
+
if type(entity) == list:
|
| 241 |
+
sub_list = []
|
| 242 |
+
for e in entity:
|
| 243 |
+
append_copy(sub_list, e)
|
| 244 |
+
reaction_copy[k].append(sub_list)
|
| 245 |
+
else:
|
| 246 |
+
append_copy(reaction_copy[k], entity)
|
| 247 |
+
|
| 248 |
+
for graph in graph_copy:
|
| 249 |
+
output = molscribe.convert_graph_to_output([graph], [graph['image']])
|
| 250 |
+
molecule = reaction_copy[graph['key'][0]][graph['key'][1]]
|
| 251 |
+
molecule['smiles'] = output[0]['smiles']
|
| 252 |
+
molecule['molfile'] = output[0]['molfile']
|
| 253 |
+
return reaction_copy
|
| 254 |
+
|
| 255 |
+
def get_sites(tar, ref, ref_site = False):
|
| 256 |
+
rdDepictor.Compute2DCoords(ref)
|
| 257 |
+
rdDepictor.Compute2DCoords(tar)
|
| 258 |
+
idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(tar, ref)
|
| 259 |
+
|
| 260 |
+
in_template = [i[1] for i in idx_pair]
|
| 261 |
+
sites = []
|
| 262 |
+
for i in range(tar.GetNumAtoms()):
|
| 263 |
+
if i not in in_template:
|
| 264 |
+
for j in tar.GetAtomWithIdx(i).GetNeighbors():
|
| 265 |
+
if j.GetIdx() in in_template and j.GetIdx() not in sites:
|
| 266 |
+
|
| 267 |
+
if ref_site: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
|
| 268 |
+
else: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
|
| 269 |
+
return sites
|
| 270 |
+
|
| 271 |
+
def get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = None):
|
| 272 |
+
# returns prod_mol_to_query which is the mapping of atom indices in prod_mol to the atom indices of the molecule represented by prod_smiles
|
| 273 |
+
prod_template_intermediate = Chem.MolToSmiles(prod_mol)
|
| 274 |
+
prod_template = prod_smiles
|
| 275 |
+
|
| 276 |
+
for r in RGROUP_SMILES:
|
| 277 |
+
if r!='*' and r!='(*)':
|
| 278 |
+
prod_template = prod_template.replace(r, '*')
|
| 279 |
+
prod_template_intermediate = prod_template_intermediate.replace(r, '*')
|
| 280 |
+
|
| 281 |
+
prod_template_intermediate_mol = Chem.MolFromSmiles(prod_template_intermediate)
|
| 282 |
+
prod_template_mol = Chem.MolFromSmiles(prod_template)
|
| 283 |
+
|
| 284 |
+
p = Chem.AdjustQueryParameters.NoAdjustments()
|
| 285 |
+
p.makeDummiesQueries = True
|
| 286 |
+
|
| 287 |
+
prod_template_mol_query = Chem.AdjustQueryProperties(prod_template_mol, p)
|
| 288 |
+
prod_template_intermediate_mol_query = Chem.AdjustQueryProperties(prod_template_intermediate_mol, p)
|
| 289 |
+
rdDepictor.Compute2DCoords(prod_mol)
|
| 290 |
+
rdDepictor.Compute2DCoords(prod_template_mol_query)
|
| 291 |
+
rdDepictor.Compute2DCoords(prod_template_intermediate_mol_query)
|
| 292 |
+
idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(prod_mol, prod_template_intermediate_mol_query)
|
| 293 |
+
|
| 294 |
+
intermdiate_to_prod_mol = {a:b for a,b in idx_pair}
|
| 295 |
+
prod_mol_to_intermediate = {b:a for a,b in idx_pair}
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
#idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
|
| 299 |
+
|
| 300 |
+
#intermediate_to_query = {a:b for a,b in idx_pair_2}
|
| 301 |
+
#query_to_intermediate = {b:a for a,b in idx_pair_2}
|
| 302 |
+
|
| 303 |
+
#prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
substructs = prod_template_mol_query.GetSubstructMatches(prod_template_intermediate_mol_query, uniquify = False)
|
| 307 |
+
|
| 308 |
+
#idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
|
| 309 |
+
for substruct in substructs:
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
intermediate_to_query = {a:b for a, b in enumerate(substruct)}
|
| 313 |
+
query_to_intermediate = {intermediate_to_query[i]: i for i in intermediate_to_query}
|
| 314 |
+
|
| 315 |
+
prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
|
| 316 |
+
|
| 317 |
+
good_map = True
|
| 318 |
+
for i in r_sites_reversed:
|
| 319 |
+
if prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[i]).GetSymbol() not in RGROUP_SMILES:
|
| 320 |
+
good_map = False
|
| 321 |
+
if good_map:
|
| 322 |
+
break
|
| 323 |
+
|
| 324 |
+
return prod_mol_to_query, prod_template_mol_query
|
| 325 |
+
|
| 326 |
+
def clean_corefs(coref_results_dict, idx):
|
| 327 |
+
label_pattern = rf'{re.escape(idx)}[a-zA-Z]+'
|
| 328 |
+
#unclean_pattern = re.escape(idx) + r'\d(?![\d% ])'
|
| 329 |
+
toreturn = {}
|
| 330 |
+
for prod in coref_results_dict:
|
| 331 |
+
has_good_label = False
|
| 332 |
+
for parsed in coref_results_dict[prod]:
|
| 333 |
+
if re.search(label_pattern, parsed):
|
| 334 |
+
has_good_label = True
|
| 335 |
+
if not has_good_label:
|
| 336 |
+
for parsed in coref_results_dict[prod]:
|
| 337 |
+
if idx+'1' in parsed:
|
| 338 |
+
coref_results_dict[prod].append(idx+'l')
|
| 339 |
+
elif idx+'0' in parsed:
|
| 340 |
+
coref_results_dict[prod].append(idx+'o')
|
| 341 |
+
elif idx+'5' in parsed:
|
| 342 |
+
coref_results_dict[prod].append(idx+'s')
|
| 343 |
+
elif idx+'9' in parsed:
|
| 344 |
+
coref_results_dict[prod].append(idx+'g')
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe):
|
| 349 |
+
name = res.group('name')
|
| 350 |
+
group = res.group('group')
|
| 351 |
+
#print(other_prod)
|
| 352 |
+
atoms = coref_smiles_to_graphs[other_prod]['atoms']
|
| 353 |
+
bonds = coref_smiles_to_graphs[other_prod]['bonds']
|
| 354 |
+
|
| 355 |
+
#print(atoms, bonds)
|
| 356 |
+
|
| 357 |
+
graph = {
|
| 358 |
+
'image': None,
|
| 359 |
+
'chartok_coords': {
|
| 360 |
+
'coords': [],
|
| 361 |
+
'symbols': [],
|
| 362 |
+
},
|
| 363 |
+
'edges': [],
|
| 364 |
+
'key': None
|
| 365 |
+
}
|
| 366 |
+
for atom in atoms:
|
| 367 |
+
graph['chartok_coords']['coords'].append((atom['x'], atom['y']))
|
| 368 |
+
graph['chartok_coords']['symbols'].append(atom['atom_symbol'])
|
| 369 |
+
graph['edges'] = [[0] * len(atoms) for _ in range(len(atoms))]
|
| 370 |
+
for bond in bonds:
|
| 371 |
+
i, j = bond['endpoint_atoms']
|
| 372 |
+
graph['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
|
| 373 |
+
graph['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
|
| 374 |
+
for i, symbol in enumerate(graph['chartok_coords']['symbols']):
|
| 375 |
+
if symbol[1:-1] == name:
|
| 376 |
+
graph['chartok_coords']['symbols'][i] = group
|
| 377 |
+
|
| 378 |
+
#print(graph)
|
| 379 |
+
o = molscribe.convert_graph_to_output([graph], [graph['image']])
|
| 380 |
+
return Chem.MolFromSmiles(o[0]['smiles'])
|
| 381 |
+
|
| 382 |
+
def get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn):
|
| 383 |
+
prod_template_mol_query, r_sites_reversed_new, h_sites, num_r_groups = query
|
| 384 |
+
# we get the substruct matches. note that we set uniquify to false since the order matters for our method
|
| 385 |
+
substructs = other_prod_mol.GetSubstructMatches(prod_template_mol_query, uniquify = False)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
#for r in r_sites_reversed:
|
| 389 |
+
# print(prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[r]).GetSymbol())
|
| 390 |
+
|
| 391 |
+
# for each substruct we create the mapping of the substruct onto the other_mol
|
| 392 |
+
# delete all the molecules in other_mol correspond to the substruct
|
| 393 |
+
# and check if they number of mol frags is equal to number of r groups
|
| 394 |
+
# we do this to make sure we have the correct substruct
|
| 395 |
+
if len(substructs) >= 1:
|
| 396 |
+
for substruct in substructs:
|
| 397 |
+
|
| 398 |
+
query_to_other = {a:b for a,b in enumerate(substruct)}
|
| 399 |
+
other_to_query = {query_to_other[i]:i for i in query_to_other}
|
| 400 |
+
|
| 401 |
+
editable = Chem.EditableMol(other_prod_mol)
|
| 402 |
+
r_site_correspondence = []
|
| 403 |
+
for r in r_sites_reversed_new:
|
| 404 |
+
#get its id in substruct
|
| 405 |
+
substruct_id = query_to_other[r]
|
| 406 |
+
r_site_correspondence.append([substruct_id, r_sites_reversed_new[r]])
|
| 407 |
+
|
| 408 |
+
for idx in tuple(sorted(substruct, reverse = True)):
|
| 409 |
+
if idx not in [query_to_other[i] for i in r_sites_reversed_new]:
|
| 410 |
+
editable.RemoveAtom(idx)
|
| 411 |
+
for r_site in r_site_correspondence:
|
| 412 |
+
if idx < r_site[0]:
|
| 413 |
+
r_site[0]-=1
|
| 414 |
+
other_prod_removed = editable.GetMol()
|
| 415 |
+
|
| 416 |
+
if len(Chem.GetMolFrags(other_prod_removed, asMols = False)) == num_r_groups:
|
| 417 |
+
break
|
| 418 |
+
|
| 419 |
+
# need to compute the sites at which correspond to each r_site_reversed
|
| 420 |
+
|
| 421 |
+
r_site_correspondence.sort(key = lambda x: x[0])
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
f = []
|
| 425 |
+
ff = []
|
| 426 |
+
frags = Chem.GetMolFrags(other_prod_removed, asMols = True, frags = f, fragsMolAtomMapping = ff)
|
| 427 |
+
|
| 428 |
+
# r_group_information maps r group name --> the fragment/molcule corresponding to the r group and the atom index it should be connected at
|
| 429 |
+
r_group_information = {}
|
| 430 |
+
#tosubtract = 0
|
| 431 |
+
for idx, r_site in enumerate(r_site_correspondence):
|
| 432 |
+
|
| 433 |
+
r_group_information[r_site[1]]= (frags[f[r_site[0]]], ff[f[r_site[0]]].index(r_site[0]))
|
| 434 |
+
#tosubtract += len(ff[idx])
|
| 435 |
+
for r_site in h_sites:
|
| 436 |
+
r_group_information[r_site] = (Chem.MolFromSmiles('[H]'), 0)
|
| 437 |
+
|
| 438 |
+
# now we modify all of the reactants according to the R groups we have found
|
| 439 |
+
# for every reactant we disconnect its r group symbol, and connect it to the r group
|
| 440 |
+
modify_reactants = copy.deepcopy(reactant_mols)
|
| 441 |
+
modified_reactant_smiles = []
|
| 442 |
+
for reactant_idx in reactant_information:
|
| 443 |
+
if len(reactant_information[reactant_idx]) == 0:
|
| 444 |
+
modified_reactant_smiles.append(Chem.MolToSmiles(modify_reactants[reactant_idx]))
|
| 445 |
+
else:
|
| 446 |
+
combined = reactant_mols[reactant_idx]
|
| 447 |
+
if combined.GetNumAtoms() == 1:
|
| 448 |
+
r_group, _, _ = reactant_information[reactant_idx][0]
|
| 449 |
+
modified_reactant_smiles.append(Chem.MolToSmiles(r_group_information[r_group][0]))
|
| 450 |
+
else:
|
| 451 |
+
for r_group, r_index, connect_index in reactant_information[reactant_idx]:
|
| 452 |
+
combined = Chem.CombineMols(combined, r_group_information[r_group][0])
|
| 453 |
+
|
| 454 |
+
editable = Chem.EditableMol(combined)
|
| 455 |
+
atomIdxAdder = reactant_mols[reactant_idx].GetNumAtoms()
|
| 456 |
+
for r_group, r_index, connect_index in reactant_information[reactant_idx]:
|
| 457 |
+
Chem.EditableMol.RemoveBond(editable, r_index, connect_index)
|
| 458 |
+
Chem.EditableMol.AddBond(editable, connect_index, atomIdxAdder + r_group_information[r_group][1], Chem.BondType.SINGLE)
|
| 459 |
+
atomIdxAdder += r_group_information[r_group][0].GetNumAtoms()
|
| 460 |
+
r_indices = [i[1] for i in reactant_information[reactant_idx]]
|
| 461 |
+
|
| 462 |
+
r_indices.sort(reverse = True)
|
| 463 |
+
|
| 464 |
+
for r_index in r_indices:
|
| 465 |
+
Chem.EditableMol.RemoveAtom(editable, r_index)
|
| 466 |
+
|
| 467 |
+
modified_reactant_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(editable.GetMol()))))
|
| 468 |
+
|
| 469 |
+
toreturn.append((modified_reactant_smiles, [Chem.MolToSmiles(other_prod_mol)], parsed))
|
| 470 |
+
return True
|
| 471 |
+
else:
|
| 472 |
+
return False
|
| 473 |
+
|
| 474 |
+
def query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups):
|
| 475 |
+
subsets = generate_subsets(num_r_groups)
|
| 476 |
+
|
| 477 |
+
toreturn = []
|
| 478 |
+
|
| 479 |
+
for subset in subsets:
|
| 480 |
+
r_sites_list = [[i, r_sites_reversed_new[i]] for i in r_sites_reversed_new]
|
| 481 |
+
r_sites_list.sort(key = lambda x: x[0])
|
| 482 |
+
to_edit = Chem.EditableMol(prod_template_mol_query)
|
| 483 |
+
|
| 484 |
+
for entry in subset:
|
| 485 |
+
pos = r_sites_list[entry][0]
|
| 486 |
+
Chem.EditableMol.RemoveBond(to_edit, r_sites_list[entry][0], prod_template_mol_query.GetAtomWithIdx(r_sites_list[entry][0]).GetNeighbors()[0].GetIdx())
|
| 487 |
+
for entry in subset:
|
| 488 |
+
pos = r_sites_list[entry][0]
|
| 489 |
+
Chem.EditableMol.RemoveAtom(to_edit, pos)
|
| 490 |
+
|
| 491 |
+
edited = to_edit.GetMol()
|
| 492 |
+
for entry in subset:
|
| 493 |
+
for i in range(entry + 1, num_r_groups):
|
| 494 |
+
r_sites_list[i][0]-=1
|
| 495 |
+
|
| 496 |
+
new_r_sites = {}
|
| 497 |
+
new_h_sites = set()
|
| 498 |
+
for i in range(num_r_groups):
|
| 499 |
+
if i not in subset:
|
| 500 |
+
new_r_sites[r_sites_list[i][0]] = r_sites_list[i][1]
|
| 501 |
+
else:
|
| 502 |
+
new_h_sites.add(r_sites_list[i][1])
|
| 503 |
+
toreturn.append((edited, new_r_sites, new_h_sites, num_r_groups - len(subset)))
|
| 504 |
+
return toreturn
|
| 505 |
+
|
| 506 |
+
def generate_subsets(n):
|
| 507 |
+
def backtrack(start, subset):
|
| 508 |
+
result.append(subset[:])
|
| 509 |
+
for i in range(start, -1, -1): # Iterate in reverse order
|
| 510 |
+
subset.append(i)
|
| 511 |
+
backtrack(i - 1, subset)
|
| 512 |
+
subset.pop()
|
| 513 |
+
|
| 514 |
+
result = []
|
| 515 |
+
backtrack(n - 1, [])
|
| 516 |
+
return sorted(result, key=lambda x: (-len(x), x), reverse=True)
|
| 517 |
+
|
| 518 |
+
def backout(results, coref_results, molscribe):
|
| 519 |
+
|
| 520 |
+
toreturn = []
|
| 521 |
+
|
| 522 |
+
if not results or not results[0]['reactions'] or not coref_results:
|
| 523 |
+
return toreturn
|
| 524 |
+
|
| 525 |
+
try:
|
| 526 |
+
reactants = results[0]['reactions'][0]['reactants']
|
| 527 |
+
products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
|
| 528 |
+
coref_results_dict = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[1]]['text'] for coref in coref_results[0]['corefs']}
|
| 529 |
+
coref_smiles_to_graphs = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[0]] for coref in coref_results[0]['corefs']}
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
if len(products) == 1:
|
| 533 |
+
if products[0] not in coref_results_dict:
|
| 534 |
+
print("Warning: No Label Parsed")
|
| 535 |
+
return
|
| 536 |
+
product_labels = coref_results_dict[products[0]]
|
| 537 |
+
prod = products[0]
|
| 538 |
+
label_idx = product_labels[0]
|
| 539 |
+
'''
|
| 540 |
+
if len(product_labels) == 1:
|
| 541 |
+
# get the coreference label of the product molecule
|
| 542 |
+
label_idx = product_labels[0]
|
| 543 |
+
else:
|
| 544 |
+
print("Warning: Malformed Label Parsed.")
|
| 545 |
+
return
|
| 546 |
+
'''
|
| 547 |
+
else:
|
| 548 |
+
print("Warning: More than one product detected")
|
| 549 |
+
return
|
| 550 |
+
|
| 551 |
+
# format the regular expression for labels that correspond to the product label
|
| 552 |
+
numbers = re.findall(r'\d+', label_idx)
|
| 553 |
+
label_idx = numbers[0] if len(numbers) > 0 else ""
|
| 554 |
+
label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
prod_smiles = prod
|
| 558 |
+
prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
|
| 559 |
+
|
| 560 |
+
# identify the atom indices of the R groups in the product tempalte
|
| 561 |
+
h_counter = 0
|
| 562 |
+
r_sites = {}
|
| 563 |
+
for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
|
| 564 |
+
sym = atom['atom_symbol']
|
| 565 |
+
if sym == '[H]':
|
| 566 |
+
h_counter += 1
|
| 567 |
+
if sym[0] == '[':
|
| 568 |
+
sym = sym[1:-1]
|
| 569 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
| 570 |
+
sym = sym[1:]+"*"
|
| 571 |
+
sym = f'[{sym}]'
|
| 572 |
+
if sym in RGROUP_SYMBOLS:
|
| 573 |
+
if sym not in r_sites:
|
| 574 |
+
r_sites[sym] = [idx-h_counter]
|
| 575 |
+
else:
|
| 576 |
+
r_sites[sym].append(idx-h_counter)
|
| 577 |
+
|
| 578 |
+
r_sites_reversed = {}
|
| 579 |
+
for sym in r_sites:
|
| 580 |
+
for pos in r_sites[sym]:
|
| 581 |
+
r_sites_reversed[pos] = sym
|
| 582 |
+
|
| 583 |
+
num_r_groups = len(r_sites_reversed)
|
| 584 |
+
|
| 585 |
+
#prepare the product template and get the associated mapping
|
| 586 |
+
|
| 587 |
+
prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
|
| 588 |
+
|
| 589 |
+
reactant_mols = []
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
#--------------process the reactants-----------------
|
| 593 |
+
|
| 594 |
+
reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
|
| 595 |
+
|
| 596 |
+
for idx, reactant in enumerate(reactants):
|
| 597 |
+
reactant_information[idx] = []
|
| 598 |
+
reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
|
| 599 |
+
has_r = False
|
| 600 |
+
|
| 601 |
+
r_sites_reactant = {}
|
| 602 |
+
|
| 603 |
+
h_counter = 0
|
| 604 |
+
|
| 605 |
+
for a_idx, atom in enumerate(reactant['atoms']):
|
| 606 |
+
|
| 607 |
+
#go through all atoms and check if they are an R group, if so add it to reactant information
|
| 608 |
+
sym = atom['atom_symbol']
|
| 609 |
+
if sym == '[H]':
|
| 610 |
+
h_counter += 1
|
| 611 |
+
if sym[0] == '[':
|
| 612 |
+
sym = sym[1:-1]
|
| 613 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
| 614 |
+
sym = sym[1:]+"*"
|
| 615 |
+
sym = f'[{sym}]'
|
| 616 |
+
if sym in r_sites:
|
| 617 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 618 |
+
reactant_information[idx].append([sym, -1, -1])
|
| 619 |
+
else:
|
| 620 |
+
has_r = True
|
| 621 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 622 |
+
reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 623 |
+
r_sites_reactant[sym] = a_idx-h_counter
|
| 624 |
+
elif sym == '[1*]' and '[7*]' in r_sites:
|
| 625 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 626 |
+
reactant_information[idx].append(['[7*]', -1, -1])
|
| 627 |
+
else:
|
| 628 |
+
has_r = True
|
| 629 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 630 |
+
reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 631 |
+
r_sites_reactant['[7*]'] = a_idx-h_counter
|
| 632 |
+
elif sym == '[7*]' and '[1*]' in r_sites:
|
| 633 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 634 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
| 635 |
+
else:
|
| 636 |
+
has_r = True
|
| 637 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 638 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 639 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
elif sym == '[1*]' and '[Rf]' in r_sites:
|
| 644 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 645 |
+
reactant_information[idx].append(['[Rf]', -1, -1])
|
| 646 |
+
else:
|
| 647 |
+
has_r = True
|
| 648 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 649 |
+
reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 650 |
+
r_sites_reactant['[Rf]'] = a_idx-h_counter
|
| 651 |
+
|
| 652 |
+
elif sym == '[Rf]' and '[1*]' in r_sites:
|
| 653 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 654 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
| 655 |
+
else:
|
| 656 |
+
has_r = True
|
| 657 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 658 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 659 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
|
| 663 |
+
# if the reactant had r groups, we had to use the molecule generated from the MolBlock.
|
| 664 |
+
# but the molblock may have unexpanded elemeents that are not R groups
|
| 665 |
+
# so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
|
| 666 |
+
# and adjust the indices of the r groups accordingly
|
| 667 |
+
if has_r:
|
| 668 |
+
#get the mapping
|
| 669 |
+
reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
|
| 670 |
+
|
| 671 |
+
#make the adjustment
|
| 672 |
+
for info in reactant_information[idx]:
|
| 673 |
+
info[1] = reactant_mol_to_query[info[1]]
|
| 674 |
+
info[2] = reactant_mol_to_query[info[2]]
|
| 675 |
+
reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
|
| 676 |
+
|
| 677 |
+
#go through all the molecules in the coreference
|
| 678 |
+
|
| 679 |
+
clean_corefs(coref_results_dict, label_idx)
|
| 680 |
+
|
| 681 |
+
for other_prod in coref_results_dict:
|
| 682 |
+
|
| 683 |
+
#check if they match the product label regex
|
| 684 |
+
found_good_label = False
|
| 685 |
+
for parsed in coref_results_dict[other_prod]:
|
| 686 |
+
if re.search(label_pattern, parsed) and not found_good_label:
|
| 687 |
+
found_good_label = True
|
| 688 |
+
other_prod_mol = Chem.MolFromSmiles(other_prod)
|
| 689 |
+
|
| 690 |
+
if other_prod != prod_smiles and other_prod_mol is not None:
|
| 691 |
+
|
| 692 |
+
#check if there are R groups to be resolved in the target product
|
| 693 |
+
|
| 694 |
+
all_other_prod_mols = []
|
| 695 |
+
|
| 696 |
+
r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
|
| 697 |
+
|
| 698 |
+
for parsed_labels in coref_results_dict[other_prod]:
|
| 699 |
+
res = r_group_sub_pattern.search(parsed_labels)
|
| 700 |
+
|
| 701 |
+
if res is not None:
|
| 702 |
+
all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
|
| 703 |
+
|
| 704 |
+
if len(all_other_prod_mols) == 0:
|
| 705 |
+
if other_prod_mol is not None:
|
| 706 |
+
all_other_prod_mols.append((other_prod_mol, parsed))
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
for other_prod_mol, parsed in all_other_prod_mols:
|
| 712 |
+
|
| 713 |
+
other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
|
| 714 |
+
|
| 715 |
+
for other_prod_frag in other_prod_frags:
|
| 716 |
+
substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
|
| 717 |
+
|
| 718 |
+
if len(substructs)>0:
|
| 719 |
+
other_prod_mol = other_prod_frag
|
| 720 |
+
break
|
| 721 |
+
r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
|
| 722 |
+
|
| 723 |
+
queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
|
| 724 |
+
|
| 725 |
+
matched = False
|
| 726 |
+
|
| 727 |
+
for query in queries:
|
| 728 |
+
if not matched:
|
| 729 |
+
try:
|
| 730 |
+
matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
|
| 731 |
+
except:
|
| 732 |
+
pass
|
| 733 |
+
|
| 734 |
+
except:
|
| 735 |
+
pass
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
return toreturn
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def backout_without_coref(results, coref_results, coref_results_dict, coref_smiles_to_graphs, molscribe):
|
| 742 |
+
|
| 743 |
+
toreturn = []
|
| 744 |
+
|
| 745 |
+
if not results or not results[0]['reactions'] or not coref_results:
|
| 746 |
+
return toreturn
|
| 747 |
+
|
| 748 |
+
try:
|
| 749 |
+
reactants = results[0]['reactions'][0]['reactants']
|
| 750 |
+
products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
|
| 751 |
+
coref_results_dict = coref_results_dict
|
| 752 |
+
coref_smiles_to_graphs = coref_smiles_to_graphs
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
if len(products) == 1:
|
| 756 |
+
if products[0] not in coref_results_dict:
|
| 757 |
+
print("Warning: No Label Parsed")
|
| 758 |
+
return
|
| 759 |
+
product_labels = coref_results_dict[products[0]]
|
| 760 |
+
prod = products[0]
|
| 761 |
+
label_idx = product_labels[0]
|
| 762 |
+
'''
|
| 763 |
+
if len(product_labels) == 1:
|
| 764 |
+
# get the coreference label of the product molecule
|
| 765 |
+
label_idx = product_labels[0]
|
| 766 |
+
else:
|
| 767 |
+
print("Warning: Malformed Label Parsed.")
|
| 768 |
+
return
|
| 769 |
+
'''
|
| 770 |
+
else:
|
| 771 |
+
print("Warning: More than one product detected")
|
| 772 |
+
return
|
| 773 |
+
|
| 774 |
+
# format the regular expression for labels that correspond to the product label
|
| 775 |
+
numbers = re.findall(r'\d+', label_idx)
|
| 776 |
+
label_idx = numbers[0] if len(numbers) > 0 else ""
|
| 777 |
+
label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
prod_smiles = prod
|
| 781 |
+
prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
|
| 782 |
+
|
| 783 |
+
# identify the atom indices of the R groups in the product tempalte
|
| 784 |
+
h_counter = 0
|
| 785 |
+
r_sites = {}
|
| 786 |
+
for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
|
| 787 |
+
sym = atom['atom_symbol']
|
| 788 |
+
if sym == '[H]':
|
| 789 |
+
h_counter += 1
|
| 790 |
+
if sym[0] == '[':
|
| 791 |
+
sym = sym[1:-1]
|
| 792 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
| 793 |
+
sym = sym[1:]+"*"
|
| 794 |
+
sym = f'[{sym}]'
|
| 795 |
+
if sym in RGROUP_SYMBOLS:
|
| 796 |
+
if sym not in r_sites:
|
| 797 |
+
r_sites[sym] = [idx-h_counter]
|
| 798 |
+
else:
|
| 799 |
+
r_sites[sym].append(idx-h_counter)
|
| 800 |
+
|
| 801 |
+
r_sites_reversed = {}
|
| 802 |
+
for sym in r_sites:
|
| 803 |
+
for pos in r_sites[sym]:
|
| 804 |
+
r_sites_reversed[pos] = sym
|
| 805 |
+
|
| 806 |
+
num_r_groups = len(r_sites_reversed)
|
| 807 |
+
|
| 808 |
+
#prepare the product template and get the associated mapping
|
| 809 |
+
|
| 810 |
+
prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
|
| 811 |
+
|
| 812 |
+
reactant_mols = []
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
#--------------process the reactants-----------------
|
| 816 |
+
|
| 817 |
+
reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
|
| 818 |
+
|
| 819 |
+
for idx, reactant in enumerate(reactants):
|
| 820 |
+
reactant_information[idx] = []
|
| 821 |
+
reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
|
| 822 |
+
has_r = False
|
| 823 |
+
|
| 824 |
+
r_sites_reactant = {}
|
| 825 |
+
|
| 826 |
+
h_counter = 0
|
| 827 |
+
|
| 828 |
+
for a_idx, atom in enumerate(reactant['atoms']):
|
| 829 |
+
|
| 830 |
+
#go through all atoms and check if they are an R group, if so add it to reactant information
|
| 831 |
+
sym = atom['atom_symbol']
|
| 832 |
+
if sym == '[H]':
|
| 833 |
+
h_counter += 1
|
| 834 |
+
if sym[0] == '[':
|
| 835 |
+
sym = sym[1:-1]
|
| 836 |
+
if sym[0] == 'R' and sym[1:].isdigit():
|
| 837 |
+
sym = sym[1:]+"*"
|
| 838 |
+
sym = f'[{sym}]'
|
| 839 |
+
if sym in r_sites:
|
| 840 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 841 |
+
reactant_information[idx].append([sym, -1, -1])
|
| 842 |
+
else:
|
| 843 |
+
has_r = True
|
| 844 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 845 |
+
reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 846 |
+
r_sites_reactant[sym] = a_idx-h_counter
|
| 847 |
+
elif sym == '[1*]' and '[7*]' in r_sites:
|
| 848 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 849 |
+
reactant_information[idx].append(['[7*]', -1, -1])
|
| 850 |
+
else:
|
| 851 |
+
has_r = True
|
| 852 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 853 |
+
reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 854 |
+
r_sites_reactant['[7*]'] = a_idx-h_counter
|
| 855 |
+
elif sym == '[7*]' and '[1*]' in r_sites:
|
| 856 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 857 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
| 858 |
+
else:
|
| 859 |
+
has_r = True
|
| 860 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 861 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 862 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
| 863 |
+
|
| 864 |
+
elif sym == '[1*]' and '[Rf]' in r_sites:
|
| 865 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 866 |
+
reactant_information[idx].append(['[Rf]', -1, -1])
|
| 867 |
+
else:
|
| 868 |
+
has_r = True
|
| 869 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 870 |
+
reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 871 |
+
r_sites_reactant['[Rf]'] = a_idx-h_counter
|
| 872 |
+
|
| 873 |
+
elif sym == '[Rf]' and '[1*]' in r_sites:
|
| 874 |
+
if reactant_mols[-1].GetNumAtoms()==1:
|
| 875 |
+
reactant_information[idx].append(['[1*]', -1, -1])
|
| 876 |
+
else:
|
| 877 |
+
has_r = True
|
| 878 |
+
reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
|
| 879 |
+
reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
|
| 880 |
+
r_sites_reactant['[1*]'] = a_idx-h_counter
|
| 881 |
+
|
| 882 |
+
r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
|
| 883 |
+
# if the reactant had r groups, we had to use the molecule generated from the MolBlock.
|
| 884 |
+
# but the molblock may have unexpanded elemeents that are not R groups
|
| 885 |
+
# so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
|
| 886 |
+
# and adjust the indices of the r groups accordingly
|
| 887 |
+
if has_r:
|
| 888 |
+
#get the mapping
|
| 889 |
+
reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
|
| 890 |
+
|
| 891 |
+
#make the adjustment
|
| 892 |
+
for info in reactant_information[idx]:
|
| 893 |
+
info[1] = reactant_mol_to_query[info[1]]
|
| 894 |
+
info[2] = reactant_mol_to_query[info[2]]
|
| 895 |
+
reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
|
| 896 |
+
|
| 897 |
+
#go through all the molecules in the coreference
|
| 898 |
+
|
| 899 |
+
clean_corefs(coref_results_dict, label_idx)
|
| 900 |
+
|
| 901 |
+
for other_prod in coref_results_dict:
|
| 902 |
+
|
| 903 |
+
#check if they match the product label regex
|
| 904 |
+
found_good_label = False
|
| 905 |
+
for parsed in coref_results_dict[other_prod]:
|
| 906 |
+
if re.search(label_pattern, parsed) and not found_good_label:
|
| 907 |
+
found_good_label = True
|
| 908 |
+
other_prod_mol = Chem.MolFromSmiles(other_prod)
|
| 909 |
+
|
| 910 |
+
if other_prod != prod_smiles and other_prod_mol is not None:
|
| 911 |
+
|
| 912 |
+
#check if there are R groups to be resolved in the target product
|
| 913 |
+
|
| 914 |
+
all_other_prod_mols = []
|
| 915 |
+
|
| 916 |
+
r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
|
| 917 |
+
|
| 918 |
+
for parsed_labels in coref_results_dict[other_prod]:
|
| 919 |
+
res = r_group_sub_pattern.search(parsed_labels)
|
| 920 |
+
|
| 921 |
+
if res is not None:
|
| 922 |
+
all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
|
| 923 |
+
|
| 924 |
+
if len(all_other_prod_mols) == 0:
|
| 925 |
+
if other_prod_mol is not None:
|
| 926 |
+
all_other_prod_mols.append((other_prod_mol, parsed))
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
for other_prod_mol, parsed in all_other_prod_mols:
|
| 932 |
+
|
| 933 |
+
other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
|
| 934 |
+
|
| 935 |
+
for other_prod_frag in other_prod_frags:
|
| 936 |
+
substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
|
| 937 |
+
|
| 938 |
+
if len(substructs)>0:
|
| 939 |
+
other_prod_mol = other_prod_frag
|
| 940 |
+
break
|
| 941 |
+
r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
|
| 942 |
+
|
| 943 |
+
queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
|
| 944 |
+
|
| 945 |
+
matched = False
|
| 946 |
+
|
| 947 |
+
for query in queries:
|
| 948 |
+
if not matched:
|
| 949 |
+
try:
|
| 950 |
+
matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
|
| 951 |
+
except:
|
| 952 |
+
pass
|
| 953 |
+
|
| 954 |
+
except:
|
| 955 |
+
pass
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
return toreturn
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def associate_corefs(results, results_coref):
|
| 963 |
+
coref_smiles = {}
|
| 964 |
+
idx_pattern = r'\b\d+[a-zA-Z]{0,2}\b'
|
| 965 |
+
for result_coref in results_coref:
|
| 966 |
+
bboxes, corefs = result_coref['bboxes'], result_coref['corefs']
|
| 967 |
+
for coref in corefs:
|
| 968 |
+
mol, idt = coref[0], coref[1]
|
| 969 |
+
if len(bboxes[idt]['text']) > 0:
|
| 970 |
+
for text in bboxes[idt]['text']:
|
| 971 |
+
matches = re.findall(idx_pattern, text)
|
| 972 |
+
for match in matches:
|
| 973 |
+
coref_smiles[match] = bboxes[mol]['smiles']
|
| 974 |
+
|
| 975 |
+
for page in results:
|
| 976 |
+
for reactions in page['reactions']:
|
| 977 |
+
for reaction in reactions['reactions']:
|
| 978 |
+
if 'Reactants' in reaction:
|
| 979 |
+
if isinstance(reaction['Reactants'], tuple):
|
| 980 |
+
if reaction['Reactants'][0] in coref_smiles:
|
| 981 |
+
reaction['Reactants'] = (f'{reaction["Reactants"][0]} ({coref_smiles[reaction["Reactants"][0]]})', reaction['Reactants'][1], reaction['Reactants'][2])
|
| 982 |
+
else:
|
| 983 |
+
for idx, compound in enumerate(reaction['Reactants']):
|
| 984 |
+
if compound[0] in coref_smiles:
|
| 985 |
+
reaction['Reactants'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
|
| 986 |
+
if 'Product' in reaction:
|
| 987 |
+
if isinstance(reaction['Product'], tuple):
|
| 988 |
+
if reaction['Product'][0] in coref_smiles:
|
| 989 |
+
reaction['Product'] = (f'{reaction["Product"][0]} ({coref_smiles[reaction["Product"][0]]})', reaction['Product'][1], reaction['Product'][2])
|
| 990 |
+
else:
|
| 991 |
+
for idx, compound in enumerate(reaction['Product']):
|
| 992 |
+
if compound[0] in coref_smiles:
|
| 993 |
+
reaction['Product'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
|
| 994 |
+
return results
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def expand_reactions_with_backout(initial_results, results_coref, molscribe):
|
| 998 |
+
idx_pattern = r'^\d+[a-zA-Z]{0,2}$'
|
| 999 |
+
for reactions, result_coref in zip(initial_results, results_coref):
|
| 1000 |
+
if not reactions['reactions']:
|
| 1001 |
+
continue
|
| 1002 |
+
try:
|
| 1003 |
+
backout_results = backout([reactions], [result_coref], molscribe)
|
| 1004 |
+
except Exception:
|
| 1005 |
+
continue
|
| 1006 |
+
conditions = reactions['reactions'][0]['conditions']
|
| 1007 |
+
idt_to_smiles = {}
|
| 1008 |
+
if not backout_results:
|
| 1009 |
+
continue
|
| 1010 |
+
|
| 1011 |
+
for reactants, products, idt in backout_results:
|
| 1012 |
+
reactions['reactions'].append({
|
| 1013 |
+
'reactants': [{'category': '[Mol]', 'molfile': None, 'smiles': reactant} for reactant in reactants],
|
| 1014 |
+
'conditions': conditions[:],
|
| 1015 |
+
'products': [{'category': '[Mol]', 'molfile': None, 'smiles': product} for product in products]
|
| 1016 |
+
})
|
| 1017 |
+
return initial_results
|
| 1018 |
+
|
examples/exp.png
ADDED
|
Git LFS Details
|
examples/image.webp
ADDED
|
examples/molecules1.png
ADDED
|
examples/molecules2.png
ADDED
|
examples/rdkit.png
ADDED
|
examples/reaction0.png
ADDED
|
Git LFS Details
|
examples/reaction1.jpg
ADDED
|
examples/reaction2.png
ADDED
|
examples/reaction3.png
ADDED
|
examples/reaction4.png
ADDED
|
Git LFS Details
|
examples/reaction5.png
ADDED
|