CYF200127 commited on
Commit
3edb646
·
verified ·
1 Parent(s): 4ff6287

Upload 164 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. README.md +80 -14
  3. __init__.py +3 -0
  4. __pycache__/get_R_group_sub_agent.cpython-310.pyc +0 -0
  5. __pycache__/get_molecular_agent.cpython-310.pyc +0 -0
  6. __pycache__/get_reaction_agent.cpython-310.pyc +0 -0
  7. __pycache__/main.cpython-310.pyc +0 -0
  8. app.ipynb +251 -0
  9. app.py +202 -0
  10. apt.txt +2 -0
  11. chemiener/__init__.py +1 -0
  12. chemiener/__pycache__/__init__.cpython-310.pyc +0 -0
  13. chemiener/__pycache__/__init__.cpython-38.pyc +0 -0
  14. chemiener/__pycache__/dataset.cpython-310.pyc +0 -0
  15. chemiener/__pycache__/dataset.cpython-38.pyc +0 -0
  16. chemiener/__pycache__/interface.cpython-310.pyc +0 -0
  17. chemiener/__pycache__/interface.cpython-38.pyc +0 -0
  18. chemiener/__pycache__/model.cpython-310.pyc +0 -0
  19. chemiener/__pycache__/model.cpython-38.pyc +0 -0
  20. chemiener/__pycache__/utils.cpython-310.pyc +0 -0
  21. chemiener/__pycache__/utils.cpython-38.pyc +0 -0
  22. chemiener/dataset.py +172 -0
  23. chemiener/interface.py +124 -0
  24. chemiener/main.py +345 -0
  25. chemiener/model.py +14 -0
  26. chemiener/utils.py +23 -0
  27. chemietoolkit/__init__.py +1 -0
  28. chemietoolkit/__pycache__/__init__.cpython-310.pyc +0 -0
  29. chemietoolkit/__pycache__/__init__.cpython-38.pyc +0 -0
  30. chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc +0 -0
  31. chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc +0 -0
  32. chemietoolkit/__pycache__/interface.cpython-310.pyc +0 -0
  33. chemietoolkit/__pycache__/interface.cpython-38.pyc +0 -0
  34. chemietoolkit/__pycache__/tableextractor.cpython-310.pyc +0 -0
  35. chemietoolkit/__pycache__/utils.cpython-310.pyc +0 -0
  36. chemietoolkit/chemrxnextractor.py +107 -0
  37. chemietoolkit/interface.py +749 -0
  38. chemietoolkit/tableextractor.py +340 -0
  39. chemietoolkit/utils.py +1018 -0
  40. examples/exp.png +3 -0
  41. examples/image.webp +0 -0
  42. examples/molecules1.png +0 -0
  43. examples/molecules2.png +0 -0
  44. examples/rdkit.png +0 -0
  45. examples/reaction0.png +3 -0
  46. examples/reaction1.jpg +0 -0
  47. examples/reaction2.png +0 -0
  48. examples/reaction3.png +0 -0
  49. examples/reaction4.png +3 -0
  50. examples/reaction5.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/exp.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/reaction0.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/reaction4.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/template1.png filter=lfs diff=lfs merge=lfs -text
40
+ molnextr/indigo/lib/Linux/x64/libbingo.so filter=lfs diff=lfs merge=lfs -text
41
+ molnextr/indigo/lib/Linux/x64/libindigo-renderer.so filter=lfs diff=lfs merge=lfs -text
42
+ molnextr/indigo/lib/Linux/x64/libindigo.so filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,80 @@
1
- ---
2
- title: ChemEagle
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.38.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: A Multi-Agent System for Multimodal Chemical IE
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChemEagle
2
+
3
+
4
+ ## :sparkles: Highlights
5
+ <p align="justify">
6
+ In this work, we present ChemEagle, a multimodal large language model (MLLM)-based multi-agent system that integrates diverse chemical information extraction tools to extract multimodal chemical reactions. By integrating 7 expert-designed tools and 6 chemical information extraction agents, ChemEagle not only processes individual modalities but also utilizes MLLMs' reasoning capabilities to unify extracted data, ensuring more accurate and comprehensive reaction representations. By bridging multimodal gaps, our approach significantly improves automated chemical knowledge extraction, facilitating more robust AI-driven chemical research.
7
+
8
+ [comment]: <> ()
9
+ ![visualization](examples/chemeagle.png)
10
+ <div align="center">
11
+ An example workflow of our ChemEagle. It illustrates how ChemEagle extracts and structures multimodal chemical reaction data. Each agent handles specific tasks, from reaction image parsing and molecular recognition to SMILES reconstruction and condition role interpretation, ensuring accurate and structured data integration.
12
+ </div>
13
+
14
+
15
+ ## 🤗 Multimodal chemical information extraction using [ChemEagle.Web](https://huggingface.co/spaces/CYF200127/ChemEagle)
16
+
17
+ Go to our [ChemEagle.Web demo](https://huggingface.co/spaces/CYF200127/ChemEagle) to directly use our tool online!
18
+
19
+ The input is a multimodal chemical reaction image:
20
+ ![visualization](examples/1.png)
21
+ <div align="center",width="100">
22
+ </div>
23
+
24
+ The output dictionary is a complete reaction list with reactant SMILES, product SMILES, detailed conditions and additional information for every reaction in the image:
25
+
26
+ ```
27
+ {
28
+ "reactions":[
29
+ {"reaction_id":"0_1",###Reaction template
30
+ "reactants":[{"smiles":"*C(*)=O","label":"1"},{"smiles":"Cc1ccc(S(=O)(=O)N2OC2c2ccccc2Cl)cc1","label":"2"}],
31
+ "conditions":[{"role":"reagent","text":"10 mol% B17 orB27","smiles":"C(C=CC=C1)=C1C[N+]2=CN3[C@H](C(C4=CC=CC=C4)(C5=CC=CC=C5)O[Si](C)(C)C(C)(C)C)CCC3=N2.F[B-](F)(F)F","label":"B17"},{"role":"reagent","text":"10 mol% B17 or B27","smiles":"CCCC(C=CC=C1)=C1[N+]2=CN3[C@H](C(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))O)CCC3=N2.F[B-](F)(F)F","label":"B27"},{"role":"reagent","text":"10 mol% Cs2CO3","smiles":"[Cs+].[Cs+].[O-]C(=O)[O-]"},{"role":"solvent","text":"PhMe","smiles":"Cc1ccccc1"},{"role":"temperature","text":"rt"},{"role":"yield","text":"38-78%"}],
32
+ "products":[{"smiles":"*C1*O[C@H](c2ccccc2Cl)N(S(=O)(=O)c2ccc(C)cc2)C1=O","label":"3"}]},
33
+
34
+ {"reaction_id":"1_1",###Detailed reaction
35
+ "reactants":[{"smiles":"CCC(=O)c1ccccc1","label":"1a"},{"smiles":"Cc1ccc(S(=O)(=O)N2OC2c2ccccc2Cl)cc1","label":"2"}],
36
+ "conditions":[{"role":"reagent","text":"10 mol% B17 or B27","smiles":"C(C=CC=C1)=C1C[N+]2=CN3[C@H](C(C4=CC=CC=C4)(C5=CC=CC=C5)O[Si](C)(C)C(C)(C)C)CCC3=N2.F[B-](F)(F)F","label":"B17"},{"role":"reagent","text":"10 mol% B17 or B27","smiles":"CCCC(C=CC=C1)=C1[N+]2=CN3[C@H](C(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))(C1=CC(=CC(=C1C(F)(F)F)C(F)(F)F))O)CCC3=N2.F[B-](F)(F)F","label":"B27"},{"role":"reagent","text":"10 mol% Cs2CO3","smiles":"[Cs+].[Cs+].[O-]C(=O)[O-]"},{"role":"solvent","text":"PhMe","smiles":"Cc1ccccc1"},{"role":"temperature","text":"rt"},{"role":"yield","text":"71%"}],
37
+ "products":[{"smiles":"CC[C@]1(c2ccccc2)O[C@H](c2ccccc2Cl)N(S(=O)(=O)c2ccc(C)cc2)C1=O","label":"3a"}],
38
+ "additional_info":[{"text":"14:1 dr, 91% ee"}]},
39
+
40
+ {"reaction_id":"2_1",... ###More detailed reactions}
41
+ ]
42
+ }
43
+ ```
44
+
45
+ ## :rocket: Using the code and the model
46
+ ### Using the code
47
+ Clone the following repositories:
48
+ ```
49
+ git clone https://github.com/CYF2000127/ChemEagle
50
+ ```
51
+ ### Example usage of the model
52
+ 1. First create and activate a [conda](https://numdifftools.readthedocs.io/en/stable/how-to/create_virtual_env_with_conda.html) environment with the following command in a Linux, Windows, or MacOS environment (Linux is the most recommended):
53
+ ```
54
+ conda create -n chemeagle python=3.10
55
+ conda activate chemeagle
56
+ ```
57
+
58
+ 2. Then install requirements:
59
+ ```
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ 3. Set up your API keys in your environment.
64
+ ```
65
+ export API_KEY=your-openai-api-key
66
+ ```
67
+ Alternatively, add your API keys in the [api_key.txt](./api_key.txt)
68
+
69
+ 4. Run the following code to extract multimodal chemical reactions from a multimodal reaction image:
70
+ ```python
71
+ from main import ChemEagle
72
+ image_path = './examples/1.png'
73
+ results = ChemEagle(image_path)
74
+ print(results)
75
+ ```
76
+
77
+ ## :warning: Acknowledgement
78
+ Our code is based on [MolNexTR](https://github.com/CYF2000127/MolNexTR), [MolScribe](https://github.com/thomas0809/MolScribe), [RxnIM](https://github.com/CYF2000127/RxnIM) [RxnScribe](https://github.com/thomas0809/RxNScribe), [ChemNER](https://github.com/Ozymandias314/ChemIENER), [ChemRxnExtractor](https://github.com/jiangfeng1124/ChemRxnExtractor), [AutoAgents](https://github.com/Link-AGI/AutoAgents) and [OpenAI](https://openai.com/) thanks their great jobs!
79
+
80
+
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "0.1.0"
2
+ __author__ = 'Alex Wang'
3
+ __credits__ = 'CSAIL'
__pycache__/get_R_group_sub_agent.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
__pycache__/get_molecular_agent.cpython-310.pyc ADDED
Binary file (8.85 kB). View file
 
__pycache__/get_reaction_agent.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
__pycache__/main.cpython-310.pyc ADDED
Binary file (4.54 kB). View file
 
app.ipynb ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "d13d3631",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Using CPU. Note: This module is much faster with a GPU.\n"
14
+ ]
15
+ },
16
+ {
17
+ "ename": "ValueError",
18
+ "evalue": "Please set API_KEY",
19
+ "output_type": "error",
20
+ "traceback": [
21
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
22
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
23
+ "Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mgradio\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mgr\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmain\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ChemEagle \u001b[38;5;66;03m# 假设内部已经管理 API Key\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrdkit\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Chem\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrdkit\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mChem\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m rdChemReactions, Draw, AllChem\n",
24
+ "File \u001b[0;32m/media/chenyufan/F/ChemEagle-hf/main.py:10\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mPIL\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Image\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mget_molecular_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m process_reaction_image_with_multiple_products_and_text_correctR\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mget_reaction_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_reaction_withatoms_correctR\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mget_R_group_sub_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m process_reaction_image_with_table_R_group, process_reaction_image_with_product_variant_R_group,get_full_reaction,get_multi_molecular_full\n",
25
+ "File \u001b[0;32m/media/chenyufan/F/ChemEagle-hf/get_molecular_agent.py:35\u001b[0m\n\u001b[1;32m 33\u001b[0m API_KEY \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAPI_KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m API_KEY:\n\u001b[0;32m---> 35\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease set API_KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 36\u001b[0m AZURE_ENDPOINT \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAZURE_ENDPOINT\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_multi_molecular\u001b[39m(image_path: \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mlist\u001b[39m:\n",
26
+ "\u001b[0;31mValueError\u001b[0m: Please set API_KEY"
27
+ ]
28
+ }
29
+ ],
30
+ "source": [
31
+ "import gradio as gr\n",
32
+ "import json\n",
33
+ "from main import ChemEagle # 假设内部已经管理 API Key\n",
34
+ "from rdkit import Chem\n",
35
+ "from rdkit.Chem import rdChemReactions, Draw, AllChem\n",
36
+ "from rdkit.Chem.Draw import rdMolDraw2D\n",
37
+ "import cairosvg\n",
38
+ "import re\n",
39
+ "\n",
40
+ "example_diagram = \"examples/exp.png\"\n",
41
+ "rdkit_image = \"examples/rdkit.png\"\n",
42
+ "\n",
43
+ "# 解析 ChemEagle 返回的结构化数据\n",
44
+ "def parse_reactions(output_json):\n",
45
+ " if isinstance(output_json, str):\n",
46
+ " reactions_data = json.loads(output_json)\n",
47
+ " else:\n",
48
+ " reactions_data = output_json\n",
49
+ " reactions_list = reactions_data.get(\"reactions\", [])\n",
50
+ " detailed_output = []\n",
51
+ " smiles_output = []\n",
52
+ "\n",
53
+ " for reaction in reactions_list:\n",
54
+ " reaction_id = reaction.get(\"reaction_id\", \"Unknown ID\")\n",
55
+ " reactants = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"reactants\", [])]\n",
56
+ " conditions = [\n",
57
+ " f\"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
58
+ " for c in reaction.get(\"condition\", [])\n",
59
+ " ]\n",
60
+ " conditions_1 = [\n",
61
+ " f\"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>\"\n",
62
+ " for c in reaction.get(\"condition\", [])\n",
63
+ " ]\n",
64
+ " products = [f\"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
65
+ " products_1 = [f\"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>\" for p in reaction.get(\"products\", [])]\n",
66
+ " products_2 = [r.get(\"smiles\", \"Unknown\") for r in reaction.get(\"products\", [])]\n",
67
+ " additional = reaction.get(\"additional_info\", [])\n",
68
+ " additional_str = [str(x) for x in additional if x]\n",
69
+ "\n",
70
+ " tail = conditions_1 + additional_str\n",
71
+ " tail_str = \", \".join(tail)\n",
72
+ " full_reaction = f\"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}\"\n",
73
+ " full_reaction = f\"<span style='color:black'>{full_reaction}</span>\"\n",
74
+ "\n",
75
+ " reaction_output = f\"<b>Reaction: </b> {reaction_id}<br>\"\n",
76
+ " reaction_output += f\" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>\"\n",
77
+ " reaction_output += f\" Conditions: {', '.join(conditions)}<br>\"\n",
78
+ " reaction_output += f\" Products: {', '.join(products)}<br>\"\n",
79
+ " reaction_output += f\" additional_info: {', '.join(additional_str)}<br>\"\n",
80
+ " reaction_output += f\" <b>Full Reaction:</b> {full_reaction}<br><br>\"\n",
81
+ " detailed_output.append(reaction_output)\n",
82
+ "\n",
83
+ " reaction_smiles = f\"{'.'.join(reactants)}>>{'.'.join(products_2)}\"\n",
84
+ " smiles_output.append(reaction_smiles)\n",
85
+ "\n",
86
+ " return detailed_output, smiles_output\n",
87
+ "\n",
88
+ "def process_chem_image(image):\n",
89
+ " image_path = \"temp_image.png\"\n",
90
+ " image.save(image_path)\n",
91
+ "\n",
92
+ " chemeagle_result = ChemEagle(image_path)\n",
93
+ " detailed, smiles = parse_reactions(chemeagle_result)\n",
94
+ "\n",
95
+ " json_path = \"output.json\"\n",
96
+ " with open(json_path, 'w') as jf:\n",
97
+ " json.dump(chemeagle_result, jf, indent=2)\n",
98
+ "\n",
99
+ " return \"\\n\\n\".join(detailed), smiles, example_diagram, json_path\n",
100
+ "\n",
101
+ "with gr.Blocks() as demo:\n",
102
+ " gr.Markdown(\n",
103
+ " \"\"\"\n",
104
+ " <center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center>\n",
105
+ " Upload a multimodal reaction image to extract multimodal chemical information.\n",
106
+ " \"\"\"\n",
107
+ " )\n",
108
+ "\n",
109
+ " with gr.Row():\n",
110
+ " with gr.Column(scale=1):\n",
111
+ " image_input = gr.Image(type=\"pil\", label=\"Upload a multimodal reaction image\")\n",
112
+ " with gr.Row():\n",
113
+ " clear_btn = gr.Button(\"Clear\")\n",
114
+ " run_btn = gr.Button(\"Run\", elem_id=\"submit-btn\")\n",
115
+ "\n",
116
+ " with gr.Column(scale=1):\n",
117
+ " gr.Markdown(\"### Parsed Reactions\")\n",
118
+ " reaction_output = gr.HTML(label=\"Detailed Reaction Output\")\n",
119
+ " gr.Markdown(\"### Schematic Diagram\")\n",
120
+ " schematic_diagram = gr.Image(value=example_diagram, label=\"示意图\")\n",
121
+ "\n",
122
+ " with gr.Column(scale=1):\n",
123
+ " gr.Markdown(\"### Machine-readable Output\")\n",
124
+ " smiles_output = gr.Textbox(\n",
125
+ " label=\"Reaction SMILES\",\n",
126
+ " show_copy_button=True,\n",
127
+ " interactive=False,\n",
128
+ " visible=False\n",
129
+ " )\n",
130
+ "\n",
131
+ " @gr.render(inputs=smiles_output)\n",
132
+ " def show_split(inputs):\n",
133
+ " if not inputs or (isinstance(inputs, str) and inputs.strip() == \"\"):\n",
134
+ " return gr.Textbox(label=\"SMILES of Reaction i\"), gr.Image(value=rdkit_image, label=\"RDKit Image of Reaction i\", height=100)\n",
135
+ " smiles_list = inputs.split(\",\")\n",
136
+ " smiles_list = [re.sub(r\"^\\s*\\[?'?|']?\\s*$\", \"\", item) for item in smiles_list]\n",
137
+ " components = []\n",
138
+ " for i, smiles in enumerate(smiles_list):\n",
139
+ " smiles_clean = smiles.replace('\"', '').replace(\"'\", \"\").replace(\"[\", \"\").replace(\"]\", \"\")\n",
140
+ " # 始终加入 SMILES 文本框\n",
141
+ " components.append(gr.Textbox(value=smiles_clean, label=f\"SMILES of Reaction {i+1}\", show_copy_button=True, interactive=False))\n",
142
+ " try:\n",
143
+ " rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True)\n",
144
+ " if not rxn:\n",
145
+ " continue\n",
146
+ " new_rxn = AllChem.ChemicalReaction()\n",
147
+ " for mol in rxn.GetReactants():\n",
148
+ " mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
149
+ " new_rxn.AddReactantTemplate(mol)\n",
150
+ " for mol in rxn.GetProducts():\n",
151
+ " mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))\n",
152
+ " new_rxn.AddProductTemplate(mol)\n",
153
+ " cleaned_rxn = new_rxn\n",
154
+ "\n",
155
+ " # 移除原子映射\n",
156
+ " for react in cleaned_rxn.GetReactants():\n",
157
+ " for atom in react.GetAtoms(): atom.SetAtomMapNum(0)\n",
158
+ " for prod in cleaned_rxn.GetProducts():\n",
159
+ " for atom in prod.GetAtoms(): atom.SetAtomMapNum(0)\n",
160
+ "\n",
161
+ " # 计算键长参考\n",
162
+ " ref_rxn = cleaned_rxn\n",
163
+ " react0 = ref_rxn.GetReactantTemplate(0)\n",
164
+ " react1 = ref_rxn.GetReactantTemplate(1) if ref_rxn.GetNumReactantTemplates() > 1 else None\n",
165
+ " if react0.GetNumBonds() > 0:\n",
166
+ " bond_len = Draw.MeanBondLength(react0)\n",
167
+ " elif react1 and react1.GetNumBonds() > 0:\n",
168
+ " bond_len = Draw.MeanBondLength(react1)\n",
169
+ " else:\n",
170
+ " bond_len = 1.0\n",
171
+ "\n",
172
+ " # 绘图\n",
173
+ " drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)\n",
174
+ " dopts = drawer.drawOptions()\n",
175
+ " dopts.padding = 0.1\n",
176
+ " dopts.includeRadicals = True\n",
177
+ " Draw.SetACS1996Mode(dopts, bond_len * 0.55)\n",
178
+ " dopts.bondLineWidth = 1.5\n",
179
+ " drawer.DrawReaction(cleaned_rxn)\n",
180
+ " drawer.FinishDrawing()\n",
181
+ " svg = drawer.GetDrawingText()\n",
182
+ " svg_file = f\"reaction_{i+1}.svg\"\n",
183
+ " with open(svg_file, \"w\") as f: f.write(svg)\n",
184
+ " png_file = f\"reaction_{i+1}.png\"\n",
185
+ " cairosvg.svg2png(url=svg_file, write_to=png_file)\n",
186
+ " components.append(gr.Image(value=png_file, label=f\"RDKit Image of Reaction {i+1}\"))\n",
187
+ " except Exception as e:\n",
188
+ " print(f\"Failed to draw reaction {i+1} for SMILES '{smiles_clean}': {e}\")\n",
189
+ " # 绘图失败则跳过\n",
190
+ " return components\n",
191
+ "\n",
192
+ " download_json = gr.File(label=\"Download JSON File\")\n",
193
+ "\n",
194
+ " gr.Examples(\n",
195
+ " examples=[\n",
196
+ " [\"examples/reaction1.jpg\"],\n",
197
+ " [\"examples/reaction2.png\"],\n",
198
+ " [\"examples/reaction3.png\"],\n",
199
+ " [\"examples/reaction4.png\"],\n",
200
+ " ],\n",
201
+ " inputs=[image_input],\n",
202
+ " outputs=[reaction_output, smiles_output, schematic_diagram, download_json],\n",
203
+ " cache_examples=False,\n",
204
+ " examples_per_page=4,\n",
205
+ " )\n",
206
+ "\n",
207
+ " clear_btn.click(\n",
208
+ " lambda: (None, None, None, None),\n",
209
+ " inputs=[],\n",
210
+ " outputs=[image_input, reaction_output, smiles_output, download_json]\n",
211
+ " )\n",
212
+ " run_btn.click(\n",
213
+ " process_chem_image,\n",
214
+ " inputs=[image_input],\n",
215
+ " outputs=[reaction_output, smiles_output, schematic_diagram, download_json]\n",
216
+ " )\n",
217
+ "\n",
218
+ " demo.css = \"\"\"\n",
219
+ " #submit-btn {\n",
220
+ " background-color: #FF914D;\n",
221
+ " color: white;\n",
222
+ " font-weight: bold;\n",
223
+ " }\n",
224
+ " \"\"\"\n",
225
+ "\n",
226
+ " demo.launch()\n"
227
+ ]
228
+ }
229
+ ],
230
+ "metadata": {
231
+ "kernelspec": {
232
+ "display_name": "openchemie",
233
+ "language": "python",
234
+ "name": "python3"
235
+ },
236
+ "language_info": {
237
+ "codemirror_mode": {
238
+ "name": "ipython",
239
+ "version": 3
240
+ },
241
+ "file_extension": ".py",
242
+ "mimetype": "text/x-python",
243
+ "name": "python",
244
+ "nbconvert_exporter": "python",
245
+ "pygments_lexer": "ipython3",
246
+ "version": "3.10.14"
247
+ }
248
+ },
249
+ "nbformat": 4,
250
+ "nbformat_minor": 5
251
+ }
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from main import ChemEagle # 假设内部已经管理 API Key
4
+ from rdkit import Chem
5
+ from rdkit.Chem import rdChemReactions, Draw, AllChem
6
+ from rdkit.Chem.Draw import rdMolDraw2D
7
+ import cairosvg
8
+ import re
9
+
10
+ example_diagram = "examples/exp.png"
11
+ rdkit_image = "examples/rdkit.png"
12
+
13
+ # 解析 ChemEagle 返回的结构化数据
14
+ def parse_reactions(output_json):
15
+ if isinstance(output_json, str):
16
+ reactions_data = json.loads(output_json)
17
+ else:
18
+ reactions_data = output_json
19
+ reactions_list = reactions_data.get("reactions", [])
20
+ detailed_output = []
21
+ smiles_output = []
22
+
23
+ for reaction in reactions_list:
24
+ reaction_id = reaction.get("reaction_id", "Unknown ID")
25
+ reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
26
+ conditions = [
27
+ f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
28
+ for c in reaction.get("condition", [])
29
+ ]
30
+ conditions_1 = [
31
+ f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
32
+ for c in reaction.get("condition", [])
33
+ ]
34
+ products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
35
+ products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
36
+ products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]
37
+ additional = reaction.get("additional_info", [])
38
+ additional_str = [str(x) for x in additional if x]
39
+
40
+ tail = conditions_1 + additional_str
41
+ tail_str = ", ".join(tail)
42
+ full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}"
43
+ full_reaction = f"<span style='color:black'>{full_reaction}</span>"
44
+
45
+ reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
46
+ reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
47
+ reaction_output += f" Conditions: {', '.join(conditions)}<br>"
48
+ reaction_output += f" Products: {', '.join(products)}<br>"
49
+ reaction_output += f" additional_info: {', '.join(additional_str)}<br>"
50
+ reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br><br>"
51
+ detailed_output.append(reaction_output)
52
+
53
+ reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
54
+ smiles_output.append(reaction_smiles)
55
+
56
+ return detailed_output, smiles_output
57
+
58
+ def process_chem_image(image):
59
+ image_path = "temp_image.png"
60
+ image.save(image_path)
61
+
62
+ chemeagle_result = ChemEagle(image_path)
63
+ detailed, smiles = parse_reactions(chemeagle_result)
64
+
65
+ json_path = "output.json"
66
+ with open(json_path, 'w') as jf:
67
+ json.dump(chemeagle_result, jf, indent=2)
68
+
69
+ return "\n\n".join(detailed), smiles, example_diagram, json_path
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown(
73
+ """
74
+ <center><h1>ChemEAGLE: A Multi-Agent System for Information Extraction from the Chemical Literature</h1></center>
75
+ Upload a chemical graphics to extract machine-readable chemical data.
76
+ """
77
+ )
78
+
79
+ with gr.Row():
80
+ with gr.Column(scale=1):
81
+ image_input = gr.Image(type="pil", label="Upload a multimodal reaction image")
82
+ with gr.Row():
83
+ clear_btn = gr.Button("Clear")
84
+ run_btn = gr.Button("Run", elem_id="submit-btn")
85
+
86
+ with gr.Column(scale=1):
87
+ gr.Markdown("### Parsed Reactions")
88
+ reaction_output = gr.HTML(label="Detailed Reaction Output")
89
+ gr.Markdown("### Schematic Diagram")
90
+ schematic_diagram = gr.Image(value=example_diagram, label="示意图")
91
+
92
+ with gr.Column(scale=1):
93
+ gr.Markdown("### Machine-readable Output")
94
+ smiles_output = gr.Textbox(
95
+ label="Reaction SMILES",
96
+ show_copy_button=True,
97
+ interactive=False,
98
+ visible=False
99
+ )
100
+
101
+ @gr.render(inputs=smiles_output)
102
+ def show_split(inputs):
103
+ if not inputs or (isinstance(inputs, str) and inputs.strip() == ""):
104
+ return gr.Textbox(label="SMILES of Reaction i"), gr.Image(value=rdkit_image, label="RDKit Image of Reaction i", height=100)
105
+ smiles_list = inputs.split(",")
106
+ smiles_list = [re.sub(r"^\s*\[?'?|']?\s*$", "", item) for item in smiles_list]
107
+ components = []
108
+ for i, smiles in enumerate(smiles_list):
109
+ smiles_clean = smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
110
+ # 始终加入 SMILES 文本框
111
+ components.append(gr.Textbox(value=smiles_clean, label=f"SMILES of Reaction {i}", show_copy_button=True, interactive=False))
112
+ try:
113
+ rxn = rdChemReactions.ReactionFromSmarts(smiles_clean, useSmiles=True)
114
+ if not rxn:
115
+ continue
116
+ new_rxn = AllChem.ChemicalReaction()
117
+ for mol in rxn.GetReactants():
118
+ mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
119
+ new_rxn.AddReactantTemplate(mol)
120
+ for mol in rxn.GetProducts():
121
+ mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
122
+ new_rxn.AddProductTemplate(mol)
123
+ cleaned_rxn = new_rxn
124
+
125
+ # 移除原子映射
126
+ for react in cleaned_rxn.GetReactants():
127
+ for atom in react.GetAtoms(): atom.SetAtomMapNum(0)
128
+ for prod in cleaned_rxn.GetProducts():
129
+ for atom in prod.GetAtoms(): atom.SetAtomMapNum(0)
130
+
131
+ # 计算键长参考
132
+ ref_rxn = cleaned_rxn
133
+ react0 = ref_rxn.GetReactantTemplate(0)
134
+ react1 = ref_rxn.GetReactantTemplate(1) if ref_rxn.GetNumReactantTemplates() > 1 else None
135
+ if react0.GetNumBonds() > 0:
136
+ bond_len = Draw.MeanBondLength(react0)
137
+ elif react1 and react1.GetNumBonds() > 0:
138
+ bond_len = Draw.MeanBondLength(react1)
139
+ else:
140
+ bond_len = 1.0
141
+
142
+ # 绘图
143
+ drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
144
+ dopts = drawer.drawOptions()
145
+ dopts.padding = 0.1
146
+ dopts.includeRadicals = True
147
+ Draw.SetACS1996Mode(dopts, bond_len * 0.55)
148
+ dopts.bondLineWidth = 1.5
149
+ drawer.DrawReaction(cleaned_rxn)
150
+ drawer.FinishDrawing()
151
+ svg = drawer.GetDrawingText()
152
+ svg_file = f"reaction_{i}.svg"
153
+ with open(svg_file, "w") as f: f.write(svg)
154
+ png_file = f"reaction_{i}.png"
155
+ cairosvg.svg2png(url=svg_file, write_to=png_file)
156
+ components.append(gr.Image(value=png_file, label=f"RDKit Image of Reaction {i}"))
157
+ except Exception as e:
158
+ print(f"Failed to draw reaction {i} for SMILES '{smiles_clean}': {e}")
159
+ # 绘图失败则跳过
160
+ return components
161
+
162
+ download_json = gr.File(label="Download JSON File")
163
+
164
+ gr.Examples(
165
+ examples=[
166
+ ["examples/reaction0.png"],
167
+ ["examples/reaction1.jpg"],
168
+ ["examples/reaction2.png"],
169
+ ["examples/reaction3.png"],
170
+ ["examples/reaction4.png"],
171
+ ["examples/reaction5.png"],
172
+ ["examples/template1.png"],
173
+ ["examples/molecules1.png"],
174
+ ["examples/molecules2.png"],
175
+
176
+ ],
177
+ inputs=[image_input],
178
+ outputs=[reaction_output, smiles_output, schematic_diagram, download_json],
179
+ cache_examples=False,
180
+ examples_per_page=4,
181
+ )
182
+
183
+ clear_btn.click(
184
+ lambda: (None, None, None, None),
185
+ inputs=[],
186
+ outputs=[image_input, reaction_output, smiles_output, download_json]
187
+ )
188
+ run_btn.click(
189
+ process_chem_image,
190
+ inputs=[image_input],
191
+ outputs=[reaction_output, smiles_output, schematic_diagram, download_json]
192
+ )
193
+
194
+ demo.css = """
195
+ #submit-btn {
196
+ background-color: #FF914D;
197
+ color: white;
198
+ font-weight: bold;
199
+ }
200
+ """
201
+
202
+ demo.launch()
apt.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libpoppler-cpp-dev
2
+ pkg-config
chemiener/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import ChemNER
chemiener/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
chemiener/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (185 Bytes). View file
 
chemiener/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
chemiener/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (5.35 kB). View file
 
chemiener/__pycache__/interface.cpython-310.pyc ADDED
Binary file (4.46 kB). View file
 
chemiener/__pycache__/interface.cpython-38.pyc ADDED
Binary file (4.47 kB). View file
 
chemiener/__pycache__/model.cpython-310.pyc ADDED
Binary file (684 Bytes). View file
 
chemiener/__pycache__/model.cpython-38.pyc ADDED
Binary file (680 Bytes). View file
 
chemiener/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
chemiener/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
chemiener/dataset.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import copy
4
+ import random
5
+ import json
6
+ import contextlib
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
13
+
14
+ from transformers import BertTokenizerFast, AutoTokenizer, RobertaTokenizerFast
15
+
16
+ from .utils import get_class_to_index
17
+
18
+
19
+
20
+ class NERDataset(Dataset):
21
+ def __init__(self, args, data_file, split='train'):
22
+ super().__init__()
23
+ self.args = args
24
+ if data_file:
25
+ data_path = os.path.join(args.data_path, data_file)
26
+ with open(data_path) as f:
27
+ self.data = json.load(f)
28
+ self.name = os.path.basename(data_file).split('.')[0]
29
+ self.split = split
30
+ self.is_train = (split == 'train')
31
+ self.tokenizer = AutoTokenizer.from_pretrained(self.args.roberta_checkpoint, cache_dir = self.args.cache_dir)#BertTokenizerFast.from_pretrained('allenai/scibert_scivocab_uncased')
32
+ self.class_to_index = get_class_to_index(self.args.corpus)
33
+ self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
34
+
35
+ #commment
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+
41
+ text_tokenized = self.tokenizer(self.data[str(idx)]['text'], truncation = True, max_length = self.args.max_seq_length)
42
+ if len(text_tokenized['input_ids']) > 512: print(len(text_tokenized['input_ids']))
43
+ text_tokenized_untruncated = self.tokenizer(self.data[str(idx)]['text'])
44
+ return text_tokenized, self.align_labels(text_tokenized, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text'])), self.align_labels(text_tokenized_untruncated, self.data[str(idx)]['entities'], len(self.data[str(idx)]['text']))
45
+
46
+ def align_labels(self, text_tokenized, entities, length):
47
+ char_to_class = {}
48
+
49
+ for entity in entities:
50
+ for span in entities[entity]["span"]:
51
+ for i in range(span[0], span[1]):
52
+ char_to_class[i] = self.class_to_index[('B-' if i == span[0] else 'I-')+str(entities[entity]["type"])]
53
+
54
+ for i in range(length):
55
+ if i not in char_to_class:
56
+ char_to_class[i] = 0
57
+
58
+ classes = []
59
+ for i in range(len(text_tokenized[0])):
60
+ span = text_tokenized.token_to_chars(i)
61
+ if span is not None:
62
+ classes.append(char_to_class[span.start])
63
+ else:
64
+ classes.append(-100)
65
+
66
+ return torch.LongTensor(classes)
67
+
68
+ def make_html(word_tokens, predictions):
69
+
70
+ toreturn = '''<!DOCTYPE html>
71
+ <html>
72
+ <head>
73
+ <title>Named Entity Recognition Visualization</title>
74
+ <style>
75
+ .EXAMPLE_LABEL {
76
+ color: red;
77
+ text-decoration: underline red;
78
+ }
79
+ .REACTION_PRODUCT {
80
+ color: orange;
81
+ text-decoration: underline orange;
82
+ }
83
+ .STARTING_MATERIAL {
84
+ color: gold;
85
+ text-decoration: underline gold;
86
+ }
87
+ .REAGENT_CATALYST {
88
+ color: green;
89
+ text-decoration: underline green;
90
+ }
91
+ .SOLVENT {
92
+ color: cyan;
93
+ text-decoration: underline cyan;
94
+ }
95
+ .OTHER_COMPOUND {
96
+ color: blue;
97
+ text-decoration: underline blue;
98
+ }
99
+ .TIME {
100
+ color: purple;
101
+ text-decoration: underline purple;
102
+ }
103
+ .TEMPERATURE {
104
+ color: magenta;
105
+ text-decoration: underline magenta;
106
+ }
107
+ .YIELD_OTHER {
108
+ color: palegreen;
109
+ text-decoration: underline palegreen;
110
+ }
111
+ .YIELD_PERCENT {
112
+ color: pink;
113
+ text-decoration: underline pink;
114
+ }
115
+ </style>
116
+ </head>
117
+ <body>
118
+ <p>'''
119
+ last_label = None
120
+ for idx, item in enumerate(word_tokens):
121
+ decoded = self.tokenizer.decode(item, skip_special_tokens = True)
122
+ if len(decoded)>0:
123
+ if idx!=0 and decoded[0]!='#':
124
+ toreturn+=" "
125
+ label = predictions[idx]
126
+ if label == last_label:
127
+
128
+ toreturn+=decoded if decoded[0]!="#" else decoded[2:]
129
+ else:
130
+ if last_label is not None and last_label>0:
131
+ toreturn+="</u>"
132
+ if label >0:
133
+ toreturn+="<u class=\""
134
+ toreturn+=self.index_to_class[label]
135
+ toreturn+="\">"
136
+ toreturn+=decoded if decoded[0]!="#" else decoded[2:]
137
+ if label == 0:
138
+ toreturn+=decoded if decoded[0]!="#" else decoded[2:]
139
+ if idx==len(word_tokens) and label>0:
140
+ toreturn+="</u>"
141
+ last_label = label
142
+
143
+ toreturn += ''' </p>
144
+ </body>
145
+ </html>'''
146
+ return toreturn
147
+
148
+
149
+ def get_collate_fn():
150
+ def collate(batch):
151
+
152
+
153
+
154
+ sentences = []
155
+ masks = []
156
+ refs = []
157
+
158
+
159
+ for ex in batch:
160
+ sentences.append(torch.LongTensor(ex[0]['input_ids']))
161
+ masks.append(torch.Tensor(ex[0]['attention_mask']))
162
+ refs.append(ex[1])
163
+
164
+ sentences = pad_sequence(sentences, batch_first = True, padding_value = 0)
165
+ masks = pad_sequence(masks, batch_first = True, padding_value = 0)
166
+ refs = pad_sequence(refs, batch_first = True, padding_value = -100)
167
+ return sentences, masks, refs
168
+
169
+ return collate
170
+
171
+
172
+
chemiener/interface.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from typing import List
4
+ import torch
5
+ import numpy as np
6
+
7
+ from .model import build_model
8
+
9
+ from .dataset import NERDataset, get_collate_fn
10
+
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from .utils import get_class_to_index
14
+
15
+ class ChemNER:
16
+
17
+ def __init__(self, model_path, device = None, cache_dir = None):
18
+
19
+ self.args = self._get_args(cache_dir)
20
+
21
+ states = torch.load(model_path, map_location = torch.device('cpu'))
22
+
23
+ if device is None:
24
+ device = torch.device('cpu')
25
+
26
+ self.device = device
27
+
28
+ self.model = self.get_model(self.args, device, states['state_dict'])
29
+
30
+ self.collate = get_collate_fn()
31
+
32
+ self.dataset = NERDataset(self.args, data_file = None)
33
+
34
+ self.class_to_index = get_class_to_index(self.args.corpus)
35
+
36
+ self.index_to_class = {self.class_to_index[key]: key for key in self.class_to_index}
37
+
38
+ def _get_args(self, cache_dir):
39
+ parser = argparse.ArgumentParser()
40
+
41
+ parser.add_argument('--roberta_checkpoint', default = 'dmis-lab/biobert-large-cased-v1.1', type=str, help='which roberta config to use')
42
+
43
+ parser.add_argument('--corpus', default = "chemdner", type=str, help="which corpus should the tags be from")
44
+
45
+ args = parser.parse_args([])
46
+
47
+ args.cache_dir = cache_dir
48
+
49
+ return args
50
+
51
+ def get_model(self, args, device, model_states):
52
+ model = build_model(args)
53
+
54
+ def remove_prefix(state_dict):
55
+ return {k.replace('model.', ''): v for k, v in state_dict.items()}
56
+
57
+ model.load_state_dict(remove_prefix(model_states), strict = False)
58
+
59
+ model.to(device)
60
+
61
+ model.eval()
62
+
63
+ return model
64
+
65
+ def predict_strings(self, strings: List, batch_size = 8):
66
+ device = self.device
67
+
68
+ predictions = []
69
+
70
+ def prepare_output(char_span, prediction):
71
+ toreturn = []
72
+
73
+
74
+ i = 0
75
+
76
+ while i < len(char_span):
77
+ if prediction[i][0] == 'B':
78
+ toreturn.append((prediction[i][2:], [char_span[i].start, char_span[i].end]))
79
+
80
+
81
+
82
+
83
+ elif len(toreturn) > 0 and prediction[i][2:] == toreturn[-1][0]:
84
+ toreturn[-1] = (toreturn[-1][0], [toreturn[-1][1][0], char_span[i].end])
85
+
86
+
87
+
88
+ i += 1
89
+
90
+
91
+ return toreturn
92
+
93
+ output = []
94
+ for idx in range(0, len(strings), batch_size):
95
+ batch_strings = strings[idx:idx+batch_size]
96
+ batch_strings_tokenized = [(self.dataset.tokenizer(s, truncation = True, max_length = 512), torch.Tensor([-1]), torch.Tensor([-1]) ) for s in batch_strings]
97
+
98
+
99
+ sentences, masks, refs = self.collate(batch_strings_tokenized)
100
+
101
+ predictions = self.model(input_ids = sentences.to(device), attention_mask = masks.to(device))[0].argmax(dim = 2).to('cpu')
102
+
103
+ sentences_list = list(sentences)
104
+
105
+ predictions_list = list(predictions)
106
+
107
+
108
+ char_spans = []
109
+ for j, sentence in enumerate(sentences_list):
110
+ to_add = [batch_strings_tokenized[j][0].token_to_chars(i) for i, word in enumerate(sentence) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0 ]
111
+ char_spans.append(to_add)
112
+
113
+ class_predictions = [[self.index_to_class[int(pred.item())] for (pred, word) in zip(sentence_p, sentence_w) if len(self.dataset.tokenizer.decode(int(word.item()), skip_special_tokens = True)) > 0] for (sentence_p, sentence_w) in zip(predictions_list, sentences_list)]
114
+
115
+
116
+
117
+ output+=[prepare_output(char_span, prediction) for char_span, prediction in zip(char_spans, class_predictions)]
118
+
119
+ return output
120
+
121
+
122
+
123
+
124
+
chemiener/main.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import json
4
+ import random
5
+ import argparse
6
+ import numpy as np
7
+
8
+ import time
9
+
10
+ import torch
11
+ from torch.profiler import profile, record_function, ProfilerActivity
12
+ import torch.distributed as dist
13
+ import pytorch_lightning as pl
14
+ from pytorch_lightning import LightningModule, LightningDataModule
15
+ from pytorch_lightning.callbacks import LearningRateMonitor
16
+ from pytorch_lightning.strategies.ddp import DDPStrategy
17
+ from transformers import get_scheduler
18
+ import transformers
19
+
20
+ from dataset import NERDataset, get_collate_fn
21
+
22
+ from model import build_model
23
+
24
+ from utils import get_class_to_index
25
+
26
+ import evaluate
27
+
28
+ from seqeval.metrics import accuracy_score
29
+ from seqeval.metrics import classification_report
30
+ from seqeval.metrics import f1_score
31
+ from seqeval.scheme import IOB2
32
+
33
+
34
+
35
+ def get_args(notebook=False):
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument('--do_train', action='store_true')
38
+ parser.add_argument('--do_valid', action='store_true')
39
+ parser.add_argument('--do_test', action='store_true')
40
+ parser.add_argument('--fp16', action='store_true')
41
+ parser.add_argument('--seed', type=int, default=42)
42
+ parser.add_argument('--gpus', type=int, default=1)
43
+ parser.add_argument('--print_freq', type=int, default=200)
44
+ parser.add_argument('--debug', action='store_true')
45
+ parser.add_argument('--no_eval', action='store_true')
46
+
47
+
48
+ # Data
49
+ parser.add_argument('--data_path', type=str, default=None)
50
+ parser.add_argument('--image_path', type=str, default=None)
51
+ parser.add_argument('--train_file', type=str, default=None)
52
+ parser.add_argument('--valid_file', type=str, default=None)
53
+ parser.add_argument('--test_file', type=str, default=None)
54
+ parser.add_argument('--vocab_file', type=str, default=None)
55
+ parser.add_argument('--format', type=str, default='reaction')
56
+ parser.add_argument('--num_workers', type=int, default=8)
57
+ parser.add_argument('--input_size', type=int, default=224)
58
+
59
+ # Training
60
+ parser.add_argument('--epochs', type=int, default=8)
61
+ parser.add_argument('--batch_size', type=int, default=256)
62
+ parser.add_argument('--lr', type=float, default=1e-4)
63
+ parser.add_argument('--weight_decay', type=float, default=0.05)
64
+ parser.add_argument('--max_grad_norm', type=float, default=5.)
65
+ parser.add_argument('--scheduler', type=str, choices=['cosine', 'constant'], default='cosine')
66
+ parser.add_argument('--warmup_ratio', type=float, default=0)
67
+ parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
68
+ parser.add_argument('--load_path', type=str, default=None)
69
+ parser.add_argument('--load_encoder_only', action='store_true')
70
+ parser.add_argument('--train_steps_per_epoch', type=int, default=-1)
71
+ parser.add_argument('--eval_per_epoch', type=int, default=10)
72
+ parser.add_argument('--save_path', type=str, default='output/')
73
+ parser.add_argument('--save_mode', type=str, default='best', choices=['best', 'all', 'last'])
74
+ parser.add_argument('--load_ckpt', type=str, default='best')
75
+ parser.add_argument('--resume', action='store_true')
76
+ parser.add_argument('--num_train_example', type=int, default=None)
77
+
78
+ parser.add_argument('--roberta_checkpoint', type=str, default = "roberta-base")
79
+
80
+ parser.add_argument('--corpus', type=str, default = "chemu")
81
+
82
+ parser.add_argument('--cache_dir')
83
+
84
+ parser.add_argument('--eval_truncated', action='store_true')
85
+
86
+ parser.add_argument('--max_seq_length', type = int, default=512)
87
+
88
+ args = parser.parse_args([]) if notebook else parser.parse_args()
89
+
90
+
91
+
92
+
93
+
94
+ return args
95
+
96
+
97
+ class ChemIENERecognizer(LightningModule):
98
+
99
+ def __init__(self, args):
100
+ super().__init__()
101
+
102
+ self.args = args
103
+
104
+ self.model = build_model(args)
105
+
106
+ self.validation_step_outputs = []
107
+
108
+ def training_step(self, batch, batch_idx):
109
+
110
+
111
+
112
+
113
+ sentences, masks, refs,_ = batch
114
+ '''
115
+ print("sentences " + str(sentences))
116
+ print("sentence shape " + str(sentences.shape))
117
+ print("masks " + str(masks))
118
+ print("masks shape " + str(masks.shape))
119
+ print("refs " + str(refs))
120
+ print("refs shape " + str(refs.shape))
121
+ '''
122
+
123
+
124
+
125
+ loss, logits = self.model(input_ids=sentences, attention_mask=masks, labels=refs)
126
+ self.log('train/loss', loss)
127
+ self.log('lr', self.lr_schedulers().get_lr()[0], prog_bar=True, logger=False)
128
+ return loss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+
132
+ sentences, masks, refs, untruncated = batch
133
+ '''
134
+ print("sentences " + str(sentences))
135
+ print("sentence shape " + str(sentences.shape))
136
+ print("masks " + str(masks))
137
+ print("masks shape " + str(masks.shape))
138
+ print("refs " + str(refs))
139
+ print("refs shape " + str(refs.shape))
140
+ '''
141
+
142
+ logits = self.model(input_ids = sentences, attention_mask=masks)[0]
143
+ '''
144
+ print("logits " + str(logits))
145
+ print(sentences.shape)
146
+ print(logits.shape)
147
+ print(torch.eq(logits.argmax(dim = 2), refs).sum())
148
+ '''
149
+ self.validation_step_outputs.append((sentences.to("cpu"), logits.argmax(dim = 2).to("cpu"), refs.to('cpu'), untruncated.to("cpu")))
150
+
151
+
152
+ def on_validation_epoch_end(self):
153
+ if self.trainer.num_devices > 1:
154
+ gathered_outputs = [None for i in range(self.trainer.num_devices)]
155
+ dist.all_gather_object(gathered_outputs, self.validation_step_outputs)
156
+ gathered_outputs = sum(gathered_outputs, [])
157
+ else:
158
+ gathered_outputs = self.validation_step_outputs
159
+
160
+ sentences = [list(output[0]) for output in gathered_outputs]
161
+
162
+ class_to_index = get_class_to_index(self.args.corpus)
163
+
164
+
165
+
166
+ index_to_class = {class_to_index[key]: key for key in class_to_index}
167
+ predictions = [list(output[1]) for output in gathered_outputs]
168
+ labels = [list(output[2]) for output in gathered_outputs]
169
+
170
+ untruncateds = [list(output[3]) for output in gathered_outputs]
171
+
172
+ untruncateds = [[index_to_class[int(label.item())] for label in sentence if int(label.item()) != -100] for batched in untruncateds for sentence in batched]
173
+
174
+
175
+ output = {"sentences": [[int(word.item()) for (word, label) in zip(sentence_w, sentence_l) if label != -100] for (batched_w, batched_l) in zip(sentences, labels) for (sentence_w, sentence_l) in zip(batched_w, batched_l) ],
176
+ "predictions": [[index_to_class[int(pred.item())] for (pred, label) in zip(sentence_p, sentence_l) if label!=-100] for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) ],
177
+ "groundtruth": [[index_to_class[int(label.item())] for label in sentence if label != -100] for batched in labels for sentence in batched]}
178
+
179
+
180
+ #true_labels = [str(label.item()) for batched in labels for sentence in batched for label in sentence if label != -100]
181
+ #true_predictions = [str(pred.item()) for (batched_p, batched_l) in zip(predictions, labels) for (sentence_p, sentence_l) in zip(batched_p, batched_l) for (pred, label) in zip(sentence_p, sentence_l) if label!=-100 ]
182
+
183
+
184
+
185
+ #print("true_label " + str(len(true_labels)) + " true_predictions "+str(len(true_predictions)))
186
+
187
+
188
+ #predictions = utils.merge_predictions(gathered_outputs)
189
+ name = self.eval_dataset.name
190
+ scores = [0]
191
+
192
+ #print(predictions)
193
+ #print(predictions[0].shape)
194
+
195
+ if self.trainer.is_global_zero:
196
+ if not self.args.no_eval:
197
+ epoch = self.trainer.current_epoch
198
+
199
+ metric = evaluate.load("seqeval", cache_dir = self.args.cache_dir)
200
+
201
+ predictions = [ preds + ['O'] * (len(full_groundtruth) - len(preds)) for (preds, full_groundtruth) in zip(output['predictions'], untruncateds)]
202
+ all_metrics = metric.compute(predictions = predictions, references = untruncateds)
203
+
204
+ #accuracy = sum([1 if p == l else 0 for (p, l) in zip(true_predictions, true_labels)])/len(true_labels)
205
+
206
+ #precision = torch.eq(self.eval_dataset.data, predictions.argmax(dim = 1)).sum().float()/self.eval_dataset.data.numel()
207
+ #self.print("Epoch: "+str(epoch)+" accuracy: "+str(accuracy))
208
+ if self.args.eval_truncated:
209
+ report = classification_report(output['groundtruth'], output['predictions'], mode = 'strict', scheme = IOB2, output_dict = True)
210
+ else:
211
+ #report = classification_report(predictions, untruncateds, output_dict = True)#, mode = 'strict', scheme = IOB2, output_dict = True)
212
+ report = classification_report(predictions, untruncateds, mode = 'strict', scheme = IOB2, output_dict = True)
213
+ self.print(report)
214
+ #self.print("______________________________________________")
215
+ #self.print(report_strict)
216
+ scores = [report['micro avg']['f1-score']]
217
+ with open(os.path.join(self.trainer.default_root_dir, f'prediction_{name}.json'), 'w') as f:
218
+ json.dump(output, f)
219
+
220
+ dist.broadcast_object_list(scores)
221
+
222
+ self.log('val/score', scores[0], prog_bar=True, rank_zero_only=True)
223
+ self.validation_step_outputs.clear()
224
+
225
+
226
+
227
+ self.validation_step_outputs.clear()
228
+
229
+ def configure_optimizers(self):
230
+ num_training_steps = self.trainer.num_training_steps
231
+
232
+ self.print(f'Num training steps: {num_training_steps}')
233
+ num_warmup_steps = int(num_training_steps * self.args.warmup_ratio)
234
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
235
+ scheduler = get_scheduler(self.args.scheduler, optimizer, num_warmup_steps, num_training_steps)
236
+ return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
237
+
238
+ class NERDataModule(LightningDataModule):
239
+
240
+ def __init__(self, args):
241
+ super().__init__()
242
+ self.args = args
243
+ self.collate_fn = get_collate_fn()
244
+
245
+ def prepare_data(self):
246
+ args = self.args
247
+ if args.do_train:
248
+ self.train_dataset = NERDataset(args, args.train_file, split='train')
249
+ if self.args.do_train or self.args.do_valid:
250
+ self.val_dataset = NERDataset(args, args.valid_file, split='valid')
251
+ if self.args.do_test:
252
+ self.test_dataset = NERDataset(args, args.test_file, split='valid')
253
+
254
+ def print_stats(self):
255
+ if self.args.do_train:
256
+ print(f'Train dataset: {len(self.train_dataset)}')
257
+ if self.args.do_train or self.args.do_valid:
258
+ print(f'Valid dataset: {len(self.val_dataset)}')
259
+ if self.args.do_test:
260
+ print(f'Test dataset: {len(self.test_dataset)}')
261
+
262
+
263
+ def train_dataloader(self):
264
+ return torch.utils.data.DataLoader(
265
+ self.train_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
266
+ collate_fn=self.collate_fn)
267
+
268
+ def val_dataloader(self):
269
+ return torch.utils.data.DataLoader(
270
+ self.val_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
271
+ collate_fn=self.collate_fn)
272
+
273
+
274
+ def test_dataloader(self):
275
+ return torch.utils.data.DataLoader(
276
+ self.test_dataset, batch_size=self.args.batch_size, num_workers=self.args.num_workers,
277
+ collate_fn=self.collate_fn)
278
+
279
+
280
+
281
+ class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
282
+ def _get_metric_interpolated_filepath_name(self, monitor_candidates, trainer, del_filepath=None) -> str:
283
+ filepath = self.format_checkpoint_name(monitor_candidates)
284
+ return filepath
285
+
286
+ def main():
287
+ transformers.utils.logging.set_verbosity_error()
288
+ args = get_args()
289
+
290
+ pl.seed_everything(args.seed, workers = True)
291
+
292
+ if args.do_train:
293
+ model = ChemIENERecognizer(args)
294
+ else:
295
+ model = ChemIENERecognizer.load_from_checkpoint(os.path.join(args.save_path, 'checkpoints/best.ckpt'), strict=False,
296
+ args=args)
297
+
298
+ dm = NERDataModule(args)
299
+ dm.prepare_data()
300
+ dm.print_stats()
301
+
302
+ checkpoint = ModelCheckpoint(monitor='val/score', mode='max', save_top_k=1, filename='best', save_last=True)
303
+ # checkpoint = ModelCheckpoint(monitor=None, save_top_k=0, save_last=True)
304
+ lr_monitor = LearningRateMonitor(logging_interval='step')
305
+ logger = pl.loggers.TensorBoardLogger(args.save_path, name='', version='')
306
+
307
+ trainer = pl.Trainer(
308
+ strategy=DDPStrategy(find_unused_parameters=False),
309
+ accelerator='gpu',
310
+ precision = 16,
311
+ devices=args.gpus,
312
+ logger=logger,
313
+ default_root_dir=args.save_path,
314
+ callbacks=[checkpoint, lr_monitor],
315
+ max_epochs=args.epochs,
316
+ gradient_clip_val=args.max_grad_norm,
317
+ accumulate_grad_batches=args.gradient_accumulation_steps,
318
+ check_val_every_n_epoch=args.eval_per_epoch,
319
+ log_every_n_steps=10,
320
+ deterministic='warn')
321
+
322
+ if args.do_train:
323
+ trainer.num_training_steps = math.ceil(
324
+ len(dm.train_dataset) / (args.batch_size * args.gpus * args.gradient_accumulation_steps)) * args.epochs
325
+ model.eval_dataset = dm.val_dataset
326
+ ckpt_path = os.path.join(args.save_path, 'checkpoints/last.ckpt') if args.resume else None
327
+ trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)
328
+ model = ChemIENERecognizer.load_from_checkpoint(checkpoint.best_model_path, args=args)
329
+
330
+ if args.do_valid:
331
+
332
+ model.eval_dataset = dm.val_dataset
333
+
334
+ trainer.validate(model, datamodule=dm)
335
+
336
+ if args.do_test:
337
+
338
+ model.test_dataset = dm.test_dataset
339
+
340
+ trainer.test(model, datamodule=dm)
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()
345
+
chemiener/model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification
6
+
7
+
8
+ def build_model(args):
9
+ if args.corpus == "chemu":
10
+ return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False)
11
+ elif args.corpus == "chemdner":
12
+ return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False)
13
+ elif args.corpus == "chemdner-mol":
14
+ return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False)
chemiener/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def merge_predictions(results):
3
+ if len(results) == 0:
4
+ return []
5
+ predictions = {}
6
+ for batch_preds in results:
7
+ for idx, preds in enumerate(batch_preds):
8
+ predictions[idx] = preds
9
+ predictions = [predictions[i] for i in range(len(predictions))]
10
+
11
+ return predictions
12
+
13
+ def get_class_to_index(corpus):
14
+ if corpus == "chemu":
15
+ return {'B-EXAMPLE_LABEL': 1, 'B-REACTION_PRODUCT': 2, 'B-STARTING_MATERIAL': 3, 'B-REAGENT_CATALYST': 4, 'B-SOLVENT': 5, 'B-OTHER_COMPOUND': 6, 'B-TIME': 7, 'B-TEMPERATURE': 8, 'B-YIELD_OTHER': 9, 'B-YIELD_PERCENT': 10, 'O': 0,
16
+ 'I-EXAMPLE_LABEL': 11, 'I-REACTION_PRODUCT': 12, 'I-STARTING_MATERIAL': 13, 'I-REAGENT_CATALYST': 14, 'I-SOLVENT': 15, 'I-OTHER_COMPOUND': 16, 'I-TIME': 17, 'I-TEMPERATURE': 18, 'I-YIELD_OTHER': 19, 'I-YIELD_PERCENT': 20}
17
+ elif corpus == "chemdner":
18
+ return {'O': 0, 'B-ABBREVIATION': 1, 'B-FAMILY': 2, 'B-FORMULA': 3, 'B-IDENTIFIER': 4, 'B-MULTIPLE': 5, 'B-SYSTEMATIC': 6, 'B-TRIVIAL': 7, 'B-NO CLASS': 8, 'I-ABBREVIATION': 9, 'I-FAMILY': 10, 'I-FORMULA': 11, 'I-IDENTIFIER': 12, 'I-MULTIPLE': 13, 'I-SYSTEMATIC': 14, 'I-TRIVIAL': 15, 'I-NO CLASS': 16}
19
+ elif corpus == "chemdner-mol":
20
+ return {'O': 0, 'B-MOL': 1, 'I-MOL': 2}
21
+
22
+
23
+
chemietoolkit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .interface import ChemIEToolkit
chemietoolkit/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
chemietoolkit/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (189 Bytes). View file
 
chemietoolkit/__pycache__/chemrxnextractor.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
chemietoolkit/__pycache__/chemrxnextractor.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
chemietoolkit/__pycache__/interface.cpython-310.pyc ADDED
Binary file (29.3 kB). View file
 
chemietoolkit/__pycache__/interface.cpython-38.pyc ADDED
Binary file (30 kB). View file
 
chemietoolkit/__pycache__/tableextractor.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
chemietoolkit/__pycache__/utils.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
chemietoolkit/chemrxnextractor.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyPDF2 import PdfReader, PdfWriter
2
+ import pdfminer.high_level
3
+ import pdfminer.layout
4
+ from operator import itemgetter
5
+ import os
6
+ import pdftotext
7
+ from chemrxnextractor import RxnExtractor
8
+
9
+ class ChemRxnExtractor(object):
10
+ def __init__(self, pdf, pn, model_dir, device):
11
+ self.pdf_file = pdf
12
+ self.pages = pn
13
+ self.model_dir = os.path.join(model_dir, "cre_models_v0.1") # directory saving both prod and role models
14
+ use_cuda = (device == 'cuda')
15
+ self.rxn_extractor = RxnExtractor(self.model_dir, use_cuda=use_cuda)
16
+ self.text_file = "info.txt"
17
+ self.pdf_text = ""
18
+ if len(self.pdf_file) > 0:
19
+ with open(self.pdf_file, "rb") as f:
20
+ self.pdf_text = pdftotext.PDF(f)
21
+
22
+ def set_pdf_file(self, pdf):
23
+ self.pdf_file = pdf
24
+ with open(self.pdf_file, "rb") as f:
25
+ self.pdf_text = pdftotext.PDF(f)
26
+
27
+ def set_pages(self, pn):
28
+ self.pages = pn
29
+
30
+ def set_model_dir(self, md):
31
+ self.model_dir = md
32
+ self.rxn_extractor = RxnExtractor(self.model_dir)
33
+
34
+ def set_text_file(self, tf):
35
+ self.text_file = tf
36
+
37
+ def extract_reactions_from_text(self):
38
+ if self.pages is None:
39
+ return self.extract_all(len(self.pdf_text))
40
+ else:
41
+ return self.extract_all(self.pages)
42
+
43
+ def extract_all(self, pages):
44
+ ans = []
45
+ text = self.get_paragraphs_from_pdf(pages)
46
+ for data in text:
47
+ L = [sent for paragraph in data['paragraphs'] for sent in paragraph]
48
+ reactions = self.get_reactions(L, page_number=data['page'])
49
+ ans.append(reactions)
50
+ return ans
51
+
52
+ def get_reactions(self, sents, page_number=None):
53
+ rxns = self.rxn_extractor.get_reactions(sents)
54
+
55
+ ret = []
56
+ for r in rxns:
57
+ if len(r['reactions']) != 0: ret.append(r)
58
+ ans = {}
59
+ ans.update({'page' : page_number})
60
+ ans.update({'reactions' : ret})
61
+ return ans
62
+
63
+
64
+ def get_paragraphs_from_pdf(self, pages):
65
+ current_page_num = 1
66
+ if pages is None:
67
+ pages = len(self.pdf_text)
68
+ result = []
69
+ for page in range(pages):
70
+ content = self.pdf_text[page]
71
+ pg = content.split("\n\n")
72
+ L = []
73
+ for line in pg:
74
+ paragraph = []
75
+ if '\x0c' in line:
76
+ continue
77
+ text = line
78
+ text = text.replace("\n", " ")
79
+ text = text.replace("- ", "-")
80
+ curind = 0
81
+ i = 0
82
+ while i < len(text):
83
+ if text[i] == '.':
84
+ if i != 0 and not text[i-1].isdigit() or i != len(text) - 1 and (text[i+1] == " " or text[i+1] == "\n"):
85
+ paragraph.append(text[curind:i+1] + "\n")
86
+ while(i < len(text) and text[i] != " "):
87
+ i += 1
88
+ curind = i + 1
89
+ i += 1
90
+ if curind != i:
91
+ if text[i - 1] == " ":
92
+ if i != 1:
93
+ i -= 1
94
+ else:
95
+ break
96
+ if text[i - 1] != '.':
97
+ paragraph.append(text[curind:i] + ".\n")
98
+ else:
99
+ paragraph.append(text[curind:i] + "\n")
100
+ L.append(paragraph)
101
+
102
+ result.append({
103
+ 'paragraphs': L,
104
+ 'page': current_page_num
105
+ })
106
+ current_page_num += 1
107
+ return result
chemietoolkit/interface.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ from functools import lru_cache
4
+ import layoutparser as lp
5
+ import pdf2image
6
+ from PIL import Image
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+ from molnextr import MolScribe
9
+ from rxnim import RxnScribe, MolDetect
10
+ from chemiener import ChemNER
11
+ from .chemrxnextractor import ChemRxnExtractor
12
+ from .tableextractor import TableExtractor
13
+ from .utils import *
14
+
15
+ class ChemIEToolkit:
16
+ def __init__(self, device=None):
17
+ if device is None:
18
+ self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
+ else:
20
+ self.device = torch.device(device)
21
+
22
+ self._molscribe = None
23
+ self._rxnscribe = None
24
+ self._pdfparser = None
25
+ self._moldet = None
26
+ self._chemrxnextractor = None
27
+ self._chemner = None
28
+ self._coref = None
29
+
30
+ @property
31
+ def molscribe(self):
32
+ if self._molscribe is None:
33
+ self.init_molscribe()
34
+ return self._molscribe
35
+
36
+ @lru_cache(maxsize=None)
37
+ def init_molscribe(self, ckpt_path=None):
38
+ """
39
+ Set model to custom checkpoint
40
+ Parameters:
41
+ ckpt_path: path to checkpoint to use, if None then will use default
42
+ """
43
+ if ckpt_path is None:
44
+ ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m680k.pth")
45
+ self._molscribe = MolScribe(ckpt_path, device=self.device)
46
+
47
+
48
+ @property
49
+ def rxnscribe(self):
50
+ if self._rxnscribe is None:
51
+ self.init_rxnscribe()
52
+ return self._rxnscribe
53
+
54
+ @lru_cache(maxsize=None)
55
+ def init_rxnscribe(self, ckpt_path=None):
56
+ """
57
+ Set model to custom checkpoint
58
+ Parameters:
59
+ ckpt_path: path to checkpoint to use, if None then will use default
60
+ """
61
+ if ckpt_path is None:
62
+ ckpt_path = hf_hub_download("yujieq/RxnScribe", "pix2seq_reaction_full.ckpt")
63
+ self._rxnscribe = RxnScribe(ckpt_path, device=self.device)
64
+
65
+
66
+ @property
67
+ def pdfparser(self):
68
+ if self._pdfparser is None:
69
+ self.init_pdfparser()
70
+ return self._pdfparser
71
+
72
+ @lru_cache(maxsize=None)
73
+ def init_pdfparser(self, ckpt_path=None):
74
+ """
75
+ Set model to custom checkpoint
76
+ Parameters:
77
+ ckpt_path: path to checkpoint to use, if None then will use default
78
+ """
79
+ config_path = "lp://efficientdet/PubLayNet/tf_efficientdet_d1"
80
+ self._pdfparser = lp.AutoLayoutModel(config_path, model_path=ckpt_path, device=self.device.type)
81
+
82
+
83
+ @property
84
+ def moldet(self):
85
+ if self._moldet is None:
86
+ self.init_moldet()
87
+ return self._moldet
88
+
89
+ @lru_cache(maxsize=None)
90
+ def init_moldet(self, ckpt_path=None):
91
+ """
92
+ Set model to custom checkpoint
93
+ Parameters:
94
+ ckpt_path: path to checkpoint to use, if None then will use default
95
+ """
96
+ if ckpt_path is None:
97
+ ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "best_hf.ckpt")
98
+ self._moldet = MolDetect(ckpt_path, device=self.device)
99
+
100
+
101
+ @property
102
+ def coref(self):
103
+ if self._coref is None:
104
+ self.init_coref()
105
+ return self._coref
106
+
107
+ @lru_cache(maxsize=None)
108
+ def init_coref(self, ckpt_path=None):
109
+ """
110
+ Set model to custom checkpoint
111
+ Parameters:
112
+ ckpt_path: path to checkpoint to use, if None then will use default
113
+ """
114
+ if ckpt_path is None:
115
+ ckpt_path = hf_hub_download("Ozymandias314/MolDetectCkpt", "coref_best_hf.ckpt")
116
+ self._coref = MolDetect(ckpt_path, device=self.device, coref=True)
117
+
118
+
119
+ @property
120
+ def chemrxnextractor(self):
121
+ if self._chemrxnextractor is None:
122
+ self.init_chemrxnextractor()
123
+ return self._chemrxnextractor
124
+
125
+ @lru_cache(maxsize=None)
126
+ def init_chemrxnextractor(self, ckpt_path=None):
127
+ """
128
+ Set model to custom checkpoint
129
+ Parameters:
130
+ ckpt_path: path to checkpoint to use, if None then will use default
131
+ """
132
+ if ckpt_path is None:
133
+ ckpt_path = snapshot_download(repo_id="amberwang/chemrxnextractor-training-modules")
134
+ self._chemrxnextractor = ChemRxnExtractor("", None, ckpt_path, self.device.type)
135
+
136
+
137
+ @property
138
+ def chemner(self):
139
+ if self._chemner is None:
140
+ self.init_chemner()
141
+ return self._chemner
142
+
143
+ @lru_cache(maxsize=None)
144
+ def init_chemner(self, ckpt_path=None):
145
+ """
146
+ Set model to custom checkpoint
147
+ Parameters:
148
+ ckpt_path: path to checkpoint to use, if None then will use default
149
+ """
150
+ if ckpt_path is None:
151
+ ckpt_path = hf_hub_download("Ozymandias314/ChemNERckpt", "best.ckpt")
152
+ self._chemner = ChemNER(ckpt_path, device=self.device)
153
+
154
+
155
+ @property
156
+ def tableextractor(self):
157
+ return TableExtractor()
158
+
159
+
160
+ def extract_figures_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
161
+ """
162
+ Find and return all figures from a pdf page
163
+ Parameters:
164
+ pdf: path to pdf
165
+ num_pages: process only first `num_pages` pages, if `None` then process all
166
+ output_bbox: whether to output bounding boxes for each individual entry of a table
167
+ output_image: whether to include PIL image for figures. default is True
168
+ Returns:
169
+ list of content in the following format
170
+ [
171
+ { # first figure
172
+ 'title': str,
173
+ 'figure': {
174
+ 'image': PIL image or None,
175
+ 'bbox': list in form [x1, y1, x2, y2],
176
+ }
177
+ 'table': {
178
+ 'bbox': list in form [x1, y1, x2, y2] or empty list,
179
+ 'content': {
180
+ 'columns': list of column headers,
181
+ 'rows': list of list of row content,
182
+ } or None
183
+ }
184
+ 'footnote': str or empty,
185
+ 'page': int
186
+ }
187
+ # more figures
188
+ ]
189
+ """
190
+ pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
191
+
192
+ table_ext = self.tableextractor
193
+ table_ext.set_pdf_file(pdf)
194
+ table_ext.set_output_image(output_image)
195
+
196
+ table_ext.set_output_bbox(output_bbox)
197
+
198
+ return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='figures')
199
+
200
+ def extract_tables_from_pdf(self, pdf, num_pages=None, output_bbox=False, output_image=True):
201
+ """
202
+ Find and return all tables from a pdf page
203
+ Parameters:
204
+ pdf: path to pdf
205
+ num_pages: process only first `num_pages` pages, if `None` then process all
206
+ output_bbox: whether to include bboxes for individual entries of the table
207
+ output_image: whether to include PIL image for figures. default is True
208
+ Returns:
209
+ list of content in the following format
210
+ [
211
+ { # first table
212
+ 'title': str,
213
+ 'figure': {
214
+ 'image': PIL image or None,
215
+ 'bbox': list in form [x1, y1, x2, y2] or empty list,
216
+ }
217
+ 'table': {
218
+ 'bbox': list in form [x1, y1, x2, y2] or empty list,
219
+ 'content': {
220
+ 'columns': list of column headers,
221
+ 'rows': list of list of row content,
222
+ }
223
+ }
224
+ 'footnote': str or empty,
225
+ 'page': int
226
+ }
227
+ # more tables
228
+ ]
229
+ """
230
+ pages = pdf2image.convert_from_path(pdf, last_page=num_pages)
231
+
232
+ table_ext = self.tableextractor
233
+ table_ext.set_pdf_file(pdf)
234
+ table_ext.set_output_image(output_image)
235
+
236
+ table_ext.set_output_bbox(output_bbox)
237
+
238
+ return table_ext.extract_all_tables_and_figures(pages, self.pdfparser, content='tables')
239
+
240
+ def extract_molecules_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None):
241
+ """
242
+ Get all molecules and their information from a pdf
243
+ Parameters:
244
+ pdf: path to pdf, or byte file
245
+ batch_size: batch size for inference in all models
246
+ num_pages: process only first `num_pages` pages, if `None` then process all
247
+ Returns:
248
+ list of figures and corresponding molecule info in the following format
249
+ [
250
+ { # first figure
251
+ 'image': ndarray of the figure image,
252
+ 'molecules': [
253
+ { # first molecule
254
+ 'bbox': tuple in the form (x1, y1, x2, y2),
255
+ 'score': float,
256
+ 'image': ndarray of cropped molecule image,
257
+ 'smiles': str,
258
+ 'molfile': str
259
+ },
260
+ # more molecules
261
+ ],
262
+ 'page': int
263
+ },
264
+ # more figures
265
+ ]
266
+ """
267
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
268
+ images = [figure['figure']['image'] for figure in figures]
269
+ results = self.extract_molecules_from_figures(images, batch_size=batch_size)
270
+ for figure, result in zip(figures, results):
271
+ result['page'] = figure['page']
272
+ return results
273
+
274
+ def extract_molecule_bboxes_from_figures(self, figures, batch_size=16):
275
+ """
276
+ Return bounding boxes of molecules in images
277
+ Parameters:
278
+ figures: list of PIL or ndarray images
279
+ batch_size: batch size for inference
280
+ Returns:
281
+ list of results for each figure in the following format
282
+ [
283
+ [ # first figure
284
+ { # first bounding box
285
+ 'category': str,
286
+ 'bbox': tuple in the form (x1, y1, x2, y2),
287
+ 'category_id': int,
288
+ 'score': float
289
+ },
290
+ # more bounding boxes
291
+ ],
292
+ # more figures
293
+ ]
294
+ """
295
+ figures = [convert_to_pil(figure) for figure in figures]
296
+ return self.moldet.predict_images(figures, batch_size=batch_size)
297
+
298
+ def extract_molecules_from_figures(self, figures, batch_size=16):
299
+ """
300
+ Get all molecules and their information from list of figures
301
+ Parameters:
302
+ figures: list of PIL or ndarray images
303
+ batch_size: batch size for inference
304
+ Returns:
305
+ list of results for each figure in the following format
306
+ [
307
+ { # first figure
308
+ 'image': ndarray of the figure image,
309
+ 'molecules': [
310
+ { # first molecule
311
+ 'bbox': tuple in the form (x1, y1, x2, y2),
312
+ 'score': float,
313
+ 'image': ndarray of cropped molecule image,
314
+ 'smiles': str,
315
+ 'molfile': str
316
+ },
317
+ # more molecules
318
+ ],
319
+ },
320
+ # more figures
321
+ ]
322
+ """
323
+ bboxes = self.extract_molecule_bboxes_from_figures(figures, batch_size=batch_size)
324
+ figures = [convert_to_cv2(figure) for figure in figures]
325
+ results, cropped_images, refs = clean_bbox_output(figures, bboxes)
326
+ mol_info = self.molscribe.predict_images(cropped_images, batch_size=batch_size)
327
+ for info, ref in zip(mol_info, refs):
328
+ ref.update(info)
329
+ return results
330
+
331
+ def extract_molecule_corefs_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe = True, ocr = True):
332
+ """
333
+ Get all molecule bboxes and corefs from figures in pdf
334
+ Parameters:
335
+ pdf: path to pdf, or byte file
336
+ batch_size: batch size for inference in all models
337
+ num_pages: process only first `num_pages` pages, if `None` then process all
338
+ Returns:
339
+ list of results for each figure in the following format:
340
+ [
341
+ {
342
+ 'bboxes': [
343
+ { # first bbox
344
+ 'category': '[Sup]',
345
+ 'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
346
+ 'category_id': 4,
347
+ 'score': -0.07593922317028046
348
+ },
349
+ # More bounding boxes
350
+ ],
351
+ 'corefs': [
352
+ [0, 1], # molecule bbox index, identifier bbox index
353
+ [3, 4],
354
+ # More coref pairs
355
+ ],
356
+ 'page': int
357
+ },
358
+ # More figures
359
+ ]
360
+ """
361
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
362
+ images = [figure['figure']['image'] for figure in figures]
363
+ results = self.extract_molecule_corefs_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
364
+ for figure, result in zip(figures, results):
365
+ result['page'] = figure['page']
366
+ return results
367
+
368
+ def extract_molecule_corefs_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
369
+ """
370
+ Get all molecule bboxes and corefs from list of figures
371
+ Parameters:
372
+ figures: list of PIL or ndarray images
373
+ batch_size: batch size for inference
374
+ Returns:
375
+ list of results for each figure in the following format:
376
+ [
377
+ {
378
+ 'bboxes': [
379
+ { # first bbox
380
+ 'category': '[Sup]',
381
+ 'bbox': (0.0050025012506253125, 0.38273870663142223, 0.9934967483741871, 0.9450094869920168),
382
+ 'category_id': 4,
383
+ 'score': -0.07593922317028046
384
+ },
385
+ # More bounding boxes
386
+ ],
387
+ 'corefs': [
388
+ [0, 1], # molecule bbox index, identifier bbox index
389
+ [3, 4],
390
+ # More coref pairs
391
+ ],
392
+ },
393
+ # More figures
394
+ ]
395
+ """
396
+ figures = [convert_to_pil(figure) for figure in figures]
397
+ return self.coref.predict_images(figures, batch_size=batch_size, coref=True, molscribe = molscribe, ocr = ocr)
398
+
399
+ def extract_reactions_from_figures_in_pdf(self, pdf, batch_size=16, num_pages=None, molscribe=True, ocr=True):
400
+ """
401
+ Get reaction information from figures in pdf
402
+ Parameters:
403
+ pdf: path to pdf, or byte file
404
+ batch_size: batch size for inference in all models
405
+ num_pages: process only first `num_pages` pages, if `None` then process all
406
+ molscribe: whether to predict and return smiles and molfile info
407
+ ocr: whether to predict and return text of conditions
408
+ Returns:
409
+ list of figures and corresponding molecule info in the following format
410
+ [
411
+ {
412
+ 'figure': PIL image
413
+ 'reactions': [
414
+ {
415
+ 'reactants': [
416
+ {
417
+ 'category': str,
418
+ 'bbox': tuple (x1,x2,y1,y2),
419
+ 'category_id': int,
420
+ 'smiles': str,
421
+ 'molfile': str,
422
+ },
423
+ # more reactants
424
+ ],
425
+ 'conditions': [
426
+ {
427
+ 'category': str,
428
+ 'bbox': tuple (x1,x2,y1,y2),
429
+ 'category_id': int,
430
+ 'text': list of str,
431
+ },
432
+ # more conditions
433
+ ],
434
+ 'products': [
435
+ # same structure as reactants
436
+ ]
437
+ },
438
+ # more reactions
439
+ ],
440
+ 'page': int
441
+ },
442
+ # more figures
443
+ ]
444
+ """
445
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
446
+ images = [figure['figure']['image'] for figure in figures]
447
+ results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
448
+ for figure, result in zip(figures, results):
449
+ result['page'] = figure['page']
450
+ return results
451
+
452
+ def extract_reactions_from_figures(self, figures, batch_size=16, molscribe=True, ocr=True):
453
+ """
454
+ Get reaction information from list of figures
455
+ Parameters:
456
+ figures: list of PIL or ndarray images
457
+ batch_size: batch size for inference in all models
458
+ molscribe: whether to predict and return smiles and molfile info
459
+ ocr: whether to predict and return text of conditions
460
+ Returns:
461
+ list of figures and corresponding molecule info in the following format
462
+ [
463
+ {
464
+ 'figure': PIL image
465
+ 'reactions': [
466
+ {
467
+ 'reactants': [
468
+ {
469
+ 'category': str,
470
+ 'bbox': tuple (x1,x2,y1,y2),
471
+ 'category_id': int,
472
+ 'smiles': str,
473
+ 'molfile': str,
474
+ },
475
+ # more reactants
476
+ ],
477
+ 'conditions': [
478
+ {
479
+ 'category': str,
480
+ 'bbox': tuple (x1,x2,y1,y2),
481
+ 'category_id': int,
482
+ 'text': list of str,
483
+ },
484
+ # more conditions
485
+ ],
486
+ 'products': [
487
+ # same structure as reactants
488
+ ]
489
+ },
490
+ # more reactions
491
+ ],
492
+ },
493
+ # more figures
494
+ ]
495
+
496
+ """
497
+ pil_figures = [convert_to_pil(figure) for figure in figures]
498
+ results = []
499
+ reactions = self.rxnscribe.predict_images(pil_figures, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
500
+ for figure, rxn in zip(figures, reactions):
501
+ data = {
502
+ 'figure': figure,
503
+ 'reactions': rxn,
504
+ }
505
+ results.append(data)
506
+ return results
507
+
508
+ def extract_molecules_from_text_in_pdf(self, pdf, batch_size=16, num_pages=None):
509
+ """
510
+ Get molecules in text of given pdf
511
+
512
+ Parameters:
513
+ pdf: path to pdf, or byte file
514
+ batch_size: batch size for inference in all models
515
+ num_pages: process only first `num_pages` pages, if `None` then process all
516
+ Returns:
517
+ list of sentences and found molecules in the following format
518
+ [
519
+ {
520
+ 'molecules': [
521
+ { # first paragraph
522
+ 'text': str,
523
+ 'labels': [
524
+ (str, int, int), # tuple of label, range start (inclusive), range end (exclusive)
525
+ # more labels
526
+ ]
527
+ },
528
+ # more paragraphs
529
+ ]
530
+ 'page': int
531
+ },
532
+ # more pages
533
+ ]
534
+ """
535
+ self.chemrxnextractor.set_pdf_file(pdf)
536
+ self.chemrxnextractor.set_pages(num_pages)
537
+ text = self.chemrxnextractor.get_paragraphs_from_pdf(num_pages)
538
+ result = []
539
+ for data in text:
540
+ model_inp = []
541
+ for paragraph in data['paragraphs']:
542
+ model_inp.append(' '.join(paragraph).replace('\n', ''))
543
+ output = self.chemner.predict_strings(model_inp, batch_size=batch_size)
544
+ to_add = {
545
+ 'molecules': [{
546
+ 'text': t,
547
+ 'labels': labels,
548
+ } for t, labels in zip(model_inp, output)],
549
+ 'page': data['page']
550
+ }
551
+ result.append(to_add)
552
+ return result
553
+
554
+
555
+ def extract_reactions_from_text_in_pdf(self, pdf, num_pages=None):
556
+ """
557
+ Get reaction information from text in pdf
558
+ Parameters:
559
+ pdf: path to pdf
560
+ num_pages: process only first `num_pages` pages, if `None` then process all
561
+ Returns:
562
+ list of pages and corresponding reaction info in the following format
563
+ [
564
+ {
565
+ 'page': page number
566
+ 'reactions': [
567
+ {
568
+ 'tokens': list of words in relevant sentence,
569
+ 'reactions' : [
570
+ {
571
+ # key, value pairs where key is the label and value is a tuple
572
+ # or list of tuples of the form (tokens, start index, end index)
573
+ # where indices are for the corresponding token list and start and end are inclusive
574
+ }
575
+ # more reactions
576
+ ]
577
+ }
578
+ # more reactions in other sentences
579
+ ]
580
+ },
581
+ # more pages
582
+ ]
583
+ """
584
+ self.chemrxnextractor.set_pdf_file(pdf)
585
+ self.chemrxnextractor.set_pages(num_pages)
586
+ return self.chemrxnextractor.extract_reactions_from_text()
587
+
588
+ def extract_reactions_from_text_in_pdf_combined(self, pdf, num_pages=None):
589
+ """
590
+ Get reaction information from text in pdf and combined with corefs from figures
591
+ Parameters:
592
+ pdf: path to pdf
593
+ num_pages: process only first `num_pages` pages, if `None` then process all
594
+ Returns:
595
+ list of pages and corresponding reaction info in the following format
596
+ [
597
+ {
598
+ 'page': page number
599
+ 'reactions': [
600
+ {
601
+ 'tokens': list of words in relevant sentence,
602
+ 'reactions' : [
603
+ {
604
+ # key, value pairs where key is the label and value is a tuple
605
+ # or list of tuples of the form (tokens, start index, end index)
606
+ # where indices are for the corresponding token list and start and end are inclusive
607
+ }
608
+ # more reactions
609
+ ]
610
+ }
611
+ # more reactions in other sentences
612
+ ]
613
+ },
614
+ # more pages
615
+ ]
616
+ """
617
+ results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
618
+ results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
619
+ return associate_corefs(results, results_coref)
620
+
621
+ def extract_reactions_from_figures_and_tables_in_pdf(self, pdf, num_pages=None, batch_size=16, molscribe=True, ocr=True):
622
+ """
623
+ Get reaction information from figures and combine with table information in pdf
624
+ Parameters:
625
+ pdf: path to pdf, or byte file
626
+ batch_size: batch size for inference in all models
627
+ num_pages: process only first `num_pages` pages, if `None` then process all
628
+ molscribe: whether to predict and return smiles and molfile info
629
+ ocr: whether to predict and return text of conditions
630
+ Returns:
631
+ list of figures and corresponding molecule info in the following format
632
+ [
633
+ {
634
+ 'figure': PIL image
635
+ 'reactions': [
636
+ {
637
+ 'reactants': [
638
+ {
639
+ 'category': str,
640
+ 'bbox': tuple (x1,x2,y1,y2),
641
+ 'category_id': int,
642
+ 'smiles': str,
643
+ 'molfile': str,
644
+ },
645
+ # more reactants
646
+ ],
647
+ 'conditions': [
648
+ {
649
+ 'category': str,
650
+ 'text': list of str,
651
+ },
652
+ # more conditions
653
+ ],
654
+ 'products': [
655
+ # same structure as reactants
656
+ ]
657
+ },
658
+ # more reactions
659
+ ],
660
+ 'page': int
661
+ },
662
+ # more figures
663
+ ]
664
+ """
665
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
666
+ images = [figure['figure']['image'] for figure in figures]
667
+ results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=molscribe, ocr=ocr)
668
+ results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
669
+ results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
670
+ results = replace_rgroups_in_figure(figures, results, results_coref, self.molscribe, batch_size=batch_size)
671
+ results = expand_reactions_with_backout(results, results_coref, self.molscribe)
672
+ return results
673
+
674
+ def extract_reactions_from_pdf(self, pdf, num_pages=None, batch_size=16):
675
+ """
676
+ Returns:
677
+ dictionary of reactions from multimodal sources
678
+ {
679
+ 'figures': [
680
+ {
681
+ 'figure': PIL image
682
+ 'reactions': [
683
+ {
684
+ 'reactants': [
685
+ {
686
+ 'category': str,
687
+ 'bbox': tuple (x1,x2,y1,y2),
688
+ 'category_id': int,
689
+ 'smiles': str,
690
+ 'molfile': str,
691
+ },
692
+ # more reactants
693
+ ],
694
+ 'conditions': [
695
+ {
696
+ 'category': str,
697
+ 'text': list of str,
698
+ },
699
+ # more conditions
700
+ ],
701
+ 'products': [
702
+ # same structure as reactants
703
+ ]
704
+ },
705
+ # more reactions
706
+ ],
707
+ 'page': int
708
+ },
709
+ # more figures
710
+ ]
711
+ 'text': [
712
+ {
713
+ 'page': page number
714
+ 'reactions': [
715
+ {
716
+ 'tokens': list of words in relevant sentence,
717
+ 'reactions' : [
718
+ {
719
+ # key, value pairs where key is the label and value is a tuple
720
+ # or list of tuples of the form (tokens, start index, end index)
721
+ # where indices are for the corresponding token list and start and end are inclusive
722
+ }
723
+ # more reactions
724
+ ]
725
+ }
726
+ # more reactions in other sentences
727
+ ]
728
+ },
729
+ # more pages
730
+ ]
731
+ }
732
+
733
+ """
734
+ figures = self.extract_figures_from_pdf(pdf, num_pages=num_pages, output_bbox=True)
735
+ images = [figure['figure']['image'] for figure in figures]
736
+ results = self.extract_reactions_from_figures(images, batch_size=batch_size, molscribe=True, ocr=True)
737
+ table_expanded_results = process_tables(figures, results, self.molscribe, batch_size=batch_size)
738
+ text_results = self.extract_reactions_from_text_in_pdf(pdf, num_pages=num_pages)
739
+ results_coref = self.extract_molecule_corefs_from_figures_in_pdf(pdf, num_pages=num_pages)
740
+ figure_results = replace_rgroups_in_figure(figures, table_expanded_results, results_coref, self.molscribe, batch_size=batch_size)
741
+ table_expanded_results = expand_reactions_with_backout(figure_results, results_coref, self.molscribe)
742
+ coref_expanded_results = associate_corefs(text_results, results_coref)
743
+ return {
744
+ 'figures': table_expanded_results,
745
+ 'text': coref_expanded_results,
746
+ }
747
+
748
+ if __name__=="__main__":
749
+ model = ChemIEToolkit()
chemietoolkit/tableextractor.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdf2image
2
+ import numpy as np
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ import layoutparser as lp
6
+ import cv2
7
+
8
+ from PyPDF2 import PdfReader, PdfWriter
9
+ import pandas as pd
10
+
11
+ import pdfminer.high_level
12
+ import pdfminer.layout
13
+ from operator import itemgetter
14
+
15
+ # inputs: pdf_file, page #, bounding box (optional) (llur or ullr), output_bbox
16
+ class TableExtractor(object):
17
+ def __init__(self, output_bbox=True):
18
+ self.pdf_file = ""
19
+ self.page = ""
20
+ self.image_dpi = 200
21
+ self.pdf_dpi = 72
22
+ self.output_bbox = output_bbox
23
+ self.blocks = {}
24
+ self.title_y = 0
25
+ self.column_header_y = 0
26
+ self.model = None
27
+ self.img = None
28
+ self.output_image = True
29
+ self.tagging = {
30
+ 'substance': ['compound', 'salt', 'base', 'solvent', 'CBr4', 'collidine', 'InX3', 'substrate', 'ligand', 'PPh3', 'PdL2', 'Cu', 'compd', 'reagent', 'reagant', 'acid', 'aldehyde', 'amine', 'Ln', 'H2O', 'enzyme', 'cofactor', 'oxidant', 'Pt(COD)Cl2', 'CuBr2', 'additive'],
31
+ 'ratio': [':'],
32
+ 'measurement': ['μM', 'nM', 'IC50', 'CI', 'excitation', 'emission', 'Φ', 'φ', 'shift', 'ee', 'ΔG', 'ΔH', 'TΔS', 'Δ', 'distance', 'trajectory', 'V', 'eV'],
33
+ 'temperature': ['temp', 'temperature', 'T', '°C'],
34
+ 'time': ['time', 't(', 't ('],
35
+ 'result': ['yield', 'aa', 'result', 'product', 'conversion', '(%)'],
36
+ 'alkyl group': ['R', 'Ar', 'X', 'Y'],
37
+ 'solvent': ['solvent'],
38
+ 'counter': ['entry', 'no.'],
39
+ 'catalyst': ['catalyst', 'cat.'],
40
+ 'conditions': ['condition'],
41
+ 'reactant': ['reactant'],
42
+ }
43
+
44
+ def set_output_image(self, oi):
45
+ self.output_image = oi
46
+
47
+ def set_pdf_file(self, pdf):
48
+ self.pdf_file = pdf
49
+
50
+ def set_page_num(self, pn):
51
+ self.page = pn
52
+
53
+ def set_output_bbox(self, ob):
54
+ self.output_bbox = ob
55
+
56
+ def run_model(self, page_info):
57
+ #img = np.asarray(pdf2image.convert_from_path(self.pdf_file, dpi=self.image_dpi)[self.page])
58
+
59
+ #model = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config', extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5], label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"})
60
+
61
+ img = np.asarray(page_info)
62
+ self.img = img
63
+
64
+ layout_result = self.model.detect(img)
65
+
66
+ text_blocks = lp.Layout([b for b in layout_result if b.type == 'Text'])
67
+ title_blocks = lp.Layout([b for b in layout_result if b.type == 'Title'])
68
+ list_blocks = lp.Layout([b for b in layout_result if b.type == 'List'])
69
+ table_blocks = lp.Layout([b for b in layout_result if b.type == 'Table'])
70
+ figure_blocks = lp.Layout([b for b in layout_result if b.type == 'Figure'])
71
+
72
+ self.blocks.update({'text': text_blocks})
73
+ self.blocks.update({'title': title_blocks})
74
+ self.blocks.update({'list': list_blocks})
75
+ self.blocks.update({'table': table_blocks})
76
+ self.blocks.update({'figure': figure_blocks})
77
+
78
+ # type is what coordinates you want to get. it comes in text, title, list, table, and figure
79
+ def convert_to_pdf_coordinates(self, type):
80
+ # scale coordinates
81
+
82
+ blocks = self.blocks[type]
83
+ coordinates = [blocks[a].scale(self.pdf_dpi/self.image_dpi) for a in range(len(blocks))]
84
+
85
+ reader = PdfReader(self.pdf_file)
86
+
87
+ writer = PdfWriter()
88
+ p = reader.pages[self.page]
89
+ a = p.mediabox.upper_left
90
+ new_coords = []
91
+ for new_block in coordinates:
92
+ new_coords.append((new_block.block.x_1, pd.to_numeric(a[1]) - new_block.block.y_2, new_block.block.x_2, pd.to_numeric(a[1]) - new_block.block.y_1))
93
+
94
+ return new_coords
95
+ # output: list of bounding boxes for tables but in pdf coordinates
96
+
97
+ # input: new_coords is singular table bounding box in pdf coordinates
98
+ def extract_singular_table(self, new_coords):
99
+ for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
100
+ elements = []
101
+ for element in page_layout:
102
+ if isinstance(element, pdfminer.layout.LTTextBox):
103
+ for e in element._objs:
104
+ temp = e.bbox
105
+ if temp[0] > min(new_coords[0], new_coords[2]) and temp[0] < max(new_coords[0], new_coords[2]) and temp[1] > min(new_coords[1], new_coords[3]) and temp[1] < max(new_coords[1], new_coords[3]) and temp[2] > min(new_coords[0], new_coords[2]) and temp[2] < max(new_coords[0], new_coords[2]) and temp[3] > min(new_coords[1], new_coords[3]) and temp[3] < max(new_coords[1], new_coords[3]) and isinstance(e, pdfminer.layout.LTTextLineHorizontal):
106
+ elements.append([e.bbox[0], e.bbox[1], e.bbox[2], e.bbox[3], e.get_text()])
107
+
108
+ elements = sorted(elements, key=itemgetter(0))
109
+ w = sorted(elements, key=itemgetter(3), reverse=True)
110
+ if len(w) <= 1:
111
+ continue
112
+
113
+ ret = {}
114
+ i = 1
115
+ g = [w[0]]
116
+
117
+ while i < len(w) and w[i][3] > w[i-1][1]:
118
+ g.append(w[i])
119
+ i += 1
120
+ g = sorted(g, key=itemgetter(0))
121
+ # check for overlaps
122
+ for a in range(len(g)-1, 0, -1):
123
+ if g[a][0] < g[a-1][2]:
124
+ g[a-1][0] = min(g[a][0], g[a-1][0])
125
+ g[a-1][1] = min(g[a][1], g[a-1][1])
126
+ g[a-1][2] = max(g[a][2], g[a-1][2])
127
+ g[a-1][3] = max(g[a][3], g[a-1][3])
128
+ g[a-1][4] = g[a-1][4].strip() + " " + g[a][4]
129
+ g.pop(a)
130
+
131
+
132
+ ret.update({"columns":[]})
133
+ for t in g:
134
+ temp_bbox = t[:4]
135
+
136
+ column_text = t[4].strip()
137
+ tag = 'unknown'
138
+ tagged = False
139
+ for key in self.tagging.keys():
140
+ for word in self.tagging[key]:
141
+ if word in column_text:
142
+ tag = key
143
+ tagged = True
144
+ break
145
+ if tagged:
146
+ break
147
+
148
+ if self.output_bbox:
149
+ ret["columns"].append({'text':column_text,'tag': tag, 'bbox':temp_bbox})
150
+ else:
151
+ ret["columns"].append({'text':column_text,'tag': tag})
152
+ self.column_header_y = max(t[1], t[3])
153
+ ret.update({"rows":[]})
154
+
155
+ g.insert(0, [0, 0, new_coords[0], 0, ''])
156
+ g.append([new_coords[2], 0, 0, 0, ''])
157
+ while i < len(w):
158
+ group = [w[i]]
159
+ i += 1
160
+ while i < len(w) and w[i][3] > w[i-1][1]:
161
+ group.append(w[i])
162
+ i += 1
163
+ group = sorted(group, key=itemgetter(0))
164
+
165
+ for a in range(len(group)-1, 0, -1):
166
+ if group[a][0] < group[a-1][2]:
167
+ group[a-1][0] = min(group[a][0], group[a-1][0])
168
+ group[a-1][1] = min(group[a][1], group[a-1][1])
169
+ group[a-1][2] = max(group[a][2], group[a-1][2])
170
+ group[a-1][3] = max(group[a][3], group[a-1][3])
171
+ group[a-1][4] = group[a-1][4].strip() + " " + group[a][4]
172
+ group.pop(a)
173
+
174
+ a = 1
175
+ while a < len(g) - 1:
176
+ if a > len(group):
177
+ group.append([0, 0, 0, 0, '\n'])
178
+ a += 1
179
+ continue
180
+ if group[a-1][0] >= g[a-1][2] and group[a-1][2] <= g[a+1][0]:
181
+ pass
182
+ """
183
+ if a < len(group) and group[a][0] >= g[a-1][2] and group[a][2] <= g[a+1][0]:
184
+ g.insert(1, [g[0][2], 0, group[a-1][2], 0, ''])
185
+ #ret["columns"].insert(0, '')
186
+ else:
187
+ a += 1
188
+ continue
189
+ """
190
+ else: group.insert(a-1, [0, 0, 0, 0, '\n'])
191
+ a += 1
192
+
193
+
194
+ added_row = []
195
+ for t in group:
196
+ temp_bbox = t[:4]
197
+ if self.output_bbox:
198
+ added_row.append({'text':t[4].strip(), 'bbox':temp_bbox})
199
+ else:
200
+ added_row.append(t[4].strip())
201
+ ret["rows"].append(added_row)
202
+ if ret["rows"] and len(ret["rows"][0]) != len(ret["columns"]):
203
+ ret["columns"] = ret["rows"][0]
204
+ ret["rows"] = ret["rows"][1:]
205
+ for col in ret['columns']:
206
+ tag = 'unknown'
207
+ tagged = False
208
+ for key in self.tagging.keys():
209
+ for word in self.tagging[key]:
210
+ if word in col['text']:
211
+ tag = key
212
+ tagged = True
213
+ break
214
+ if tagged:
215
+ break
216
+ col['tag'] = tag
217
+
218
+ return ret
219
+
220
+ def get_title_and_footnotes(self, tb_coords):
221
+
222
+ for page_layout in pdfminer.high_level.extract_pages(self.pdf_file, page_numbers=[self.page]):
223
+ title = (0, 0, 0, 0, '')
224
+ footnote = (0, 0, 0, 0, '')
225
+ title_gap = 30
226
+ footnote_gap = 30
227
+ for element in page_layout:
228
+ if isinstance(element, pdfminer.layout.LTTextBoxHorizontal):
229
+ if (element.bbox[0] >= tb_coords[0] and element.bbox[0] <= tb_coords[2]) or (element.bbox[2] >= tb_coords[0] and element.bbox[2] <= tb_coords[2]) or (tb_coords[0] >= element.bbox[0] and tb_coords[0] <= element.bbox[2]) or (tb_coords[2] >= element.bbox[0] and tb_coords[2] <= element.bbox[2]):
230
+ #print(element)
231
+ if 'Table' in element.get_text():
232
+ if abs(element.bbox[1] - tb_coords[3]) < title_gap:
233
+ title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Table'):].replace('\n', ' '),)
234
+ title_gap = abs(element.bbox[1] - tb_coords[3])
235
+ if 'Scheme' in element.get_text():
236
+ if abs(element.bbox[1] - tb_coords[3]) < title_gap:
237
+ title = tuple(element.bbox) + (element.get_text()[element.get_text().index('Scheme'):].replace('\n', ' '),)
238
+ title_gap = abs(element.bbox[1] - tb_coords[3])
239
+ if element.bbox[1] >= tb_coords[1] and element.bbox[3] <= tb_coords[3]: continue
240
+ #print(element)
241
+ temp = ['aA', 'aB', 'aC', 'aD', 'aE', 'aF', 'aG', 'aH', 'aI', 'aJ', 'aK', 'aL', 'aM', 'aN', 'aO', 'aP', 'aQ', 'aR', 'aS', 'aT', 'aU', 'aV', 'aW', 'aX', 'aY', 'aZ', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9', 'a0']
242
+ for segment in temp:
243
+ if segment in element.get_text():
244
+ if abs(element.bbox[3] - tb_coords[1]) < footnote_gap:
245
+ footnote = tuple(element.bbox) + (element.get_text()[element.get_text().index(segment):].replace('\n', ' '),)
246
+ footnote_gap = abs(element.bbox[3] - tb_coords[1])
247
+ break
248
+ self.title_y = min(title[1], title[3])
249
+ if self.output_bbox:
250
+ return ({'text': title[4], 'bbox': list(title[:4])}, {'text': footnote[4], 'bbox': list(footnote[:4])})
251
+ else:
252
+ return (title[4], footnote[4])
253
+
254
+ def extract_table_information(self):
255
+ #self.run_model(page_info) # changed
256
+ table_coordinates = self.blocks['table'] #should return a list of layout objects
257
+ table_coordinates_in_pdf = self.convert_to_pdf_coordinates('table') #should return a list of lists
258
+
259
+ ans = []
260
+ i = 0
261
+ for coordinate in table_coordinates_in_pdf:
262
+ ret = {}
263
+ pad = 20
264
+ coordinate = [coordinate[0] - pad, coordinate[1], coordinate[2] + pad, coordinate[3]]
265
+ ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
266
+
267
+ table_results = self.extract_singular_table(coordinate)
268
+ tf = self.get_title_and_footnotes(coordinate)
269
+ figure = Image.fromarray(table_coordinates[i].crop_image(self.img))
270
+ ret.update({'title': tf[0]})
271
+ ret.update({'figure': {
272
+ 'image': None,
273
+ 'bbox': []
274
+ }})
275
+ if self.output_image:
276
+ ret['figure']['image'] = figure
277
+ ret.update({'table': {'bbox': list(coordinate), 'content': table_results}})
278
+ ret.update({'footnote': tf[1]})
279
+ if abs(self.title_y - self.column_header_y) > 50:
280
+ ret['figure']['bbox'] = list(coordinate)
281
+
282
+ ret.update({'page':self.page})
283
+
284
+ ans.append(ret)
285
+ i += 1
286
+
287
+ return ans
288
+
289
+ def extract_figure_information(self):
290
+ figure_coordinates = self.blocks['figure']
291
+ figure_coordinates_in_pdf = self.convert_to_pdf_coordinates('figure')
292
+
293
+ ans = []
294
+ for i in range(len(figure_coordinates)):
295
+ ret = {}
296
+ coordinate = figure_coordinates_in_pdf[i]
297
+ ullr_coord = [coordinate[0], coordinate[3], coordinate[2], coordinate[1]]
298
+
299
+ tf = self.get_title_and_footnotes(coordinate)
300
+ figure = Image.fromarray(figure_coordinates[i].crop_image(self.img))
301
+ ret.update({'title':tf[0]})
302
+ ret.update({'figure': {
303
+ 'image': None,
304
+ 'bbox': []
305
+ }})
306
+ if self.output_image:
307
+ ret['figure']['image'] = figure
308
+ ret.update({'table': {
309
+ 'bbox': [],
310
+ 'content': None
311
+ }})
312
+ ret.update({'footnote': tf[1]})
313
+ ret['figure']['bbox'] = list(coordinate)
314
+
315
+ ret.update({'page':self.page})
316
+
317
+ ans.append(ret)
318
+
319
+ return ans
320
+
321
+
322
+ def extract_all_tables_and_figures(self, pages, pdfparser, content=None):
323
+ self.model = pdfparser
324
+ ret = []
325
+ for i in range(len(pages)):
326
+ self.set_page_num(i)
327
+ self.run_model(pages[i])
328
+ table_info = self.extract_table_information()
329
+ figure_info = self.extract_figure_information()
330
+ if content == 'tables':
331
+ ret += table_info
332
+ elif content == 'figures':
333
+ ret += figure_info
334
+ for table in table_info:
335
+ if table['figure']['bbox'] != []:
336
+ ret.append(table)
337
+ else:
338
+ ret += table_info
339
+ ret += figure_info
340
+ return ret
chemietoolkit/utils.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+ import layoutparser as lp
5
+ from rdkit import Chem
6
+ from rdkit.Chem import Draw
7
+ from rdkit.Chem import rdDepictor
8
+ rdDepictor.SetPreferCoordGen(True)
9
+ from rdkit.Chem.Draw import IPythonConsole
10
+ from rdkit.Chem import AllChem
11
+ import re
12
+ import copy
13
+
14
+ BOND_TO_INT = {
15
+ "": 0,
16
+ "single": 1,
17
+ "double": 2,
18
+ "triple": 3,
19
+ "aromatic": 4,
20
+ "solid wedge": 5,
21
+ "dashed wedge": 6
22
+ }
23
+
24
+ RGROUP_SYMBOLS = ['R', 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12',
25
+ 'Ra', 'Rb', 'Rc', 'Rd', 'Rf', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', 'Ar1', 'Ar2', 'Ari', "R'",
26
+ '1*', '2*','3*', '4*','5*', '6*','7*', '8*','9*', '10*','11*', '12*','[a*]', '[b*]','[c*]', '[d*]']
27
+
28
+ RGROUP_SYMBOLS = RGROUP_SYMBOLS + [f'[{i}]' for i in RGROUP_SYMBOLS]
29
+
30
+ RGROUP_SMILES = ['[1*]', '[2*]','[3*]', '[4*]','[5*]', '[6*]','[7*]', '[8*]','[9*]', '[10*]','[11*]', '[12*]','[a*]', '[b*]','[c*]', '[d*]','*', '[Rf]']
31
+
32
+ def get_figures_from_pages(pages, pdfparser):
33
+ figures = []
34
+ for i in range(len(pages)):
35
+ img = np.asarray(pages[i])
36
+ layout = pdfparser.detect(img)
37
+ blocks = lp.Layout([b for b in layout if b.type == "Figure"])
38
+ for block in blocks:
39
+ figure = Image.fromarray(block.crop_image(img))
40
+ figures.append({
41
+ 'image': figure,
42
+ 'page': i
43
+ })
44
+ return figures
45
+
46
+ def clean_bbox_output(figures, bboxes):
47
+ results = []
48
+ cropped = []
49
+ references = []
50
+ for i, output in enumerate(bboxes):
51
+ mol_bboxes = [elt['bbox'] for elt in output if elt['category'] == '[Mol]']
52
+ mol_scores = [elt['score'] for elt in output if elt['category'] == '[Mol]']
53
+ data = {}
54
+ results.append(data)
55
+ data['image'] = figures[i]
56
+ data['molecules'] = []
57
+ for bbox, score in zip(mol_bboxes, mol_scores):
58
+ x1, y1, x2, y2 = bbox
59
+ height, width, _ = figures[i].shape
60
+ cropped_img = figures[i][int(y1*height):int(y2*height),int(x1*width):int(x2*width)]
61
+ cur_mol = {
62
+ 'bbox': bbox,
63
+ 'score': score,
64
+ 'image': cropped_img,
65
+ #'info': None,
66
+ }
67
+ cropped.append(cropped_img)
68
+ data['molecules'].append(cur_mol)
69
+ references.append(cur_mol)
70
+ return results, cropped, references
71
+
72
+ def convert_to_pil(image):
73
+ if type(image) == np.ndarray:
74
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
75
+ image = Image.fromarray(image)
76
+ return image
77
+
78
+ def convert_to_cv2(image):
79
+ if type(image) != np.ndarray:
80
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
81
+ return image
82
+
83
+ def replace_rgroups_in_figure(figures, results, coref_results, molscribe, batch_size=16):
84
+ pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
85
+ for figure, result, corefs in zip(figures, results, coref_results):
86
+ r_groups = []
87
+ seen_r_groups = set()
88
+ for bbox in corefs['bboxes']:
89
+ if bbox['category'] == '[Idt]':
90
+ for text in bbox['text']:
91
+ res = pattern.search(text)
92
+ if res is None:
93
+ continue
94
+ name = res.group('name')
95
+ group = res.group('group')
96
+ if (name, group) in seen_r_groups:
97
+ continue
98
+ seen_r_groups.add((name, group))
99
+ r_groups.append({name: res.group('group')})
100
+ if r_groups and result['reactions']:
101
+ seen_r_groups = set([pair[0] for pair in seen_r_groups])
102
+ orig_reaction = result['reactions'][0]
103
+ graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
104
+ relevant_locs = {}
105
+ for i, graph in enumerate(graphs):
106
+ to_add = []
107
+ for j, atom in enumerate(graph['chartok_coords']['symbols']):
108
+ if atom[1:-1] in seen_r_groups:
109
+ to_add.append((atom[1:-1], j))
110
+ relevant_locs[i] = to_add
111
+
112
+ for r_group in r_groups:
113
+ reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_group, molscribe)
114
+ to_add ={
115
+ 'reactants': reaction['reactants'][:],
116
+ 'conditions': orig_reaction['conditions'][:],
117
+ 'products': reaction['products'][:]
118
+ }
119
+ result['reactions'].append(to_add)
120
+ return results
121
+
122
+ def process_tables(figures, results, molscribe, batch_size=16):
123
+ r_group_pattern = re.compile(r'^(\w+-)?(?P<group>[\w-]+)( \(\w+\))?$')
124
+ for figure, result in zip(figures, results):
125
+ result['page'] = figure['page']
126
+ if figure['table']['content'] is not None:
127
+ content = figure['table']['content']
128
+ if len(result['reactions']) > 1:
129
+ print("Warning: multiple reactions detected for table")
130
+ elif len(result['reactions']) == 0:
131
+ continue
132
+ orig_reaction = result['reactions'][0]
133
+ graphs = get_atoms_and_bonds(figure['figure']['image'], orig_reaction, molscribe, batch_size=batch_size)
134
+ relevant_locs = find_relevant_groups(graphs, content['columns'])
135
+ conditions_to_extend = []
136
+ for row in content['rows']:
137
+ r_groups = {}
138
+ expanded_conditions = orig_reaction['conditions'][:]
139
+ replaced = False
140
+ for col, entry in zip(content['columns'], row):
141
+ if col['tag'] != 'alkyl group':
142
+ expanded_conditions.append({
143
+ 'category': '[Table]',
144
+ 'text': entry['text'],
145
+ 'tag': col['tag'],
146
+ 'header': col['text'],
147
+ })
148
+ else:
149
+ found = r_group_pattern.match(entry['text'])
150
+ if found is not None:
151
+ r_groups[col['text']] = found.group('group')
152
+ replaced = True
153
+ reaction = get_replaced_reaction(orig_reaction, graphs, relevant_locs, r_groups, molscribe)
154
+ if replaced:
155
+ to_add = {
156
+ 'reactants': reaction['reactants'][:],
157
+ 'conditions': expanded_conditions,
158
+ 'products': reaction['products'][:]
159
+ }
160
+ result['reactions'].append(to_add)
161
+ else:
162
+ conditions_to_extend.append(expanded_conditions)
163
+ orig_reaction['conditions'] = [orig_reaction['conditions']]
164
+ orig_reaction['conditions'].extend(conditions_to_extend)
165
+ return results
166
+
167
+
168
+ def get_atoms_and_bonds(image, reaction, molscribe, batch_size=16):
169
+ image = convert_to_cv2(image)
170
+ cropped_images = []
171
+ results = []
172
+ for key, molecules in reaction.items():
173
+ for i, elt in enumerate(molecules):
174
+ if type(elt) != dict or elt['category'] != '[Mol]':
175
+ continue
176
+ x1, y1, x2, y2 = elt['bbox']
177
+ height, width, _ = image.shape
178
+ cropped_images.append(image[int(y1*height):int(y2*height),int(x1*width):int(x2*width)])
179
+ to_add = {
180
+ 'image': cropped_images[-1],
181
+ 'chartok_coords': {
182
+ 'coords': [],
183
+ 'symbols': [],
184
+ },
185
+ 'edges': [],
186
+ 'key': (key, i)
187
+ }
188
+ results.append(to_add)
189
+ outputs = molscribe.predict_images(cropped_images, return_atoms_bonds=True, batch_size=batch_size)
190
+ for mol, result in zip(outputs, results):
191
+ for atom in mol['atoms']:
192
+ result['chartok_coords']['coords'].append((atom['x'], atom['y']))
193
+ result['chartok_coords']['symbols'].append(atom['atom_symbol'])
194
+ result['edges'] = [[0] * len(mol['atoms']) for _ in range(len(mol['atoms']))]
195
+ for bond in mol['bonds']:
196
+ i, j = bond['endpoint_atoms']
197
+ result['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
198
+ result['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
199
+ return results
200
+
201
+ def find_relevant_groups(graphs, columns):
202
+ results = {}
203
+ r_groups = set([f"[{col['text']}]" for col in columns if col['tag'] == 'alkyl group'])
204
+ for i, graph in enumerate(graphs):
205
+ to_add = []
206
+ for j, atom in enumerate(graph['chartok_coords']['symbols']):
207
+ if atom in r_groups:
208
+ to_add.append((atom[1:-1], j))
209
+ results[i] = to_add
210
+ return results
211
+
212
+ def get_replaced_reaction(orig_reaction, graphs, relevant_locs, mappings, molscribe):
213
+ graph_copy = []
214
+ for graph in graphs:
215
+ graph_copy.append({
216
+ 'image': graph['image'],
217
+ 'chartok_coords': {
218
+ 'coords': graph['chartok_coords']['coords'][:],
219
+ 'symbols': graph['chartok_coords']['symbols'][:],
220
+ },
221
+ 'edges': graph['edges'][:],
222
+ 'key': graph['key'],
223
+ })
224
+ for graph_idx, atoms in relevant_locs.items():
225
+ for atom, atom_idx in atoms:
226
+ if atom in mappings:
227
+ graph_copy[graph_idx]['chartok_coords']['symbols'][atom_idx] = mappings[atom]
228
+ reaction_copy = {}
229
+ def append_copy(copy_list, entity):
230
+ if entity['category'] == '[Mol]':
231
+ copy_list.append({
232
+ k1: v1 for k1, v1 in entity.items()
233
+ })
234
+ else:
235
+ copy_list.append(entity)
236
+
237
+ for k, v in orig_reaction.items():
238
+ reaction_copy[k] = []
239
+ for entity in v:
240
+ if type(entity) == list:
241
+ sub_list = []
242
+ for e in entity:
243
+ append_copy(sub_list, e)
244
+ reaction_copy[k].append(sub_list)
245
+ else:
246
+ append_copy(reaction_copy[k], entity)
247
+
248
+ for graph in graph_copy:
249
+ output = molscribe.convert_graph_to_output([graph], [graph['image']])
250
+ molecule = reaction_copy[graph['key'][0]][graph['key'][1]]
251
+ molecule['smiles'] = output[0]['smiles']
252
+ molecule['molfile'] = output[0]['molfile']
253
+ return reaction_copy
254
+
255
+ def get_sites(tar, ref, ref_site = False):
256
+ rdDepictor.Compute2DCoords(ref)
257
+ rdDepictor.Compute2DCoords(tar)
258
+ idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(tar, ref)
259
+
260
+ in_template = [i[1] for i in idx_pair]
261
+ sites = []
262
+ for i in range(tar.GetNumAtoms()):
263
+ if i not in in_template:
264
+ for j in tar.GetAtomWithIdx(i).GetNeighbors():
265
+ if j.GetIdx() in in_template and j.GetIdx() not in sites:
266
+
267
+ if ref_site: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
268
+ else: sites.append(idx_pair[in_template.index(j.GetIdx())][0])
269
+ return sites
270
+
271
+ def get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = None):
272
+ # returns prod_mol_to_query which is the mapping of atom indices in prod_mol to the atom indices of the molecule represented by prod_smiles
273
+ prod_template_intermediate = Chem.MolToSmiles(prod_mol)
274
+ prod_template = prod_smiles
275
+
276
+ for r in RGROUP_SMILES:
277
+ if r!='*' and r!='(*)':
278
+ prod_template = prod_template.replace(r, '*')
279
+ prod_template_intermediate = prod_template_intermediate.replace(r, '*')
280
+
281
+ prod_template_intermediate_mol = Chem.MolFromSmiles(prod_template_intermediate)
282
+ prod_template_mol = Chem.MolFromSmiles(prod_template)
283
+
284
+ p = Chem.AdjustQueryParameters.NoAdjustments()
285
+ p.makeDummiesQueries = True
286
+
287
+ prod_template_mol_query = Chem.AdjustQueryProperties(prod_template_mol, p)
288
+ prod_template_intermediate_mol_query = Chem.AdjustQueryProperties(prod_template_intermediate_mol, p)
289
+ rdDepictor.Compute2DCoords(prod_mol)
290
+ rdDepictor.Compute2DCoords(prod_template_mol_query)
291
+ rdDepictor.Compute2DCoords(prod_template_intermediate_mol_query)
292
+ idx_pair = rdDepictor.GenerateDepictionMatching2DStructure(prod_mol, prod_template_intermediate_mol_query)
293
+
294
+ intermdiate_to_prod_mol = {a:b for a,b in idx_pair}
295
+ prod_mol_to_intermediate = {b:a for a,b in idx_pair}
296
+
297
+
298
+ #idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
299
+
300
+ #intermediate_to_query = {a:b for a,b in idx_pair_2}
301
+ #query_to_intermediate = {b:a for a,b in idx_pair_2}
302
+
303
+ #prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
304
+
305
+
306
+ substructs = prod_template_mol_query.GetSubstructMatches(prod_template_intermediate_mol_query, uniquify = False)
307
+
308
+ #idx_pair_2 = rdDepictor.GenerateDepictionMatching2DStructure(prod_template_mol_query, prod_template_intermediate_mol_query)
309
+ for substruct in substructs:
310
+
311
+
312
+ intermediate_to_query = {a:b for a, b in enumerate(substruct)}
313
+ query_to_intermediate = {intermediate_to_query[i]: i for i in intermediate_to_query}
314
+
315
+ prod_mol_to_query = {a:intermediate_to_query[prod_mol_to_intermediate[a]] for a in prod_mol_to_intermediate}
316
+
317
+ good_map = True
318
+ for i in r_sites_reversed:
319
+ if prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[i]).GetSymbol() not in RGROUP_SMILES:
320
+ good_map = False
321
+ if good_map:
322
+ break
323
+
324
+ return prod_mol_to_query, prod_template_mol_query
325
+
326
+ def clean_corefs(coref_results_dict, idx):
327
+ label_pattern = rf'{re.escape(idx)}[a-zA-Z]+'
328
+ #unclean_pattern = re.escape(idx) + r'\d(?![\d% ])'
329
+ toreturn = {}
330
+ for prod in coref_results_dict:
331
+ has_good_label = False
332
+ for parsed in coref_results_dict[prod]:
333
+ if re.search(label_pattern, parsed):
334
+ has_good_label = True
335
+ if not has_good_label:
336
+ for parsed in coref_results_dict[prod]:
337
+ if idx+'1' in parsed:
338
+ coref_results_dict[prod].append(idx+'l')
339
+ elif idx+'0' in parsed:
340
+ coref_results_dict[prod].append(idx+'o')
341
+ elif idx+'5' in parsed:
342
+ coref_results_dict[prod].append(idx+'s')
343
+ elif idx+'9' in parsed:
344
+ coref_results_dict[prod].append(idx+'g')
345
+
346
+
347
+
348
+ def expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe):
349
+ name = res.group('name')
350
+ group = res.group('group')
351
+ #print(other_prod)
352
+ atoms = coref_smiles_to_graphs[other_prod]['atoms']
353
+ bonds = coref_smiles_to_graphs[other_prod]['bonds']
354
+
355
+ #print(atoms, bonds)
356
+
357
+ graph = {
358
+ 'image': None,
359
+ 'chartok_coords': {
360
+ 'coords': [],
361
+ 'symbols': [],
362
+ },
363
+ 'edges': [],
364
+ 'key': None
365
+ }
366
+ for atom in atoms:
367
+ graph['chartok_coords']['coords'].append((atom['x'], atom['y']))
368
+ graph['chartok_coords']['symbols'].append(atom['atom_symbol'])
369
+ graph['edges'] = [[0] * len(atoms) for _ in range(len(atoms))]
370
+ for bond in bonds:
371
+ i, j = bond['endpoint_atoms']
372
+ graph['edges'][i][j] = BOND_TO_INT[bond['bond_type']]
373
+ graph['edges'][j][i] = BOND_TO_INT[bond['bond_type']]
374
+ for i, symbol in enumerate(graph['chartok_coords']['symbols']):
375
+ if symbol[1:-1] == name:
376
+ graph['chartok_coords']['symbols'][i] = group
377
+
378
+ #print(graph)
379
+ o = molscribe.convert_graph_to_output([graph], [graph['image']])
380
+ return Chem.MolFromSmiles(o[0]['smiles'])
381
+
382
+ def get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn):
383
+ prod_template_mol_query, r_sites_reversed_new, h_sites, num_r_groups = query
384
+ # we get the substruct matches. note that we set uniquify to false since the order matters for our method
385
+ substructs = other_prod_mol.GetSubstructMatches(prod_template_mol_query, uniquify = False)
386
+
387
+
388
+ #for r in r_sites_reversed:
389
+ # print(prod_template_mol_query.GetAtomWithIdx(prod_mol_to_query[r]).GetSymbol())
390
+
391
+ # for each substruct we create the mapping of the substruct onto the other_mol
392
+ # delete all the molecules in other_mol correspond to the substruct
393
+ # and check if they number of mol frags is equal to number of r groups
394
+ # we do this to make sure we have the correct substruct
395
+ if len(substructs) >= 1:
396
+ for substruct in substructs:
397
+
398
+ query_to_other = {a:b for a,b in enumerate(substruct)}
399
+ other_to_query = {query_to_other[i]:i for i in query_to_other}
400
+
401
+ editable = Chem.EditableMol(other_prod_mol)
402
+ r_site_correspondence = []
403
+ for r in r_sites_reversed_new:
404
+ #get its id in substruct
405
+ substruct_id = query_to_other[r]
406
+ r_site_correspondence.append([substruct_id, r_sites_reversed_new[r]])
407
+
408
+ for idx in tuple(sorted(substruct, reverse = True)):
409
+ if idx not in [query_to_other[i] for i in r_sites_reversed_new]:
410
+ editable.RemoveAtom(idx)
411
+ for r_site in r_site_correspondence:
412
+ if idx < r_site[0]:
413
+ r_site[0]-=1
414
+ other_prod_removed = editable.GetMol()
415
+
416
+ if len(Chem.GetMolFrags(other_prod_removed, asMols = False)) == num_r_groups:
417
+ break
418
+
419
+ # need to compute the sites at which correspond to each r_site_reversed
420
+
421
+ r_site_correspondence.sort(key = lambda x: x[0])
422
+
423
+
424
+ f = []
425
+ ff = []
426
+ frags = Chem.GetMolFrags(other_prod_removed, asMols = True, frags = f, fragsMolAtomMapping = ff)
427
+
428
+ # r_group_information maps r group name --> the fragment/molcule corresponding to the r group and the atom index it should be connected at
429
+ r_group_information = {}
430
+ #tosubtract = 0
431
+ for idx, r_site in enumerate(r_site_correspondence):
432
+
433
+ r_group_information[r_site[1]]= (frags[f[r_site[0]]], ff[f[r_site[0]]].index(r_site[0]))
434
+ #tosubtract += len(ff[idx])
435
+ for r_site in h_sites:
436
+ r_group_information[r_site] = (Chem.MolFromSmiles('[H]'), 0)
437
+
438
+ # now we modify all of the reactants according to the R groups we have found
439
+ # for every reactant we disconnect its r group symbol, and connect it to the r group
440
+ modify_reactants = copy.deepcopy(reactant_mols)
441
+ modified_reactant_smiles = []
442
+ for reactant_idx in reactant_information:
443
+ if len(reactant_information[reactant_idx]) == 0:
444
+ modified_reactant_smiles.append(Chem.MolToSmiles(modify_reactants[reactant_idx]))
445
+ else:
446
+ combined = reactant_mols[reactant_idx]
447
+ if combined.GetNumAtoms() == 1:
448
+ r_group, _, _ = reactant_information[reactant_idx][0]
449
+ modified_reactant_smiles.append(Chem.MolToSmiles(r_group_information[r_group][0]))
450
+ else:
451
+ for r_group, r_index, connect_index in reactant_information[reactant_idx]:
452
+ combined = Chem.CombineMols(combined, r_group_information[r_group][0])
453
+
454
+ editable = Chem.EditableMol(combined)
455
+ atomIdxAdder = reactant_mols[reactant_idx].GetNumAtoms()
456
+ for r_group, r_index, connect_index in reactant_information[reactant_idx]:
457
+ Chem.EditableMol.RemoveBond(editable, r_index, connect_index)
458
+ Chem.EditableMol.AddBond(editable, connect_index, atomIdxAdder + r_group_information[r_group][1], Chem.BondType.SINGLE)
459
+ atomIdxAdder += r_group_information[r_group][0].GetNumAtoms()
460
+ r_indices = [i[1] for i in reactant_information[reactant_idx]]
461
+
462
+ r_indices.sort(reverse = True)
463
+
464
+ for r_index in r_indices:
465
+ Chem.EditableMol.RemoveAtom(editable, r_index)
466
+
467
+ modified_reactant_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(Chem.MolToSmiles(editable.GetMol()))))
468
+
469
+ toreturn.append((modified_reactant_smiles, [Chem.MolToSmiles(other_prod_mol)], parsed))
470
+ return True
471
+ else:
472
+ return False
473
+
474
+ def query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups):
475
+ subsets = generate_subsets(num_r_groups)
476
+
477
+ toreturn = []
478
+
479
+ for subset in subsets:
480
+ r_sites_list = [[i, r_sites_reversed_new[i]] for i in r_sites_reversed_new]
481
+ r_sites_list.sort(key = lambda x: x[0])
482
+ to_edit = Chem.EditableMol(prod_template_mol_query)
483
+
484
+ for entry in subset:
485
+ pos = r_sites_list[entry][0]
486
+ Chem.EditableMol.RemoveBond(to_edit, r_sites_list[entry][0], prod_template_mol_query.GetAtomWithIdx(r_sites_list[entry][0]).GetNeighbors()[0].GetIdx())
487
+ for entry in subset:
488
+ pos = r_sites_list[entry][0]
489
+ Chem.EditableMol.RemoveAtom(to_edit, pos)
490
+
491
+ edited = to_edit.GetMol()
492
+ for entry in subset:
493
+ for i in range(entry + 1, num_r_groups):
494
+ r_sites_list[i][0]-=1
495
+
496
+ new_r_sites = {}
497
+ new_h_sites = set()
498
+ for i in range(num_r_groups):
499
+ if i not in subset:
500
+ new_r_sites[r_sites_list[i][0]] = r_sites_list[i][1]
501
+ else:
502
+ new_h_sites.add(r_sites_list[i][1])
503
+ toreturn.append((edited, new_r_sites, new_h_sites, num_r_groups - len(subset)))
504
+ return toreturn
505
+
506
+ def generate_subsets(n):
507
+ def backtrack(start, subset):
508
+ result.append(subset[:])
509
+ for i in range(start, -1, -1): # Iterate in reverse order
510
+ subset.append(i)
511
+ backtrack(i - 1, subset)
512
+ subset.pop()
513
+
514
+ result = []
515
+ backtrack(n - 1, [])
516
+ return sorted(result, key=lambda x: (-len(x), x), reverse=True)
517
+
518
+ def backout(results, coref_results, molscribe):
519
+
520
+ toreturn = []
521
+
522
+ if not results or not results[0]['reactions'] or not coref_results:
523
+ return toreturn
524
+
525
+ try:
526
+ reactants = results[0]['reactions'][0]['reactants']
527
+ products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
528
+ coref_results_dict = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[1]]['text'] for coref in coref_results[0]['corefs']}
529
+ coref_smiles_to_graphs = {coref_results[0]['bboxes'][coref[0]]['smiles']: coref_results[0]['bboxes'][coref[0]] for coref in coref_results[0]['corefs']}
530
+
531
+
532
+ if len(products) == 1:
533
+ if products[0] not in coref_results_dict:
534
+ print("Warning: No Label Parsed")
535
+ return
536
+ product_labels = coref_results_dict[products[0]]
537
+ prod = products[0]
538
+ label_idx = product_labels[0]
539
+ '''
540
+ if len(product_labels) == 1:
541
+ # get the coreference label of the product molecule
542
+ label_idx = product_labels[0]
543
+ else:
544
+ print("Warning: Malformed Label Parsed.")
545
+ return
546
+ '''
547
+ else:
548
+ print("Warning: More than one product detected")
549
+ return
550
+
551
+ # format the regular expression for labels that correspond to the product label
552
+ numbers = re.findall(r'\d+', label_idx)
553
+ label_idx = numbers[0] if len(numbers) > 0 else ""
554
+ label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
555
+
556
+
557
+ prod_smiles = prod
558
+ prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
559
+
560
+ # identify the atom indices of the R groups in the product tempalte
561
+ h_counter = 0
562
+ r_sites = {}
563
+ for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
564
+ sym = atom['atom_symbol']
565
+ if sym == '[H]':
566
+ h_counter += 1
567
+ if sym[0] == '[':
568
+ sym = sym[1:-1]
569
+ if sym[0] == 'R' and sym[1:].isdigit():
570
+ sym = sym[1:]+"*"
571
+ sym = f'[{sym}]'
572
+ if sym in RGROUP_SYMBOLS:
573
+ if sym not in r_sites:
574
+ r_sites[sym] = [idx-h_counter]
575
+ else:
576
+ r_sites[sym].append(idx-h_counter)
577
+
578
+ r_sites_reversed = {}
579
+ for sym in r_sites:
580
+ for pos in r_sites[sym]:
581
+ r_sites_reversed[pos] = sym
582
+
583
+ num_r_groups = len(r_sites_reversed)
584
+
585
+ #prepare the product template and get the associated mapping
586
+
587
+ prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
588
+
589
+ reactant_mols = []
590
+
591
+
592
+ #--------------process the reactants-----------------
593
+
594
+ reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
595
+
596
+ for idx, reactant in enumerate(reactants):
597
+ reactant_information[idx] = []
598
+ reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
599
+ has_r = False
600
+
601
+ r_sites_reactant = {}
602
+
603
+ h_counter = 0
604
+
605
+ for a_idx, atom in enumerate(reactant['atoms']):
606
+
607
+ #go through all atoms and check if they are an R group, if so add it to reactant information
608
+ sym = atom['atom_symbol']
609
+ if sym == '[H]':
610
+ h_counter += 1
611
+ if sym[0] == '[':
612
+ sym = sym[1:-1]
613
+ if sym[0] == 'R' and sym[1:].isdigit():
614
+ sym = sym[1:]+"*"
615
+ sym = f'[{sym}]'
616
+ if sym in r_sites:
617
+ if reactant_mols[-1].GetNumAtoms()==1:
618
+ reactant_information[idx].append([sym, -1, -1])
619
+ else:
620
+ has_r = True
621
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
622
+ reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
623
+ r_sites_reactant[sym] = a_idx-h_counter
624
+ elif sym == '[1*]' and '[7*]' in r_sites:
625
+ if reactant_mols[-1].GetNumAtoms()==1:
626
+ reactant_information[idx].append(['[7*]', -1, -1])
627
+ else:
628
+ has_r = True
629
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
630
+ reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
631
+ r_sites_reactant['[7*]'] = a_idx-h_counter
632
+ elif sym == '[7*]' and '[1*]' in r_sites:
633
+ if reactant_mols[-1].GetNumAtoms()==1:
634
+ reactant_information[idx].append(['[1*]', -1, -1])
635
+ else:
636
+ has_r = True
637
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
638
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
639
+ r_sites_reactant['[1*]'] = a_idx-h_counter
640
+
641
+
642
+
643
+ elif sym == '[1*]' and '[Rf]' in r_sites:
644
+ if reactant_mols[-1].GetNumAtoms()==1:
645
+ reactant_information[idx].append(['[Rf]', -1, -1])
646
+ else:
647
+ has_r = True
648
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
649
+ reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
650
+ r_sites_reactant['[Rf]'] = a_idx-h_counter
651
+
652
+ elif sym == '[Rf]' and '[1*]' in r_sites:
653
+ if reactant_mols[-1].GetNumAtoms()==1:
654
+ reactant_information[idx].append(['[1*]', -1, -1])
655
+ else:
656
+ has_r = True
657
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
658
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
659
+ r_sites_reactant['[1*]'] = a_idx-h_counter
660
+
661
+
662
+ r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
663
+ # if the reactant had r groups, we had to use the molecule generated from the MolBlock.
664
+ # but the molblock may have unexpanded elemeents that are not R groups
665
+ # so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
666
+ # and adjust the indices of the r groups accordingly
667
+ if has_r:
668
+ #get the mapping
669
+ reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
670
+
671
+ #make the adjustment
672
+ for info in reactant_information[idx]:
673
+ info[1] = reactant_mol_to_query[info[1]]
674
+ info[2] = reactant_mol_to_query[info[2]]
675
+ reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
676
+
677
+ #go through all the molecules in the coreference
678
+
679
+ clean_corefs(coref_results_dict, label_idx)
680
+
681
+ for other_prod in coref_results_dict:
682
+
683
+ #check if they match the product label regex
684
+ found_good_label = False
685
+ for parsed in coref_results_dict[other_prod]:
686
+ if re.search(label_pattern, parsed) and not found_good_label:
687
+ found_good_label = True
688
+ other_prod_mol = Chem.MolFromSmiles(other_prod)
689
+
690
+ if other_prod != prod_smiles and other_prod_mol is not None:
691
+
692
+ #check if there are R groups to be resolved in the target product
693
+
694
+ all_other_prod_mols = []
695
+
696
+ r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
697
+
698
+ for parsed_labels in coref_results_dict[other_prod]:
699
+ res = r_group_sub_pattern.search(parsed_labels)
700
+
701
+ if res is not None:
702
+ all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
703
+
704
+ if len(all_other_prod_mols) == 0:
705
+ if other_prod_mol is not None:
706
+ all_other_prod_mols.append((other_prod_mol, parsed))
707
+
708
+
709
+
710
+
711
+ for other_prod_mol, parsed in all_other_prod_mols:
712
+
713
+ other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
714
+
715
+ for other_prod_frag in other_prod_frags:
716
+ substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
717
+
718
+ if len(substructs)>0:
719
+ other_prod_mol = other_prod_frag
720
+ break
721
+ r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
722
+
723
+ queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
724
+
725
+ matched = False
726
+
727
+ for query in queries:
728
+ if not matched:
729
+ try:
730
+ matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
731
+ except:
732
+ pass
733
+
734
+ except:
735
+ pass
736
+
737
+
738
+ return toreturn
739
+
740
+
741
+ def backout_without_coref(results, coref_results, coref_results_dict, coref_smiles_to_graphs, molscribe):
742
+
743
+ toreturn = []
744
+
745
+ if not results or not results[0]['reactions'] or not coref_results:
746
+ return toreturn
747
+
748
+ try:
749
+ reactants = results[0]['reactions'][0]['reactants']
750
+ products = [i['smiles'] for i in results[0]['reactions'][0]['products']]
751
+ coref_results_dict = coref_results_dict
752
+ coref_smiles_to_graphs = coref_smiles_to_graphs
753
+
754
+
755
+ if len(products) == 1:
756
+ if products[0] not in coref_results_dict:
757
+ print("Warning: No Label Parsed")
758
+ return
759
+ product_labels = coref_results_dict[products[0]]
760
+ prod = products[0]
761
+ label_idx = product_labels[0]
762
+ '''
763
+ if len(product_labels) == 1:
764
+ # get the coreference label of the product molecule
765
+ label_idx = product_labels[0]
766
+ else:
767
+ print("Warning: Malformed Label Parsed.")
768
+ return
769
+ '''
770
+ else:
771
+ print("Warning: More than one product detected")
772
+ return
773
+
774
+ # format the regular expression for labels that correspond to the product label
775
+ numbers = re.findall(r'\d+', label_idx)
776
+ label_idx = numbers[0] if len(numbers) > 0 else ""
777
+ label_pattern = rf'{re.escape(label_idx)}[a-zA-Z]+'
778
+
779
+
780
+ prod_smiles = prod
781
+ prod_mol = Chem.MolFromMolBlock(results[0]['reactions'][0]['products'][0]['molfile'])
782
+
783
+ # identify the atom indices of the R groups in the product tempalte
784
+ h_counter = 0
785
+ r_sites = {}
786
+ for idx, atom in enumerate(results[0]['reactions'][0]['products'][0]['atoms']):
787
+ sym = atom['atom_symbol']
788
+ if sym == '[H]':
789
+ h_counter += 1
790
+ if sym[0] == '[':
791
+ sym = sym[1:-1]
792
+ if sym[0] == 'R' and sym[1:].isdigit():
793
+ sym = sym[1:]+"*"
794
+ sym = f'[{sym}]'
795
+ if sym in RGROUP_SYMBOLS:
796
+ if sym not in r_sites:
797
+ r_sites[sym] = [idx-h_counter]
798
+ else:
799
+ r_sites[sym].append(idx-h_counter)
800
+
801
+ r_sites_reversed = {}
802
+ for sym in r_sites:
803
+ for pos in r_sites[sym]:
804
+ r_sites_reversed[pos] = sym
805
+
806
+ num_r_groups = len(r_sites_reversed)
807
+
808
+ #prepare the product template and get the associated mapping
809
+
810
+ prod_mol_to_query, prod_template_mol_query = get_atom_mapping(prod_mol, prod_smiles, r_sites_reversed = r_sites_reversed)
811
+
812
+ reactant_mols = []
813
+
814
+
815
+ #--------------process the reactants-----------------
816
+
817
+ reactant_information = {} #index of relevant reaction --> [[R group name, atom index of R group, atom index of R group connection], ...]
818
+
819
+ for idx, reactant in enumerate(reactants):
820
+ reactant_information[idx] = []
821
+ reactant_mols.append(Chem.MolFromSmiles(reactant['smiles']))
822
+ has_r = False
823
+
824
+ r_sites_reactant = {}
825
+
826
+ h_counter = 0
827
+
828
+ for a_idx, atom in enumerate(reactant['atoms']):
829
+
830
+ #go through all atoms and check if they are an R group, if so add it to reactant information
831
+ sym = atom['atom_symbol']
832
+ if sym == '[H]':
833
+ h_counter += 1
834
+ if sym[0] == '[':
835
+ sym = sym[1:-1]
836
+ if sym[0] == 'R' and sym[1:].isdigit():
837
+ sym = sym[1:]+"*"
838
+ sym = f'[{sym}]'
839
+ if sym in r_sites:
840
+ if reactant_mols[-1].GetNumAtoms()==1:
841
+ reactant_information[idx].append([sym, -1, -1])
842
+ else:
843
+ has_r = True
844
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
845
+ reactant_information[idx].append([sym, a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
846
+ r_sites_reactant[sym] = a_idx-h_counter
847
+ elif sym == '[1*]' and '[7*]' in r_sites:
848
+ if reactant_mols[-1].GetNumAtoms()==1:
849
+ reactant_information[idx].append(['[7*]', -1, -1])
850
+ else:
851
+ has_r = True
852
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
853
+ reactant_information[idx].append(['[7*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
854
+ r_sites_reactant['[7*]'] = a_idx-h_counter
855
+ elif sym == '[7*]' and '[1*]' in r_sites:
856
+ if reactant_mols[-1].GetNumAtoms()==1:
857
+ reactant_information[idx].append(['[1*]', -1, -1])
858
+ else:
859
+ has_r = True
860
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
861
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
862
+ r_sites_reactant['[1*]'] = a_idx-h_counter
863
+
864
+ elif sym == '[1*]' and '[Rf]' in r_sites:
865
+ if reactant_mols[-1].GetNumAtoms()==1:
866
+ reactant_information[idx].append(['[Rf]', -1, -1])
867
+ else:
868
+ has_r = True
869
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
870
+ reactant_information[idx].append(['[Rf]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
871
+ r_sites_reactant['[Rf]'] = a_idx-h_counter
872
+
873
+ elif sym == '[Rf]' and '[1*]' in r_sites:
874
+ if reactant_mols[-1].GetNumAtoms()==1:
875
+ reactant_information[idx].append(['[1*]', -1, -1])
876
+ else:
877
+ has_r = True
878
+ reactant_mols[-1] = Chem.MolFromMolBlock(reactant['molfile'])
879
+ reactant_information[idx].append(['[1*]', a_idx-h_counter, [i.GetIdx() for i in reactant_mols[-1].GetAtomWithIdx(a_idx-h_counter).GetNeighbors()][0]])
880
+ r_sites_reactant['[1*]'] = a_idx-h_counter
881
+
882
+ r_sites_reversed_reactant = {r_sites_reactant[i]: i for i in r_sites_reactant}
883
+ # if the reactant had r groups, we had to use the molecule generated from the MolBlock.
884
+ # but the molblock may have unexpanded elemeents that are not R groups
885
+ # so we have to map back the r group indices in the molblock version to the full molecule generated by the smiles
886
+ # and adjust the indices of the r groups accordingly
887
+ if has_r:
888
+ #get the mapping
889
+ reactant_mol_to_query, _ = get_atom_mapping(reactant_mols[-1], reactant['smiles'], r_sites_reversed = r_sites_reversed_reactant)
890
+
891
+ #make the adjustment
892
+ for info in reactant_information[idx]:
893
+ info[1] = reactant_mol_to_query[info[1]]
894
+ info[2] = reactant_mol_to_query[info[2]]
895
+ reactant_mols[-1] = Chem.MolFromSmiles(reactant['smiles'])
896
+
897
+ #go through all the molecules in the coreference
898
+
899
+ clean_corefs(coref_results_dict, label_idx)
900
+
901
+ for other_prod in coref_results_dict:
902
+
903
+ #check if they match the product label regex
904
+ found_good_label = False
905
+ for parsed in coref_results_dict[other_prod]:
906
+ if re.search(label_pattern, parsed) and not found_good_label:
907
+ found_good_label = True
908
+ other_prod_mol = Chem.MolFromSmiles(other_prod)
909
+
910
+ if other_prod != prod_smiles and other_prod_mol is not None:
911
+
912
+ #check if there are R groups to be resolved in the target product
913
+
914
+ all_other_prod_mols = []
915
+
916
+ r_group_sub_pattern = re.compile('(?P<name>[RXY]\d?)[ ]*=[ ]*(?P<group>\w+)')
917
+
918
+ for parsed_labels in coref_results_dict[other_prod]:
919
+ res = r_group_sub_pattern.search(parsed_labels)
920
+
921
+ if res is not None:
922
+ all_other_prod_mols.append((expand_r_group_label_helper(res, coref_smiles_to_graphs, other_prod, molscribe), parsed + parsed_labels))
923
+
924
+ if len(all_other_prod_mols) == 0:
925
+ if other_prod_mol is not None:
926
+ all_other_prod_mols.append((other_prod_mol, parsed))
927
+
928
+
929
+
930
+
931
+ for other_prod_mol, parsed in all_other_prod_mols:
932
+
933
+ other_prod_frags = Chem.GetMolFrags(other_prod_mol, asMols = True)
934
+
935
+ for other_prod_frag in other_prod_frags:
936
+ substructs = other_prod_frag.GetSubstructMatches(prod_template_mol_query, uniquify = False)
937
+
938
+ if len(substructs)>0:
939
+ other_prod_mol = other_prod_frag
940
+ break
941
+ r_sites_reversed_new = {prod_mol_to_query[r]: r_sites_reversed[r] for r in r_sites_reversed}
942
+
943
+ queries = query_enumeration(prod_template_mol_query, r_sites_reversed_new, num_r_groups)
944
+
945
+ matched = False
946
+
947
+ for query in queries:
948
+ if not matched:
949
+ try:
950
+ matched = get_r_group_frags_and_substitute(other_prod_mol, query, reactant_mols, reactant_information, parsed, toreturn)
951
+ except:
952
+ pass
953
+
954
+ except:
955
+ pass
956
+
957
+
958
+ return toreturn
959
+
960
+
961
+
962
+ def associate_corefs(results, results_coref):
963
+ coref_smiles = {}
964
+ idx_pattern = r'\b\d+[a-zA-Z]{0,2}\b'
965
+ for result_coref in results_coref:
966
+ bboxes, corefs = result_coref['bboxes'], result_coref['corefs']
967
+ for coref in corefs:
968
+ mol, idt = coref[0], coref[1]
969
+ if len(bboxes[idt]['text']) > 0:
970
+ for text in bboxes[idt]['text']:
971
+ matches = re.findall(idx_pattern, text)
972
+ for match in matches:
973
+ coref_smiles[match] = bboxes[mol]['smiles']
974
+
975
+ for page in results:
976
+ for reactions in page['reactions']:
977
+ for reaction in reactions['reactions']:
978
+ if 'Reactants' in reaction:
979
+ if isinstance(reaction['Reactants'], tuple):
980
+ if reaction['Reactants'][0] in coref_smiles:
981
+ reaction['Reactants'] = (f'{reaction["Reactants"][0]} ({coref_smiles[reaction["Reactants"][0]]})', reaction['Reactants'][1], reaction['Reactants'][2])
982
+ else:
983
+ for idx, compound in enumerate(reaction['Reactants']):
984
+ if compound[0] in coref_smiles:
985
+ reaction['Reactants'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
986
+ if 'Product' in reaction:
987
+ if isinstance(reaction['Product'], tuple):
988
+ if reaction['Product'][0] in coref_smiles:
989
+ reaction['Product'] = (f'{reaction["Product"][0]} ({coref_smiles[reaction["Product"][0]]})', reaction['Product'][1], reaction['Product'][2])
990
+ else:
991
+ for idx, compound in enumerate(reaction['Product']):
992
+ if compound[0] in coref_smiles:
993
+ reaction['Product'][idx] = (f'{compound[0]} ({coref_smiles[compound[0]]})', compound[1], compound[2])
994
+ return results
995
+
996
+
997
+ def expand_reactions_with_backout(initial_results, results_coref, molscribe):
998
+ idx_pattern = r'^\d+[a-zA-Z]{0,2}$'
999
+ for reactions, result_coref in zip(initial_results, results_coref):
1000
+ if not reactions['reactions']:
1001
+ continue
1002
+ try:
1003
+ backout_results = backout([reactions], [result_coref], molscribe)
1004
+ except Exception:
1005
+ continue
1006
+ conditions = reactions['reactions'][0]['conditions']
1007
+ idt_to_smiles = {}
1008
+ if not backout_results:
1009
+ continue
1010
+
1011
+ for reactants, products, idt in backout_results:
1012
+ reactions['reactions'].append({
1013
+ 'reactants': [{'category': '[Mol]', 'molfile': None, 'smiles': reactant} for reactant in reactants],
1014
+ 'conditions': conditions[:],
1015
+ 'products': [{'category': '[Mol]', 'molfile': None, 'smiles': product} for product in products]
1016
+ })
1017
+ return initial_results
1018
+
examples/exp.png ADDED

Git LFS Details

  • SHA256: 3ce344ed33ff77f45d6e87a29e91426c3444ee9b58a8b10086ce3483a1ad2a2e
  • Pointer size: 131 Bytes
  • Size of remote file: 696 kB
examples/image.webp ADDED
examples/molecules1.png ADDED
examples/molecules2.png ADDED
examples/rdkit.png ADDED
examples/reaction0.png ADDED

Git LFS Details

  • SHA256: 84485f2bdd393b6195a0f904f720b5240dfe3c952bcbad8ba92de833f67ac2e6
  • Pointer size: 131 Bytes
  • Size of remote file: 769 kB
examples/reaction1.jpg ADDED
examples/reaction2.png ADDED
examples/reaction3.png ADDED
examples/reaction4.png ADDED

Git LFS Details

  • SHA256: 341a3b9f6b24b3fe3793186ec198cf1171ffb84bca6c0316052f25e17c0eeb55
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
examples/reaction5.png ADDED