|
import spaces |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import random |
|
import io |
|
import imageio |
|
import os |
|
import tempfile |
|
import atexit |
|
import glob |
|
import csv |
|
from datetime import datetime |
|
import json |
|
|
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
|
|
from evaluator import Evaluator |
|
from loader import load_graph_decoder |
|
|
|
|
|
known_labels = pd.read_csv('data/known_labels.csv') |
|
knwon_smiles = pd.read_csv('data/known_polymers.csv') |
|
|
|
all_properties = ['CH4', 'CO2', 'H2', 'N2', 'O2'] |
|
|
|
|
|
evaluators = {prop: Evaluator(f'evaluators/{prop}.joblib', prop) for prop in all_properties} |
|
|
|
|
|
property_ranges = {prop: (known_labels[prop].min(), known_labels[prop].max()) for prop in all_properties} |
|
|
|
|
|
temp_dir = tempfile.mkdtemp(prefix="polymer_gifs_") |
|
|
|
def cleanup_temp_files(): |
|
"""Clean up temporary GIF files on exit.""" |
|
for file in glob.glob(os.path.join(temp_dir, "*.gif")): |
|
try: |
|
os.remove(file) |
|
except Exception as e: |
|
print(f"Error deleting {file}: {e}") |
|
try: |
|
os.rmdir(temp_dir) |
|
except Exception as e: |
|
print(f"Error deleting temporary directory {temp_dir}: {e}") |
|
|
|
|
|
atexit.register(cleanup_temp_files) |
|
|
|
def random_properties(): |
|
return known_labels[all_properties].sample(1).values.tolist()[0] |
|
|
|
def load_model(model_choice): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = load_graph_decoder(path=model_choice) |
|
return (model, device) |
|
|
|
|
|
flagged_folder = "flagged" |
|
os.makedirs(flagged_folder, exist_ok=True) |
|
|
|
def save_interesting_log(smiles, properties, suggested_properties): |
|
"""Save interesting polymer data to a CSV file.""" |
|
log_file = os.path.join(flagged_folder, "log.csv") |
|
file_exists = os.path.isfile(log_file) |
|
|
|
with open(log_file, 'a', newline='') as csvfile: |
|
fieldnames = ['timestamp', 'smiles'] + all_properties + [f'suggested_{prop}' for prop in all_properties] |
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
|
|
|
if not file_exists: |
|
writer.writeheader() |
|
|
|
log_data = { |
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
'smiles': smiles, |
|
**{prop: value for prop, value in zip(all_properties, properties)}, |
|
**{f'suggested_{prop}': value for prop, value in suggested_properties.items()} |
|
} |
|
writer.writerow(log_data) |
|
|
|
@spaces.GPU |
|
def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps): |
|
model, device = model_state |
|
|
|
properties = [CH4, CO2, H2, N2, O2] |
|
|
|
def is_nan_like(x): |
|
return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x)) |
|
|
|
properties = [None if is_nan_like(prop) else prop for prop in properties] |
|
|
|
nan_message = "The following gas properties were treated as NaN: " |
|
nan_gases = [gas for gas, prop in zip(all_properties, properties) if prop is None] |
|
nan_message += ", ".join(nan_gases) if nan_gases else "None" |
|
|
|
num_nodes = None if num_nodes == 0 else num_nodes |
|
|
|
for _ in range(repeating_time): |
|
try: |
|
model.to(device) |
|
generated_molecule, img_list = model.generate(properties, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps) |
|
|
|
gif_path = None |
|
if img_list and len(img_list) > 0: |
|
imgs = [np.array(pil_img) for pil_img in img_list] |
|
imgs.extend([imgs[-1]] * 10) |
|
gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif") |
|
imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0) |
|
|
|
if generated_molecule is not None: |
|
mol = Chem.MolFromSmiles(generated_molecule) |
|
if mol is not None: |
|
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True) |
|
is_novel = standardized_smiles not in knwon_smiles['SMILES'].values |
|
novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)" |
|
img = Draw.MolToImage(mol) |
|
|
|
|
|
suggested_properties = {} |
|
for prop, evaluator in evaluators.items(): |
|
suggested_properties[prop] = evaluator([standardized_smiles])[0] |
|
|
|
suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()]) |
|
|
|
return ( |
|
f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n" |
|
f"**{nan_message}**\n\n" |
|
f"**{novelty_status}**\n\n" |
|
f"**Suggested Properties:**\n{suggested_properties_text}", |
|
img, |
|
gif_path, |
|
properties, |
|
suggested_properties |
|
) |
|
else: |
|
return ( |
|
f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**", |
|
None, |
|
gif_path, |
|
properties, |
|
None, |
|
) |
|
except Exception as e: |
|
print(f"Error in generation: {e}") |
|
continue |
|
|
|
return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None |
|
|
|
def set_random_properties(): |
|
return random_properties() |
|
|
|
|
|
model_name_mapping = { |
|
"model_all": "Graph DiT (trained on labeled + unlabeled)", |
|
"model_labeled": "Graph DiT (trained on labeled)" |
|
} |
|
|
|
def numpy_to_python(obj): |
|
if isinstance(obj, np.integer): |
|
return int(obj) |
|
elif isinstance(obj, np.floating): |
|
return float(obj) |
|
elif isinstance(obj, np.ndarray): |
|
return obj.tolist() |
|
elif isinstance(obj, list): |
|
return [numpy_to_python(item) for item in obj] |
|
elif isinstance(obj, dict): |
|
return {k: numpy_to_python(v) for k, v in obj.items()} |
|
else: |
|
return obj |
|
|
|
def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps): |
|
result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps) |
|
|
|
if result[0].startswith("**Generated polymer SMILES:**"): |
|
smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0] |
|
properties = json.dumps(numpy_to_python(result[3])) |
|
suggested_properties = json.dumps(numpy_to_python(result[4])) |
|
|
|
return [*result[:3], smiles, properties, suggested_properties, gr.Button(interactive=True)] |
|
else: |
|
|
|
return [*result[:3], "", "[]", "[]", gr.Button(interactive=False)] |
|
|
|
def process_feedback(checkbox_value, smiles, properties, suggested_properties): |
|
if checkbox_value: |
|
|
|
if isinstance(properties, str): |
|
properties = json.loads(properties) |
|
if isinstance(suggested_properties, str): |
|
suggested_properties = json.loads(suggested_properties) |
|
|
|
save_interesting_log(smiles, properties, suggested_properties) |
|
return gr.Textbox(value="Thank you for your feedback! This polymer has been saved to our interesting polymers log.", visible=True) |
|
else: |
|
return gr.Textbox(value="Thank you for your feedback!", visible=True) |
|
|
|
|
|
def reset_feedback_button(): |
|
return gr.Button(interactive=False) |
|
|
|
|
|
with gr.Blocks(title="Polymer Design with GraphDiT") as iface: |
|
|
|
with gr.Row(elem_id="navbar"): |
|
gr.Markdown(""" |
|
<div style="text-align: center;"> |
|
<h1>🔗🔬 Polymer Design with GraphDiT</h1> |
|
<div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;"> |
|
<a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;"> |
|
<img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" /> |
|
<span>View Code</span> |
|
</a> |
|
<a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;"> |
|
📄 View Paper |
|
</a> |
|
</div> |
|
</div> |
|
""") |
|
|
|
|
|
gr.Markdown(""" |
|
## Introduction |
|
Input the desired gas barrier properties for CH₄, CO₂, H₂, N₂, and O₂ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. Note: Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts. |
|
""") |
|
|
|
|
|
model_choice = gr.Radio( |
|
choices=list(model_name_mapping.values()), |
|
label="Model Zoo", |
|
|
|
value="Graph DiT (trained on labeled)" |
|
) |
|
|
|
|
|
with gr.Accordion("🔍 Model Description", open=False): |
|
gr.Markdown(""" |
|
### GraphDiT: Graph Diffusion Transformer |
|
GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties. |
|
We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/). |
|
The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging. |
|
We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement. |
|
#### Currently, we have two variants of Graph DiT: |
|
- **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation. |
|
- **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs. |
|
""") |
|
|
|
|
|
with gr.Accordion("📄 Citation", open=False): |
|
gr.Markdown(""" |
|
If you use this model or interface useful, please cite the following paper: |
|
```bibtex |
|
@article{graphdit2024, |
|
title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation}, |
|
author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng}, |
|
journal={NeurIPS}, |
|
year={2024}, |
|
} |
|
``` |
|
""") |
|
|
|
model_state = gr.State(lambda: load_model("model_labeled")) |
|
|
|
with gr.Row(): |
|
CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]") |
|
CO2_input = gr.Slider(0, property_ranges['CO2'][1], value=15.4, label=f"CO₂ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]") |
|
H2_input = gr.Slider(0, property_ranges['H2'][1], value=21.0, label=f"H₂ (Barrier) [0-{property_ranges['H2'][1]:.1f}]") |
|
N2_input = gr.Slider(0, property_ranges['N2'][1], value=1.5, label=f"N₂ (Barrier) [0-{property_ranges['N2'][1]:.1f}]") |
|
O2_input = gr.Slider(0, property_ranges['O2'][1], value=2.8, label=f"O₂ (Barrier) [0-{property_ranges['O2'][1]:.1f}]") |
|
|
|
with gr.Row(): |
|
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale from Properties") |
|
num_nodes = gr.Slider(0, 50, step=1, value=0, label="Number of Nodes (0 for Random, Larger Graphs Take More Time)") |
|
repeating_time = gr.Slider(1, 10, step=1, value=3, label="Repetition Until Success") |
|
num_chain_steps = gr.Slider(0, 499, step=1, value=50, label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)") |
|
fps = gr.Slider(0.25, 10, step=0.25, value=5, label="Frames Per Second") |
|
|
|
with gr.Row(): |
|
random_btn = gr.Button("🔀 Randomize Properties (from Labeled Data)") |
|
generate_btn = gr.Button("🚀 Generate Polymer") |
|
|
|
with gr.Row(): |
|
result_text = gr.Textbox(label="📝 Generation Result") |
|
result_image = gr.Image(label="Final Molecule Visualization", type="pil") |
|
result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif") |
|
|
|
with gr.Row() as feedback_row: |
|
feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False) |
|
feedback_result = gr.Textbox(label="Feedback Result", visible=False) |
|
|
|
|
|
def switch_model(choice): |
|
|
|
internal_name = next(key for key, value in model_name_mapping.items() if value == choice) |
|
return load_model(internal_name) |
|
|
|
model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state]) |
|
|
|
|
|
hidden_smiles = gr.Textbox(visible=False) |
|
hidden_properties = gr.JSON(visible=False) |
|
hidden_suggested_properties = gr.JSON(visible=False) |
|
|
|
|
|
random_btn.click( |
|
set_random_properties, |
|
outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input] |
|
) |
|
|
|
generate_btn.click( |
|
on_generate, |
|
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps], |
|
outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn] |
|
) |
|
|
|
feedback_btn.click( |
|
process_feedback, |
|
inputs=[gr.Checkbox(value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties], |
|
outputs=[feedback_result] |
|
).then( |
|
lambda: gr.Button(interactive=False), |
|
outputs=[feedback_btn] |
|
) |
|
|
|
CH4_input.change(reset_feedback_button, outputs=[feedback_btn]) |
|
CO2_input.change(reset_feedback_button, outputs=[feedback_btn]) |
|
H2_input.change(reset_feedback_button, outputs=[feedback_btn]) |
|
N2_input.change(reset_feedback_button, outputs=[feedback_btn]) |
|
O2_input.change(reset_feedback_button, outputs=[feedback_btn]) |
|
random_btn.click(reset_feedback_button, outputs=[feedback_btn]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
iface.launch(share=False) |