import base64 import pickle import re import uuid import pandas as pd import streamlit as st from CGRtools.files import SMILESRead from streamlit_ketcher import st_ketcher from huggingface_hub import hf_hub_download from huggingface_hub.utils import disable_progress_bars from synplan.mcts.expansion import PolicyNetworkFunction from synplan.mcts.search import extract_tree_stats from synplan.mcts.tree import Tree from synplan.chem.utils import mol_from_smiles from synplan.utils.config import TreeConfig, PolicyNetworkConfig from synplan.utils.loading import load_reaction_rules, load_building_blocks from synplan.utils.visualisation import generate_results_html, get_route_svg disable_progress_bars("huggingface_hub") smiles_parser = SMILESRead.create_parser(ignore=True) def download_button(object_to_download, download_filename, button_text, pickle_it=False): """ Issued from Generates a link to download the given object_to_download. Params: ------ object_to_download: The object to be downloaded. download_filename (str): filename and extension of file. e.g. mydata.csv, some_txt_output.txt download_link_text (str): Text to display for download link. button_text (str): Text to display on download button (e.g. 'click here to download file') pickle_it (bool): If True, pickle file. Returns: ------- (str): the anchor tag to download object_to_download Examples: -------- download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') """ if pickle_it: try: object_to_download = pickle.dumps(object_to_download) except pickle.PicklingError as e: st.write(e) return None else: if isinstance(object_to_download, bytes): pass elif isinstance(object_to_download, pd.DataFrame): object_to_download = object_to_download.to_csv(index=False).encode('utf-8') # Try JSON encode for everything else # else: # object_to_download = json.dumps(object_to_download) try: # some strings <-> bytes conversions necessary here b64 = base64.b64encode(object_to_download.encode()).decode() except AttributeError: b64 = base64.b64encode(object_to_download).decode() button_uuid = str(uuid.uuid4()).replace('-', '') button_id = re.sub('\d+', '', button_uuid) custom_css = f""" """ dl_link = custom_css + f'{button_text}

' return dl_link st.set_page_config(page_title="SynPlanner GUI", page_icon="🧪", layout="wide") intro_text = ''' This is a demo of the graphical user interface of [SynPlanner](https://github.com/Laboratoire-de-Chemoinformatique/SynPlanner/). SynPlanner is a comprehensive tool for reaction data curation, rule extraction, model training and retrosynthetic planning. More information on SynPlanner is available in the [official docs](https://synplanner.readthedocs.io/en/latest/index.html). ''' st.title("`SynPlanner GUI`") st.write(intro_text) st.header('Molecule input') st.markdown( ''' You can provide a molecular structure by either providing: * SMILES string + Enter * Draw it + Apply ''' ) DEFAULT_MOL = 'c1cc(ccc1Cl)C(CCO)NC(C2(CCN(CC2)c3c4cc[nH]c4ncn3)N)=O' molecule = st.text_input("SMILES:", DEFAULT_MOL) smile_code = st_ketcher(molecule) target_molecule = mol_from_smiles(smile_code) building_blocks_path = hf_hub_download( repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", filename="building_blocks_em_sa_ln.smi", subfolder="building_blocks", local_dir="." ) ranking_policy_weights_path = hf_hub_download( repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", filename="ranking_policy_network.ckpt", subfolder="uspto/weights", local_dir="." ) reaction_rules_path = hf_hub_download( repo_id="Laboratoire-De-Chemoinformatique/SynPlanner", filename="uspto_reaction_rules.pickle", subfolder="uspto", local_dir="." ) st.header('Launch calculation') st.markdown( '''If you modified the structure, please ensure you clicked on `Apply` (bottom right of the molecular editor).''' ) st.markdown(f"The molecule SMILES is actually: ``{smile_code}``") st.subheader('Planning options') st.markdown( ''' The description of each option can be found in the [Retrosynthetic Planning Tutorial](https://synplanner.readthedocs.io/en/latest/tutorial_files/retrosynthetic_planning.html#Configuring-search-tree). ''' ) col_options_1, col_options_2 = st.columns(2, gap="medium") with col_options_1: search_strategy_input = st.selectbox(label='Search strategy', options=('Expansion first', 'Evaluation first',), index=0) ucb_type = st.selectbox(label='Search strategy', options=('uct', 'puct', 'value'), index=0) c_ucb = st.number_input("C coefficient of UCB", value=0.1, placeholder="Type a number...") with col_options_2: max_iterations = st.slider('Total number of MCTS iterations', min_value=50, max_value=300, value=100) max_depth = st.slider('Maximal number of reaction steps', min_value=3, max_value=9, value=6) min_mol_size = st.slider('Minimum size of a molecule to be precursor', min_value=0, max_value=7, value=0) search_strategy_translator = { "Expansion first": "expansion_first", "Evaluation first": "evaluation_first", } search_strategy = search_strategy_translator[search_strategy_input] submit_planning = st.button('Start retrosynthetic planning') if submit_planning: with st.status("Downloading data"): st.write("Downloading building blocks") building_blocks = load_building_blocks(building_blocks_path, standardize=False) st.write('Downloading reaction rules') reaction_rules = load_reaction_rules(reaction_rules_path) st.write('Loading policy network') policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path) policy_function = PolicyNetworkFunction(policy_config=policy_config) tree_config = TreeConfig( search_strategy=search_strategy, evaluation_type="rollout", max_iterations=max_iterations, max_depth=max_depth, min_mol_size=min_mol_size, init_node_value=0.5, ucb_type=ucb_type, c_ucb=c_ucb, silent=True ) tree = Tree( target=target_molecule, config=tree_config, reaction_rules=reaction_rules, building_blocks=building_blocks, expansion_function=policy_function, evaluation_function=None, ) mcts_progress_text = "Running retrosynthetic planning" mcts_bar = st.progress(0, text=mcts_progress_text) for step, (solved, node_id) in enumerate(tree): mcts_bar.progress(step / max_iterations, text=mcts_progress_text) res = extract_tree_stats(tree, target_molecule) st.header('Results') if res["solved"]: st.balloons() st.subheader("Examples of found retrosynthetic routes") image_counter = 0 visualised_node_ids = set() for n, node_id in enumerate(sorted(set(tree.winning_nodes))): if image_counter == 3: break if n % 2 == 0 and node_id not in visualised_node_ids: visualised_node_ids.add(node_id) image_counter += 1 num_steps = len(tree.synthesis_route(node_id)) route_score = round(tree.route_score(node_id), 3) st.image(get_route_svg(tree, node_id), caption=f"Route {node_id}; {num_steps} steps; Route score: {route_score}") stat_col, download_col = st.columns(2, gap="medium") with stat_col: st.subheader("Statistics") df = pd.DataFrame(res, index=[0]) st.write(df[["target_smiles", "num_routes", "num_nodes", "num_iter", "search_time"]]) with download_col: st.subheader("Downloads") html_body = generate_results_html(tree, html_path=None, extended=True) dl_html = download_button(html_body, 'results_synplanner.html', 'Download results as a HTML file') dl_csv = download_button(pd.DataFrame(res, index=[0]), 'results_synplanner.csv', 'Download statistics as a csv file') st.markdown(dl_html + dl_csv, unsafe_allow_html=True) else: st.write("Found no reaction path.") st.divider() st.header('Restart from the beginning?') if st.button("Restart"): st.rerun()