tagirshin commited on
Commit
3ae9d07
1 Parent(s): 39706b3

added form to keep results

Browse files
Files changed (1) hide show
  1. app.py +178 -176
app.py CHANGED
@@ -56,7 +56,8 @@ def tanimoto_kernel(x, y):
56
 
57
  def fitness_func_batch(ga_instance, solutions, solutions_indices):
58
  frag_counts = np.array(solutions)
59
- st.write(frag_counts.shape)
 
60
 
61
  # prediction of activity by random forest
62
  rf_score = rf_model.predict_proba(frag_counts)[:, 1]
@@ -84,6 +85,7 @@ def fitness_func_batch(ga_instance, solutions, solutions_indices):
84
 
85
  def on_generation_progress(ga):
86
  global ga_progress
 
87
  ga_progress = ga_progress + 1
88
  ga_bar.progress(ga_progress // num_generations * 100, text=ga_progress_text)
89
 
@@ -119,184 +121,184 @@ X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size)
119
  assert X.shape == (603, 4096)
120
 
121
  with st.sidebar:
122
- num_generations = st.slider(
123
- 'Number of generations for GA',
124
- min_value=3,
125
- max_value=40,
126
- value=5
127
- )
128
-
129
- parent_selection_type = st.selectbox(
130
- label='Parent selection type',
131
- options=(
132
- 'Steady-state selection',
133
- 'Roulette wheel selection',
134
- 'Stochastic universal selection',
135
- 'Rank selection',
136
- 'Random selection',
137
- 'Tournament selection'
138
- ),
139
- index=1
140
- )
141
-
142
- parent_selection_translator = {
143
- "Steady-state selection": "sss",
144
- "Roulette wheel selection": "rws",
145
- "Stochastic universal selection": "sus",
146
- "Rank selection": "rank",
147
- "Random selection": "random",
148
- "Tournament selection": "tournament",
149
- }
150
-
151
- parent_selection_type = parent_selection_translator[parent_selection_type]
152
-
153
- crossover_type = st.selectbox(
154
- label='Crossover type',
155
- options=(
156
- 'Single point',
157
- 'Two points',
158
- ),
159
- index=0
160
- )
161
 
162
- crossover_translator = {
163
- "Single point": "single_point",
164
- "Two points": "two_points",
165
- }
 
 
 
 
 
 
 
 
166
 
167
- crossover_type = crossover_translator[crossover_type]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- num_parents_mating = st.slider(
170
- 'Number of parents mating',
171
- min_value=1,
172
- max_value=X.shape[0],
173
- value=int(X.shape[0] * 0.33 // 10 * 10)
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- keep_parents = st.slider(
177
- 'Number of parents kept',
178
- min_value=1,
179
- max_value=num_parents_mating,
180
- value=int(num_parents_mating * 0.66 // 10 * 10) # 2/3 of num_parents_mating
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  )
182
-
183
- use_ordering_score = st.toggle('Use ordering score', value=True)
184
-
185
- random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
186
- start_optimisation = st.button("Start optimisation")
187
-
188
-
189
- if start_optimisation:
190
- ga_instance = pygad.GA(
191
- fitness_func=fitness_func_batch,
192
- on_generation=on_generation_progress,
193
- initial_population=X,
194
- num_genes=X.shape[-1],
195
- fitness_batch_size=batch_size,
196
- num_generations=num_generations,
197
- num_parents_mating=num_parents_mating,
198
- parent_selection_type=parent_selection_type,
199
- crossover_type=crossover_type,
200
- mutation_type="adaptive",
201
- mutation_percent_genes=[10, 5],
202
- # https://pygad.readthedocs.io/en/latest/pygad.html#use-adaptive-mutation-in-pygad
203
- save_best_solutions=False,
204
- save_solutions=True,
205
- keep_elitism=0, # turn it off to make keep_parents work
206
- keep_parents=keep_parents,
207
- suppress_warnings=True,
208
- random_seed=random_seed,
209
- gene_type=int
210
  )
211
-
212
- ga_progress = 0
213
- ga_progress_text = "Genetic optimisation in progress. Please wait."
214
- ga_bar = st.progress(ga_progress // num_generations * 100, text=ga_progress_text)
215
- ga_instance.run()
216
-
217
- with st.spinner('Getting unique solutions'):
218
- unique_solutions = list(set(tuple(s) for s in ga_instance.solutions))
219
- st.success(f'{len(unique_solutions)} solutions were obtained')
220
-
221
- scores = {
222
- "rf_score": [],
223
- "similarity_score": []
224
- }
225
-
226
- if use_ordering_score:
227
- scores["ordering_score"] = []
228
-
229
- rescoring_progress = 0
230
- rescoring_progress_text = "Rescoring obtained solutions"
231
- rescoring_bar = st.progress(0, text=rescoring_progress_text)
232
- total_rescoring_steps = len(unique_solutions) // batch_size + 1
233
- for i in range(total_rescoring_steps):
234
- vqgae_latents = unique_solutions[i * batch_size: (i + 1) * batch_size]
235
- frag_counts = np.array(vqgae_latents)
236
- rf_scores = rf_model.predict_proba(frag_counts)[:, 1]
237
- similarity_scores = tanimoto_kernel(frag_counts, X).max(-1)
238
- scores["rf_score"].extend(rf_scores.tolist())
239
- scores["similarity_score"].extend(similarity_scores.tolist())
240
- if use_ordering_score:
241
- frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
242
- _, ordering_scores = restore_order(frag_inds, ordering_model)
243
- scores["ordering_score"].extend(ordering_scores)
244
- rescoring_bar.progress(i // total_rescoring_steps * 100, text=rescoring_progress_text)
245
-
246
- sc_df = pd.DataFrame(scores)
247
-
248
- if use_ordering_score:
249
- chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)]
250
- else:
251
- chosen_gen = sc_df[
252
- (sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5)]
253
-
254
- chosen_ids = chosen_gen.index.to_list()
255
- chosen_solutions = np.array([unique_solutions[ind] for ind in chosen_ids])
256
- gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51)
257
- st.info(f'The number of chosen solutions is {gen_frag_inds.shape[0]}', icon="ℹ️")
258
-
259
- gen_molecules = []
260
- results = {"smiles": [], "ordering_score": [], "validity": []}
261
- decoding_progress = 0
262
- decoding_progress_text = "Decoding chosen solutions"
263
- decoding_bar = st.progress(0, text=decoding_progress_text)
264
- total_decoding_steps = gen_frag_inds.shape[0] // batch_size + 1
265
- for i in range(total_decoding_steps):
266
- inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size]
267
- canon_order_inds, scores = restore_order(
268
- frag_inds=inputs,
269
- ordering_model=ordering_model,
270
- )
271
- molecules, validity = decode_molecules(
272
- ordered_frag_inds=canon_order_inds,
273
- vqgae_model=vqgae_model
274
- )
275
- gen_molecules.extend(molecules)
276
- results["smiles"].extend([str(molecule) for molecule in molecules])
277
- results["ordering_score"].extend(scores)
278
- results["validity"].extend([1 if i else 0 for i in validity])
279
- decoding_bar.progress(i // total_decoding_steps * 100, text=rescoring_progress_text)
280
-
281
- gen_stats = pd.DataFrame(results)
282
- full_stats = pd.concat([gen_stats, chosen_gen[["similarity_score", "rf_score"]].reset_index(), ], axis=1, ignore_index=False)
283
-
284
- st.dataframe(full_stats)
285
-
286
- # valid_gen_stats = full_stats[full_stats.valid == 1]
287
- #
288
- # valid_gen_mols = []
289
- # for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
290
- # mol = gen_molecules[i]
291
- # valid_gen_mols.append(mol)
292
- #
293
- # filtered_gen_mols = []
294
- # for mol in valid_gen_mols:
295
- # is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
296
- # is_macro = False
297
- # for ring in mol.sssr:
298
- # if len(ring) > 8 or len(ring) < 4:
299
- # is_macro = True
300
- # break
301
- # if not is_frag and not is_macro:
302
- # filtered_gen_mols.append(mol)
 
56
 
57
  def fitness_func_batch(ga_instance, solutions, solutions_indices):
58
  frag_counts = np.array(solutions)
59
+ if len(frag_counts.shape) == 1:
60
+ frag_counts = frag_counts[np.newaxis, :]
61
 
62
  # prediction of activity by random forest
63
  rf_score = rf_model.predict_proba(frag_counts)[:, 1]
 
85
 
86
  def on_generation_progress(ga):
87
  global ga_progress
88
+ global ga_bar
89
  ga_progress = ga_progress + 1
90
  ga_bar.progress(ga_progress // num_generations * 100, text=ga_progress_text)
91
 
 
121
  assert X.shape == (603, 4096)
122
 
123
  with st.sidebar:
124
+ with st.form("my_form"):
125
+ num_generations = st.slider(
126
+ 'Number of generations for GA',
127
+ min_value=3,
128
+ max_value=40,
129
+ value=5
130
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ parent_selection_type = st.selectbox(
133
+ label='Parent selection type',
134
+ options=(
135
+ 'Steady-state selection',
136
+ 'Roulette wheel selection',
137
+ 'Stochastic universal selection',
138
+ 'Rank selection',
139
+ 'Random selection',
140
+ 'Tournament selection'
141
+ ),
142
+ index=1
143
+ )
144
 
145
+ parent_selection_translator = {
146
+ "Steady-state selection": "sss",
147
+ "Roulette wheel selection": "rws",
148
+ "Stochastic universal selection": "sus",
149
+ "Rank selection": "rank",
150
+ "Random selection": "random",
151
+ "Tournament selection": "tournament",
152
+ }
153
+
154
+ parent_selection_type = parent_selection_translator[parent_selection_type]
155
+
156
+ crossover_type = st.selectbox(
157
+ label='Crossover type',
158
+ options=(
159
+ 'Single point',
160
+ 'Two points',
161
+ ),
162
+ index=0
163
+ )
164
 
165
+ crossover_translator = {
166
+ "Single point": "single_point",
167
+ "Two points": "two_points",
168
+ }
169
+
170
+ crossover_type = crossover_translator[crossover_type]
171
+
172
+ num_parents_mating = st.slider(
173
+ 'Pecentage of parents mating taken from initial population',
174
+ min_value=0,
175
+ max_value=X.shape[0],
176
+ step=0.01,
177
+ value=0.33,
178
+ ) * X.shape[0] * 10 // 10
179
+
180
+ keep_parents = st.slider(
181
+ 'Percentage of parents kept taken from number of parents mating',
182
+ min_value=1,
183
+ max_value=num_parents_mating,
184
+ value=int(num_parents_mating * 0.66 // 10 * 10) # 2/3 of num_parents_mating
185
+ )
186
 
187
+ use_ordering_score = st.toggle('Use ordering score', value=True)
188
+
189
+ random_seed = int(st.number_input("Random seed", value=42, placeholder="Type a number..."))
190
+ st.form_submit_button('Start optimisation')
191
+
192
+ ga_instance = pygad.GA(
193
+ fitness_func=fitness_func_batch,
194
+ on_generation=on_generation_progress,
195
+ initial_population=X,
196
+ num_genes=X.shape[-1],
197
+ fitness_batch_size=batch_size,
198
+ num_generations=num_generations,
199
+ num_parents_mating=num_parents_mating,
200
+ parent_selection_type=parent_selection_type,
201
+ crossover_type=crossover_type,
202
+ mutation_type="adaptive",
203
+ mutation_percent_genes=[10, 5],
204
+ # https://pygad.readthedocs.io/en/latest/pygad.html#use-adaptive-mutation-in-pygad
205
+ save_best_solutions=False,
206
+ save_solutions=True,
207
+ keep_elitism=0, # turn it off to make keep_parents work
208
+ keep_parents=keep_parents,
209
+ suppress_warnings=True,
210
+ random_seed=random_seed,
211
+ gene_type=int
212
+ )
213
+
214
+ ga_progress = 0
215
+ ga_progress_text = "Genetic optimisation in progress. Please wait."
216
+ ga_bar = st.progress(ga_progress // num_generations * 100, text=ga_progress_text)
217
+ ga_instance.run()
218
+
219
+ with st.spinner('Getting unique solutions'):
220
+ unique_solutions = list(set(tuple(s) for s in ga_instance.solutions))
221
+ st.success(f'{len(unique_solutions)} solutions were obtained')
222
+
223
+ scores = {
224
+ "rf_score": [],
225
+ "similarity_score": []
226
+ }
227
+
228
+ if use_ordering_score:
229
+ scores["ordering_score"] = []
230
+
231
+ rescoring_progress = 0
232
+ rescoring_progress_text = "Rescoring obtained solutions"
233
+ rescoring_bar = st.progress(0, text=rescoring_progress_text)
234
+ total_rescoring_steps = len(unique_solutions) // batch_size + 1
235
+ for i in range(total_rescoring_steps):
236
+ vqgae_latents = unique_solutions[i * batch_size: (i + 1) * batch_size]
237
+ frag_counts = np.array(vqgae_latents)
238
+ rf_scores = rf_model.predict_proba(frag_counts)[:, 1]
239
+ similarity_scores = tanimoto_kernel(frag_counts, X).max(-1)
240
+ scores["rf_score"].extend(rf_scores.tolist())
241
+ scores["similarity_score"].extend(similarity_scores.tolist())
242
+ if use_ordering_score:
243
+ frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51)
244
+ _, ordering_scores = restore_order(frag_inds, ordering_model)
245
+ scores["ordering_score"].extend(ordering_scores)
246
+ rescoring_bar.progress(i // total_rescoring_steps * 100, text=rescoring_progress_text)
247
+
248
+ sc_df = pd.DataFrame(scores)
249
+
250
+ if use_ordering_score:
251
+ chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)]
252
+ else:
253
+ chosen_gen = sc_df[
254
+ (sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5)]
255
+
256
+ chosen_ids = chosen_gen.index.to_list()
257
+ chosen_solutions = np.array([unique_solutions[ind] for ind in chosen_ids])
258
+ gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51)
259
+ st.info(f'The number of chosen solutions is {gen_frag_inds.shape[0]}', icon="ℹ️")
260
+
261
+ gen_molecules = []
262
+ results = {"smiles": [], "ordering_score": [], "validity": []}
263
+ decoding_progress = 0
264
+ decoding_progress_text = "Decoding chosen solutions"
265
+ decoding_bar = st.progress(0, text=decoding_progress_text)
266
+ total_decoding_steps = gen_frag_inds.shape[0] // batch_size + 1
267
+ for i in range(total_decoding_steps):
268
+ inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size]
269
+ canon_order_inds, scores = restore_order(
270
+ frag_inds=inputs,
271
+ ordering_model=ordering_model,
272
  )
273
+ molecules, validity = decode_molecules(
274
+ ordered_frag_inds=canon_order_inds,
275
+ vqgae_model=vqgae_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  )
277
+ gen_molecules.extend(molecules)
278
+ results["smiles"].extend([str(molecule) for molecule in molecules])
279
+ results["ordering_score"].extend(scores)
280
+ results["validity"].extend([1 if i else 0 for i in validity])
281
+ decoding_bar.progress(i // total_decoding_steps * 100, text=rescoring_progress_text)
282
+
283
+ gen_stats = pd.DataFrame(results)
284
+ full_stats = pd.concat([gen_stats, chosen_gen[["similarity_score", "rf_score"]].reset_index(), ], axis=1, ignore_index=False)
285
+
286
+ st.dataframe(full_stats)
287
+
288
+ # valid_gen_stats = full_stats[full_stats.valid == 1]
289
+ #
290
+ # valid_gen_mols = []
291
+ # for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")):
292
+ # mol = gen_molecules[i]
293
+ # valid_gen_mols.append(mol)
294
+ #
295
+ # filtered_gen_mols = []
296
+ # for mol in valid_gen_mols:
297
+ # is_frag = allene < mol or peroxide_charge < mol or peroxide < mol
298
+ # is_macro = False
299
+ # for ring in mol.sssr:
300
+ # if len(ring) > 8 or len(ring) < 4:
301
+ # is_macro = True
302
+ # break
303
+ # if not is_frag and not is_macro:
304
+ # filtered_gen_mols.append(mol)