Spaces:
Running
Running
| import gradio as gr | |
| from inference import Inference | |
| import PIL | |
| from PIL import Image | |
| import pandas as pd | |
| import random | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from rdkit.Chem.Draw import IPythonConsole | |
| import shutil | |
| import os | |
| import time | |
| class DrugGENConfig: | |
| # Inference configuration | |
| submodel = 'DrugGEN' | |
| inference_model = "/home/user/app/experiments/models/DrugGEN/" | |
| sample_num = 100 | |
| # Data configuration | |
| inf_smiles = '/home/user/app/data/chembl_test.smi' | |
| train_smiles = '/home/user/app/data/chembl_train.smi' | |
| inf_batch_size = 1 | |
| mol_data_dir = '/home/user/app/data' | |
| features = False | |
| # Model configuration | |
| act = 'relu' | |
| max_atom = 45 | |
| dim = 128 | |
| depth = 1 | |
| heads = 8 | |
| mlp_ratio = 3 | |
| dropout = 0. | |
| # Seed configuration | |
| set_seed = True | |
| seed = 10 | |
| disable_correction = False | |
| class DrugGENAKT1Config(DrugGENConfig): | |
| submodel = 'DrugGEN' | |
| inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/" | |
| train_drug_smiles = '/home/user/app/data/akt_train.smi' | |
| max_atom = 45 | |
| class DrugGENCDK2Config(DrugGENConfig): | |
| submodel = 'DrugGEN' | |
| inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/" | |
| train_drug_smiles = '/home/user/app/data/cdk2_train.smi' | |
| max_atom = 38 | |
| class NoTargetConfig(DrugGENConfig): | |
| submodel = "NoTarget" | |
| inference_model = "/home/user/app/experiments/models/NoTarget/" | |
| train_drug_smiles = '/home/user/app/data/akt_train.smi' | |
| max_atom = 45 | |
| model_configs = { | |
| "DrugGEN-AKT1": DrugGENAKT1Config(), | |
| "DrugGEN-CDK2": DrugGENCDK2Config(), | |
| "DrugGEN-NoTarget": NoTargetConfig(), | |
| } | |
| def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str): | |
| """ | |
| Depending on the selected mode, either generate new molecules or evaluate provided SMILES. | |
| Returns: | |
| image, file_path, basic_metrics, advanced_metrics | |
| """ | |
| config = model_configs[model_name] | |
| if mode == "Custom Input SMILES": | |
| # Process the custom input SMILES | |
| smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""] | |
| if len(smiles_list) > 100: | |
| raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.") | |
| # Write the custom SMILES to a temporary file and update config | |
| config.seed = random.randint(0, 10000) | |
| temp_input_file = f"custom_input{config.seed}.smi" | |
| with open(temp_input_file, "w") as f: | |
| for s in smiles_list: | |
| f.write(s + "\n") | |
| config.inf_smiles = temp_input_file | |
| config.sample_num = len(smiles_list) | |
| # Always use a random seed for custom mode | |
| else: | |
| # Classical Generation mode | |
| config.sample_num = num_molecules | |
| if config.sample_num > 200: | |
| raise gr.Error("You have requested to generate more than the allowed limit of 200 molecules. Please reduce your request to 200 or fewer.") | |
| if seed_num is None or seed_num.strip() == "": | |
| config.seed = random.randint(0, 10000) | |
| else: | |
| try: | |
| config.seed = int(seed_num) | |
| except ValueError: | |
| raise gr.Error("The seed must be an integer value!") | |
| # Adjust model name for the inference if not using NoTarget | |
| if model_name != "DrugGEN-NoTarget": | |
| target_model_name = "DrugGEN" | |
| else: | |
| target_model_name = "NoTarget" | |
| inferer = Inference(config) | |
| start_time = time.time() | |
| scores = inferer.inference() # This returns a DataFrame with specific columns | |
| et = time.time() - start_time | |
| # Create basic metrics dataframe | |
| basic_metrics = pd.DataFrame({ | |
| "Validity": [scores["validity"].iloc[0]], | |
| "Uniqueness": [scores["uniqueness"].iloc[0]], | |
| "Novelty (Train)": [scores["novelty"].iloc[0]], | |
| "Novelty (Inference)": [scores["novelty_test"].iloc[0]], | |
| "Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]], | |
| "Runtime (s)": [round(et, 2)] | |
| }) | |
| # Create advanced metrics dataframe | |
| advanced_metrics = pd.DataFrame({ | |
| "QED": [scores["qed"].iloc[0]], | |
| "SA Score": [scores["sa"].iloc[0]], | |
| "Internal Diversity": [scores["IntDiv"].iloc[0]], | |
| "SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
| "SNN Real Inhibitors": [scores["snn_drug"].iloc[0]], | |
| "Average Length": [scores["max_len"].iloc[0]] | |
| }) | |
| # Process the output file from inference | |
| output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt' | |
| new_path = f'{target_model_name}_denovo_mols.smi' | |
| os.rename(output_file_path, new_path) | |
| with open(new_path) as f: | |
| inference_drugs = f.read() | |
| generated_molecule_list = inference_drugs.split("\n")[:-1] | |
| # Randomly select up to 9 molecules for display | |
| rng = random.Random(config.seed) | |
| if len(generated_molecule_list) > 9: | |
| selected_smiles = rng.choices(generated_molecule_list, k=9) | |
| else: | |
| selected_smiles = generated_molecule_list | |
| selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None] | |
| drawOptions = Draw.rdMolDraw2D.MolDrawOptions() | |
| drawOptions.prepareMolsBeforeDrawing = False | |
| drawOptions.bondLineWidth = 0.5 | |
| molecule_image = Draw.MolsToGridImage( | |
| selected_molecules, | |
| molsPerRow=3, | |
| subImgSize=(400, 400), | |
| maxMols=len(selected_molecules), | |
| returnPNG=False, | |
| drawOptions=drawOptions, | |
| highlightAtomLists=None, | |
| highlightBondLists=None, | |
| ) | |
| return molecule_image, new_path, basic_metrics, advanced_metrics | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| # Add custom CSS for styling | |
| gr.HTML(""" | |
| <style> | |
| #metrics-container { | |
| border: 1px solid rgba(128, 128, 128, 0.3); | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-top: 15px; | |
| margin-bottom: 15px; | |
| background-color: rgba(255, 255, 255, 0.05); | |
| } | |
| </style> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") | |
| gr.HTML(""" | |
| <div style="display: flex; gap: 10px; margin-bottom: 15px;"> | |
| <!-- Nature Machine Intelligence badge --> | |
| <a href="https://www.nature.com/articles/s42256-025-01082-y" target="_blank" style="text-decoration: none;"> | |
| <div style=" | |
| display: inline-block; | |
| background-color: #006d5b; | |
| color: #ffffff !important; | |
| padding: 5px 10px; | |
| border-radius: 5px; | |
| font-size: 14px;"> | |
| <span style="font-weight: bold;">Nature</span> Machine Intelligence | |
| </div> | |
| </a> | |
| </div> | |
| <!-- GitHub badge --> | |
| <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;"> | |
| <div style=" | |
| display: inline-block; | |
| background-color: #24292e; | |
| color: #ffffff !important; | |
| padding: 5px 10px; | |
| border-radius: 5px; | |
| font-size: 14px;"> | |
| <span style="font-weight: bold;">GitHub</span> Repository | |
| </div> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Accordion("About DrugGEN Models", open=False): | |
| gr.Markdown(""" | |
| ### DrugGEN-AKT1 | |
| This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). Trained with [2,607 bioactive compounds](https://drive.google.com/file/d/1B2OOim5wrUJalixeBTDKXLHY8BAIvNh-/view?usp=drive_link). | |
| Molecules larger than 45 heavy atoms were excluded. | |
| ### DrugGEN-CDK2 | |
| This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). Trained with [1,817 bioactive compounds](https://drive.google.com/file/d/1C0CGFKx0I2gdSfbIEgUO7q3K2S1P9ksT/view?usp=drive_link). | |
| Molecules larger than 38 heavy atoms were excluded. | |
| ### DrugGEN-NoTarget | |
| This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. Trained with a general [ChEMBL dataset]((https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link) | |
| Molecules larger than 45 heavy atoms were excluded. | |
| - Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties. | |
| For more details, see our [paper on *Nature Machine Intelligence*](https://www.nature.com/articles/s42256-025-01082-y). | |
| You can also access the [full-text online read version (PDF)](https://www.nature.com/articles/s42256-025-01082-y.epdf?sharing_token=bBVnlnZabD8FPXynnlbEa9RgN0jAjWel9jnR3ZoTv0MOagXTn27vLVloLjCccXj0nzhwrmuR31o-E7RYhzLWbxKZbwmIHhjdjLgJG3Q2AOnkRhIv4r0-N4wiBgFqx9A8_Xo76WCGo_0-1AIjN5rizV12fj7TtkJuo_49OxmCcZY%3D). | |
| """) | |
| with gr.Accordion("Understanding the Metrics", open=False): | |
| gr.Markdown(""" | |
| ### Basic Metrics | |
| - **Validity**: Percentage of generated molecules that are chemically valid | |
| - **Uniqueness**: Percentage of unique molecules among valid ones | |
| - **Runtime**: Time taken to generate or evaluate the molecules | |
| ### Novelty Metrics | |
| - **Novelty (Train)**: Percentage of molecules not found in the [training set](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link). These molecules are used as inputs to | |
| the generator during training. | |
| - **Novelty (Inference)**: Percentage of molecules not found in the [inference set](https://drive.google.com/file/d/1vMGXqK1SQXB3Od3l80gMWvTEOjJ5MFXP/view?usp=share_link). These molecules are used as inputs | |
| to the generator during inference. | |
| - **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein (look at About DrugGEN Models for details). These molecules are used as inputs to the | |
| discriminator during training. | |
| ### Structural Metrics | |
| - **Average Length**: Normalized average number of atoms in the generated molecules, normalized by the maximum number of atoms (e.g., 45 for AKT1/NoTarget, 38 for CDK2) | |
| - **Mean Atom Type**: Average number of distinct atom types in the generated molecules | |
| - **Internal Diversity**: Diversity within the generated set (higher is more diverse) | |
| ### Drug-likeness Metrics | |
| - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better) | |
| - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better) | |
| ### Similarity Metrics | |
| - **SNN ChEMBL**: Similarity to [ChEMBL molecules](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link) (higher means more similar to known drug-like compounds) | |
| - **SNN Real Inhibitors**: Similarity to the real inhibitors of the selected target (higher means more similar to the real inhibitors) | |
| """) | |
| model_name = gr.Radio( | |
| choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"), | |
| value="DrugGEN-AKT1", | |
| label="Select Target Model", | |
| info="Choose which protein target or general model to use for molecule generation" | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Classical Generation"): | |
| num_molecules = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=100, | |
| step=10, | |
| label="Number of Molecules to Generate", | |
| info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 200-molecule cap." | |
| ) | |
| seed_num = gr.Textbox( | |
| label="Random Seed (Optional)", | |
| value="", | |
| info="Set a specific seed for reproducible results, or leave empty for random generation" | |
| ) | |
| classical_submit = gr.Button( | |
| value="Generate Molecules", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.TabItem("Custom Input SMILES"): | |
| custom_smiles = gr.Textbox( | |
| label="Input SMILES (one per line, maximum 100 molecules)", | |
| info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 100-molecule cap.\n\n Molecules larger than allowed maximum length (45 for AKT1/NoTarget and 38 for CDK2) and allowed atom types are going to be filtered.\n\n Novelty (Inference) metric is going to be calculated using these input smiles.", | |
| placeholder="Nc1ccccc1-c1nc(N)c2ccccc2n1\nO=C(O)c1ccccc1C(=O)c1cccc(Cl)c1\n...", | |
| lines=10 | |
| ) | |
| custom_submit = gr.Button( | |
| value="Generate Molecules using Custom SMILES", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=2): | |
| basic_metrics_df = gr.Dataframe( | |
| headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"], | |
| elem_id="basic-metrics" | |
| ) | |
| advanced_metrics_df = gr.Dataframe( | |
| headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"], | |
| elem_id="advanced-metrics" | |
| ) | |
| file_download = gr.File(label="Download All Generated Molecules (SMILES format)") | |
| image_output = gr.Image(label="Structures of Randomly Selected Generated Molecules", | |
| elem_id="molecule_display") | |
| gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)") | |
| # Set up the click actions for each tab. | |
| classical_submit.click( | |
| run_inference, | |
| inputs=[gr.State("Generate Molecules"), model_name, num_molecules, seed_num, gr.State("")], | |
| outputs=[ | |
| image_output, | |
| file_download, | |
| basic_metrics_df, | |
| advanced_metrics_df | |
| ], | |
| api_name="inference_classical" | |
| ) | |
| custom_submit.click( | |
| run_inference, | |
| inputs=[gr.State("Custom Input SMILES"), model_name, gr.State(0), gr.State(""), custom_smiles], | |
| outputs=[ | |
| image_output, | |
| file_download, | |
| basic_metrics_df, | |
| advanced_metrics_df | |
| ], | |
| api_name="inference_custom" | |
| ) | |
| demo.queue() | |
| demo.launch() | |