tagirshin commited on
Commit
1adc128
1 Parent(s): ca8cff5

first version of app

Browse files
Files changed (1) hide show
  1. app.py +213 -131
app.py CHANGED
@@ -4,7 +4,6 @@ import pandas as pd
4
  import pickle
5
  import pygad
6
 
7
- from tqdm.auto import tqdm
8
  from VQGAE.models import VQGAE, OrderingNetwork
9
  from CGRtools.containers import QueryContainer
10
  from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
@@ -55,44 +54,6 @@ def tanimoto_kernel(x, y):
55
  return result
56
 
57
 
58
- def rescoring(vqgae_latents):
59
- frag_counts = np.array(vqgae_latents)
60
- rf_scores = rf_model.predict_proba(frag_counts)[:, 1]
61
- similarity_scores = tanimoto_kernel(frag_counts, X).max(-1)
62
-
63
- frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
64
- _, ordering_scores = restore_order(frag_inds, ordering_model)
65
- return rf_scores.tolist(), similarity_scores.tolist(), ordering_scores
66
-
67
-
68
- def fitness_func_batch(ga_instance, solutions, solutions_indices):
69
- frag_counts = np.array(solutions)
70
-
71
- # prediction of activity by random forest
72
- rf_score = rf_model.predict_proba(frag_counts)[:, 1]
73
-
74
- # size penalty if molecule too small
75
- mol_size = frag_counts.sum(-1).astype(np.int64)
76
- size_penalty = np.where(mol_size < 18, -1.0, 0.)
77
-
78
- # adding dissimilarity so it generates different solutions
79
- dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1)
80
- dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0)
81
-
82
- # prediction of ordering score
83
- frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
84
- _, ordering_scores = restore_order(frag_inds, ordering_model)
85
- ordering_scores = np.array(ordering_scores)
86
-
87
- # full fitness function
88
- fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty + 0.2 * ordering_scores
89
- return fitness.tolist()
90
-
91
-
92
- def on_generation_progress(ga):
93
- pbar.update(1)
94
-
95
-
96
  @st.cache_data
97
  def load_data(batch_size):
98
  X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"]
@@ -122,97 +83,218 @@ st.title('Inverse QSAR of Tubulin inhibitors in colchicine site with VQGAE')
122
  data_load_state = st.text('Loading data...')
123
  batch_size = 500
124
  X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size)
125
-
126
  data_load_state.text("Done! (using st.cache_data)")
127
 
128
- # initial_pop = X
129
- #
130
- # num_parents_mating = int(initial_pop.shape[0] * 0.33 // 10 * 10)
131
- # keep_parents = int(num_parents_mating * 0.66 // 10 * 10)
132
- # print(num_parents_mating, keep_parents)
133
- #
134
- # num_generations = 30
135
- # with tqdm(total=num_generations) as pbar:
136
- # ga_instance = pygad.GA(
137
- # fitness_func=fitness_func_batch,
138
- # on_generation=on_generation_progress,
139
- # initial_population=initial_pop,
140
- # num_genes=initial_pop.shape[-1],
141
- # fitness_batch_size=batch_size,
142
- # num_generations=num_generations,
143
- # num_parents_mating=num_parents_mating,
144
- # parent_selection_type="rws",
145
- # crossover_type="single_point",
146
- # mutation_type="adaptive",
147
- # mutation_percent_genes=[10, 5],
148
- # # https://pygad.readthedocs.io/en/latest/pygad.html#use-adaptive-mutation-in-pygad
149
- # save_best_solutions=False,
150
- # save_solutions=True,
151
- # keep_elitism=0, # turn it off to make keep_parents work
152
- # keep_parents=keep_parents, # 2/3 of num_parents_mating
153
- # # parallel_processing=['process', 5],
154
- # suppress_warnings=True,
155
- # random_seed=42,
156
- # gene_type=int
157
- # )
158
- # ga_instance.run()
159
- #
160
- # solutions = ga_instance.solutions
161
- # solutions = list(set(tuple(s) for s in solutions))
162
- # print(len(solutions))
163
- #
164
- # scores = {"rf_score": [], "similarity_score": [], "ordering_score": []}
165
- # for i in tqdm(range(len(solutions) // 100 + 1)):
166
- # solution = solutions[i * 100: (i + 1) * 100]
167
- # rf_score, similarity_score, ordering_score = rescoring(solution)
168
- # scores["rf_score"].extend(rf_score)
169
- # scores["similarity_score"].extend(similarity_score)
170
- # scores["ordering_score"].extend(ordering_score)
171
- #
172
- # sc_df = pd.DataFrame(scores)
173
- #
174
- # chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)]
175
- #
176
- # chosen_ids = chosen_gen.index.to_list()
177
- # chosen_solutions = np.array([solutions[ind] for ind in chosen_ids])
178
- # gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51)
179
- #
180
- # gen_molecules = []
181
- # results = {"score": [], "valid": []}
182
- # for i in tqdm(range(gen_frag_inds.shape[0] // batch_size + 1)):
183
- # inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size]
184
- # canon_order_inds, scores = restore_order(
185
- # frag_inds=inputs,
186
- # ordering_model=ordering_model,
187
- # )
188
- # molecules, validity = decode_molecules(
189
- # ordered_frag_inds=canon_order_inds,
190
- # vqgae_model=vqgae_model
191
- # )
192
- # gen_molecules.extend(molecules)
193
- # results["score"].extend(scores)
194
- # results["valid"].extend([1 if i else 0 for i in validity])
195
- #
196
- # gen_stats = pd.DataFrame(results)
197
- # full_stats = pd.concat([chosen_gen.reset_index(), gen_stats], axis=1, ignore_index=False)
198
- # valid_gen_stats = full_stats[full_stats.valid == 1]
199
- # valid_gen_mols = []
200
- # for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
201
- # mol = gen_molecules[i]
202
- # mol.meta.update({
203
- # "rf_score": record["rf_score"],
204
- # "similarity_score": record["similarity_score"],
205
- # "ordering_score": record["ordering_score"],
206
- # })
207
- # valid_gen_mols.append(mol)
208
- #
209
- # filtered_gen_mols = []
210
- # for mol in valid_gen_mols:
211
- # is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
212
- # is_macro = False
213
- # for ring in mol.sssr:
214
- # if len(ring) > 8 or len(ring) < 4:
215
- # is_macro = True
216
- # break
217
- # if not is_frag and not is_macro:
218
- # filtered_gen_mols.append(mol)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import pickle
5
  import pygad
6
 
 
7
  from VQGAE.models import VQGAE, OrderingNetwork
8
  from CGRtools.containers import QueryContainer
9
  from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules
 
54
  return result
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @st.cache_data
58
  def load_data(batch_size):
59
  X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"]
 
83
  data_load_state = st.text('Loading data...')
84
  batch_size = 500
85
  X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size)
 
86
  data_load_state.text("Done! (using st.cache_data)")
87
 
88
+ num_generations = st.slider(
89
+ 'Number of generations for GA',
90
+ min_value=3,
91
+ max_value=40,
92
+ value=5
93
+ )
94
+
95
+ parent_selection_type = st.selectbox(
96
+ label='Parent selection type',
97
+ options=(
98
+ 'Steady-state selection',
99
+ 'Roulette wheel selection',
100
+ 'Stochastic universal selection',
101
+ 'Rank selection',
102
+ 'Random selection',
103
+ 'Tournament selection'
104
+ ),
105
+ index=1
106
+ )
107
+
108
+ parent_selection_translator = {
109
+ "Steady-state selection": "sss",
110
+ "Roulette wheel selection": "rws",
111
+ "Stochastic universal selection": "sus",
112
+ "Rank selection": "rank",
113
+ "Random selection": "random",
114
+ "Tournament selection": "tournament",
115
+ }
116
+
117
+ parent_selection_type = parent_selection_translator[parent_selection_type]
118
+
119
+ crossover_type = st.selectbox(
120
+ label='Crossover type',
121
+ options=(
122
+ 'Single point',
123
+ 'Two points',
124
+ ),
125
+ index=0
126
+ )
127
+
128
+ crossover_translator = {
129
+ "Single point": "single_point",
130
+ "Two points": "two_points",
131
+ }
132
+
133
+ crossover_type = crossover_translator[crossover_type]
134
+
135
+ num_parents_mating = st.slider(
136
+ 'Number of generations for GA',
137
+ min_value=1,
138
+ max_value=X.shape[0],
139
+ value=int(X.shape[0] * 0.33 // 10 * 10)
140
+ )
141
+
142
+ keep_parents = st.slider(
143
+ 'Number of generations for GA',
144
+ min_value=1,
145
+ max_value=num_parents_mating,
146
+ value=int(num_parents_mating * 0.66 // 10 * 10) # 2/3 of num_parents_mating
147
+ )
148
+
149
+ use_ordering_score = st.toggle('Use ordering score', value=True)
150
+
151
+ random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
152
+
153
+
154
+ def fitness_func_batch(ga_instance, solutions, solutions_indices):
155
+ frag_counts = np.array(solutions)
156
+
157
+ # prediction of activity by random forest
158
+ rf_score = rf_model.predict_proba(frag_counts)[:, 1]
159
+
160
+ # size penalty if molecule too small
161
+ mol_size = frag_counts.sum(-1).astype(np.int64)
162
+ size_penalty = np.where(mol_size < 18, -1.0, 0.)
163
+
164
+ # adding dissimilarity so it generates different solutions
165
+ dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1)
166
+ dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0)
167
+
168
+ # full fitness function
169
+ fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty
170
+
171
+ # prediction of ordering score
172
+ if use_ordering_score:
173
+ frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
174
+ _, ordering_scores = restore_order(frag_inds, ordering_model)
175
+ ordering_scores = np.array(ordering_scores)
176
+ fitness += 0.2 * ordering_scores
177
+
178
+ return fitness.tolist()
179
+
180
+
181
+ def on_generation_progress(ga):
182
+ global ga_progress
183
+ ga_progress = ga_progress + 1
184
+ ga_bar.progress(ga_progress // num_generations * 100, text=ga_progress_text)
185
+
186
+
187
+ if st.button("Start optimisation"):
188
+ ga_instance = pygad.GA(
189
+ fitness_func=fitness_func_batch,
190
+ on_generation=on_generation_progress,
191
+ initial_population=X,
192
+ num_genes=X.shape[-1],
193
+ fitness_batch_size=batch_size,
194
+ num_generations=num_generations,
195
+ num_parents_mating=num_parents_mating,
196
+ parent_selection_type=parent_selection_type,
197
+ crossover_type=crossover_type,
198
+ mutation_type="adaptive",
199
+ mutation_percent_genes=[10, 5],
200
+ # https://pygad.readthedocs.io/en/latest/pygad.html#use-adaptive-mutation-in-pygad
201
+ save_best_solutions=False,
202
+ save_solutions=True,
203
+ keep_elitism=0, # turn it off to make keep_parents work
204
+ keep_parents=keep_parents,
205
+ suppress_warnings=True,
206
+ random_seed=random_seed,
207
+ gene_type=int
208
+ )
209
+
210
+ ga_progress = 0
211
+ ga_progress_text = "Genetic optimisation in progress. Please wait."
212
+ ga_bar = st.progress(ga_progress // num_generations * 100, text=ga_progress_text)
213
+ ga_instance.run()
214
+
215
+ with st.spinner('Getting unique solutions'):
216
+ unique_solutions = list(set(tuple(s) for s in ga_instance.solutions))
217
+ st.success(f'{len(unique_solutions)} solutions were obtained')
218
+
219
+ scores = {
220
+ "rf_score": [],
221
+ "similarity_score": []
222
+ }
223
+
224
+ if use_ordering_score:
225
+ scores["ordering_score"] = []
226
+
227
+ rescoring_progress = 0
228
+ rescoring_progress_text = "Rescoring obtained solutions"
229
+ rescoring_bar = st.progress(0, text=rescoring_progress_text)
230
+ total_rescoring_steps = len(unique_solutions) // batch_size + 1
231
+ for i in range(total_rescoring_steps):
232
+ vqgae_latents = unique_solutions[i * batch_size: (i + 1) * batch_size]
233
+ frag_counts = np.array(vqgae_latents)
234
+ rf_scores = rf_model.predict_proba(frag_counts)[:, 1]
235
+ similarity_scores = tanimoto_kernel(frag_counts, X).max(-1)
236
+ scores["rf_score"].extend(rf_scores.tolist())
237
+ scores["similarity_score"].extend(similarity_scores.tolist())
238
+ if use_ordering_score:
239
+ frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
240
+ _, ordering_scores = restore_order(frag_inds, ordering_model)
241
+ scores["ordering_score"].extend(ordering_scores)
242
+ rescoring_bar.progress(i // total_rescoring_steps * 100, text=rescoring_progress_text)
243
+
244
+ sc_df = pd.DataFrame(scores)
245
+
246
+ if use_ordering_score:
247
+ chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)]
248
+ else:
249
+ chosen_gen = sc_df[
250
+ (sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5)]
251
+
252
+ chosen_ids = chosen_gen.index.to_list()
253
+ chosen_solutions = np.array([unique_solutions[ind] for ind in chosen_ids])
254
+ gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51)
255
+ st.info(f'The number of chosen solutions is {gen_frag_inds.shape[0]}', icon="ℹ️")
256
+
257
+ gen_molecules = []
258
+ results = {"smiles": [], "ordering_score": [], "validity": []}
259
+ decoding_progress = 0
260
+ decoding_progress_text = "Decoding chosen solutions"
261
+ decoding_bar = st.progress(0, text=decoding_progress_text)
262
+ total_decoding_steps = gen_frag_inds.shape[0] // batch_size + 1
263
+ for i in range(total_decoding_steps):
264
+ inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size]
265
+ canon_order_inds, scores = restore_order(
266
+ frag_inds=inputs,
267
+ ordering_model=ordering_model,
268
+ )
269
+ molecules, validity = decode_molecules(
270
+ ordered_frag_inds=canon_order_inds,
271
+ vqgae_model=vqgae_model
272
+ )
273
+ gen_molecules.extend(molecules)
274
+ results["smiles"].extend([str(molecule) for molecule in molecules])
275
+ results["ordering_score"].extend(scores)
276
+ results["validity"].extend([1 if i else 0 for i in validity])
277
+ decoding_bar.progress(i // total_decoding_steps * 100, text=rescoring_progress_text)
278
+
279
+ gen_stats = pd.DataFrame(results)
280
+ full_stats = pd.concat([gen_stats, chosen_gen[["similarity_score", "rf_score"]].reset_index(), ], axis=1, ignore_index=False)
281
+
282
+ st.dataframe(full_stats)
283
+
284
+ # valid_gen_stats = full_stats[full_stats.valid == 1]
285
+ #
286
+ # valid_gen_mols = []
287
+ # for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
288
+ # mol = gen_molecules[i]
289
+ # valid_gen_mols.append(mol)
290
+ #
291
+ # filtered_gen_mols = []
292
+ # for mol in valid_gen_mols:
293
+ # is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
294
+ # is_macro = False
295
+ # for ring in mol.sssr:
296
+ # if len(ring) > 8 or len(ring) < 4:
297
+ # is_macro = True
298
+ # break
299
+ # if not is_frag and not is_macro:
300
+ # filtered_gen_mols.append(mol)