Spaces:
Sleeping
Sleeping
File size: 6,004 Bytes
1760662 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import streamlit as st
from streamlit_ketcher import st_ketcher
from SynTool.mcts.tree import Tree, TreeConfig
from SynTool.mcts.expansion import PolicyFunction
from SynTool.mcts.search import extract_tree_stats
from SynTool.utils.config import PolicyNetworkConfig
from SynTool.interfaces.visualisation import to_table, extract_routes
import pickle
import uuid
import base64
import pandas as pd
import json
import re
def download_button(object_to_download, download_filename, button_text, pickle_it=False):
"""
Issued from
Generates a link to download the given object_to_download.
Params:
------
object_to_download: The object to be downloaded.
download_filename (str): filename and extension of file. e.g. mydata.csv,
some_txt_output.txt download_link_text (str): Text to display for download
link.
button_text (str): Text to display on download button (e.g. 'click here to download file')
pickle_it (bool): If True, pickle file.
Returns:
-------
(str): the anchor tag to download object_to_download
Examples:
--------
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!')
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!')
"""
if pickle_it:
try:
object_to_download = pickle.dumps(object_to_download)
except pickle.PicklingError as e:
st.write(e)
return None
else:
if isinstance(object_to_download, bytes):
pass
elif isinstance(object_to_download, pd.DataFrame):
object_to_download = object_to_download.to_csv(index=False).encode('utf-8')
# Try JSON encode for everything else
# else:
# object_to_download = json.dumps(object_to_download)
try:
# some strings <-> bytes conversions necessary here
b64 = base64.b64encode(object_to_download.encode()).decode()
except AttributeError:
b64 = base64.b64encode(object_to_download).decode()
button_uuid = str(uuid.uuid4()).replace('-', '')
button_id = re.sub('\d+', '', button_uuid)
custom_css = f"""
<style>
#{button_id} {{
background-color: rgb(255, 255, 255);
color: rgb(38, 39, 48);
text-decoration: none;
border-radius: 4px;
border-width: 1px;
border-style: solid;
border-color: rgb(230, 234, 241);
border-image: initial;
}}
#{button_id}:hover {{
border-color: rgb(246, 51, 102);
color: rgb(246, 51, 102);
}}
#{button_id}:active {{
box-shadow: none;
background-color: rgb(246, 51, 102);
color: white;
}}
</style> """
dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>'
return dl_link
st.set_page_config( # layout="wide",
page_title="SynTool GUI",
page_icon="🧪",)
st.title("`SynTool GUI`")
st.write("*{Introduction text to be inserted here}*")
st.header('Molecule input')
st.write("You can provide a molecular structure by either providing its SMILES string + Enter, either by drawing it + Apply.")
DEFAULT_MOL='NC(CCCCB(O)O)(CCN1CCC(CO)C1)C(=O)O'
molecule = st.text_input("Molecule", DEFAULT_MOL)
smile_code = st_ketcher(molecule)
st.header('Launch calculation')
st.write("If you modified the structure, please ensure you clicked on 'Apply' (bottom right of the molecular editor).")
st.markdown(f"The molecule SMILES is actually: ``{smile_code}``")
max_depth = st.slider('Maximal number of reaction steps', min_value=2, max_value=9, value=9)
run_default = st.button('Launch and search a reaction path',)
ranking_policy_weights_path = 'data/policy_network.ckpt'
reaction_rules_path = 'data/reaction_rules.pickle'
building_blocks_path = 'data/building_blocks.smi'
policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path)
policy_function = PolicyFunction(policy_config=policy_config)
if run_default:
st.toast('Optimisation is started. The progress will be printed below')
spinner = st.spinner(text="Running with default parameters...")
tree_config = TreeConfig(
search_strategy="expansion_first",
evaluation_type="rollout",
max_iterations=100,
max_depth=max_depth,
min_mol_size=0,
init_node_value=0.5,
ucb_type="uct",
c_ucb=0.1,
silent=True
)
with spinner:
tree = Tree(
target=smile_code,
tree_config=tree_config,
reaction_rules_path=reaction_rules_path,
building_blocks_path=building_blocks_path,
policy_function=policy_function,
value_function=None,
)
_ = list(tree)
res = extract_tree_stats(tree, smile_code) # extract_routes(tree)
st.header('Results')
if res['found_paths']:
st.write("Success!")
st.subheader("Retrosynthetic Routes Report")
st.markdown(to_table(tree, None, extended=True, integration=True), unsafe_allow_html=True)
st.subheader("Statistics")
st.write(pd.DataFrame(res, index=[0]))
st.subheader("Downloads")
dl_html = download_button(to_table(tree, None, extended=True, integration=False),
'results_syntool.html',
'Download results as a HTML file')
dl_csv = download_button(pd.DataFrame(res, index=[0]),
'results_syntool.csv',
'Download statistics as an Excel csv file')
st.markdown(dl_html+dl_csv, unsafe_allow_html=True)
else:
st.write("Found no reaction path.")
st.divider()
st.header('Restart from the beginning?')
if st.button("Restart"):
st.rerun()
|