import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import pickle | |
import pygad | |
from tqdm.auto import tqdm | |
from VQGAE.models import VQGAE, OrderingNetwork | |
from CGRtools.containers import QueryContainer | |
from VQGAE.utils import frag_counts_to_inds, restore_order, decode_molecules | |
# define groups to filter | |
allene = QueryContainer() | |
allene.add_atom("C") | |
allene.add_atom("A") | |
allene.add_atom("A") | |
allene.add_bond(1, 2, 2) | |
allene.add_bond(1, 3, 2) | |
peroxide_charge = QueryContainer() | |
peroxide_charge.add_atom("O", charge=-1) | |
peroxide_charge.add_atom("O") | |
peroxide_charge.add_bond(1, 2, 1) | |
peroxide = QueryContainer() | |
peroxide.add_atom("O") | |
peroxide.add_atom("O") | |
peroxide.add_bond(1, 2, 1) | |
def tanimoto_kernel(x, y): | |
""" | |
"The Tanimoto coefficient is a measure of the similarity between two sets. | |
It is defined as the size of the intersection divided by the size of the union of the sample sets." | |
The Tanimoto coefficient is also known as the Jaccard index | |
Adoppted from https://github.com/cimm-kzn/CIMtools/blob/master/CIMtools/metrics/pairwise.py | |
:param x: 2D array of features. | |
:param y: 2D array of features. | |
:return: The Tanimoto coefficient between the two arrays. | |
""" | |
x_dot = np.dot(x, y.T) | |
x2 = (x ** 2).sum(axis=1) | |
y2 = (y ** 2).sum(axis=1) | |
len_x2 = len(x2) | |
len_y2 = len(y2) | |
result = x_dot / (np.array([x2] * len_y2).T + np.array([y2] * len_x2) - x_dot) | |
result[np.isnan(result)] = 0 | |
return result | |
def rescoring(vqgae_latents): | |
frag_counts = np.array(vqgae_latents) | |
rf_scores = rf_model.predict_proba(frag_counts)[:, 1] | |
similarity_scores = tanimoto_kernel(frag_counts, X).max(-1) | |
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) | |
_, ordering_scores = restore_order(frag_inds, ordering_model) | |
return rf_scores.tolist(), similarity_scores.tolist(), ordering_scores | |
def fitness_func_batch(ga_instance, solutions, solutions_indices): | |
frag_counts = np.array(solutions) | |
# prediction of activity by random forest | |
rf_score = rf_model.predict_proba(frag_counts)[:, 1] | |
# size penalty if molecule too small | |
mol_size = frag_counts.sum(-1).astype(np.int64) | |
size_penalty = np.where(mol_size < 18, -1.0, 0.) | |
# adding dissimilarity so it generates different solutions | |
dissimilarity_score = 1 - tanimoto_kernel(frag_counts, X).max(-1) | |
dissimilarity_score += np.where(dissimilarity_score == 0, -5, 0) | |
# prediction of ordering score | |
frag_inds = frag_counts_to_inds(frag_counts, max_atoms=51) | |
_, ordering_scores = restore_order(frag_inds, ordering_model) | |
ordering_scores = np.array(ordering_scores) | |
# full fitness function | |
fitness = 0.5 * rf_score + 0.3 * dissimilarity_score + size_penalty + 0.2 * ordering_scores | |
return fitness.tolist() | |
def on_generation_progress(ga): | |
pbar.update(1) | |
def load_data(batch_size): | |
X = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["x"] | |
Y = np.load("saved_model/tubulin_qsar_class_train_data_vqgae.npz")["y"] | |
with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp: | |
rf_model = pickle.load(inp) | |
vqgae_model = VQGAE.load_from_checkpoint( | |
"saved_model/vqgae.ckpt", | |
task="decode", | |
batch_size=batch_size, | |
map_location="cpu" | |
) | |
vqgae_model = vqgae_model.eval() | |
ordering_model = OrderingNetwork.load_from_checkpoint( | |
"saved_model/ordering_network.ckpt", | |
batch_size=batch_size, | |
map_location="cpu" | |
) | |
ordering_model = ordering_model.eval() | |
return X, Y, rf_model, vqgae_model, ordering_model | |
st.title('Inverse QSAR of Tubulin inhibitors in colchicine site with VQGAE') | |
data_load_state = st.text('Loading data...') | |
batch_size = 500 | |
X, Y, rf_model, vqgae_model, ordering_model = load_data(batch_size) | |
data_load_state.text("Done! (using st.cache_data)") | |
# initial_pop = X | |
# | |
# num_parents_mating = int(initial_pop.shape[0] * 0.33 // 10 * 10) | |
# keep_parents = int(num_parents_mating * 0.66 // 10 * 10) | |
# print(num_parents_mating, keep_parents) | |
# | |
# num_generations = 30 | |
# with tqdm(total=num_generations) as pbar: | |
# ga_instance = pygad.GA( | |
# fitness_func=fitness_func_batch, | |
# on_generation=on_generation_progress, | |
# initial_population=initial_pop, | |
# num_genes=initial_pop.shape[-1], | |
# fitness_batch_size=batch_size, | |
# num_generations=num_generations, | |
# num_parents_mating=num_parents_mating, | |
# parent_selection_type="rws", | |
# crossover_type="single_point", | |
# mutation_type="adaptive", | |
# mutation_percent_genes=[10, 5], | |
# # https://pygad.readthedocs.io/en/latest/pygad.html#use-adaptive-mutation-in-pygad | |
# save_best_solutions=False, | |
# save_solutions=True, | |
# keep_elitism=0, # turn it off to make keep_parents work | |
# keep_parents=keep_parents, # 2/3 of num_parents_mating | |
# # parallel_processing=['process', 5], | |
# suppress_warnings=True, | |
# random_seed=42, | |
# gene_type=int | |
# ) | |
# ga_instance.run() | |
# | |
# solutions = ga_instance.solutions | |
# solutions = list(set(tuple(s) for s in solutions)) | |
# print(len(solutions)) | |
# | |
# scores = {"rf_score": [], "similarity_score": [], "ordering_score": []} | |
# for i in tqdm(range(len(solutions) // 100 + 1)): | |
# solution = solutions[i * 100: (i + 1) * 100] | |
# rf_score, similarity_score, ordering_score = rescoring(solution) | |
# scores["rf_score"].extend(rf_score) | |
# scores["similarity_score"].extend(similarity_score) | |
# scores["ordering_score"].extend(ordering_score) | |
# | |
# sc_df = pd.DataFrame(scores) | |
# | |
# chosen_gen = sc_df[(sc_df["similarity_score"] < 0.95) & (sc_df["rf_score"] > 0.5) & (sc_df["ordering_score"] > 0.7)] | |
# | |
# chosen_ids = chosen_gen.index.to_list() | |
# chosen_solutions = np.array([solutions[ind] for ind in chosen_ids]) | |
# gen_frag_inds = frag_counts_to_inds(chosen_solutions, max_atoms=51) | |
# | |
# gen_molecules = [] | |
# results = {"score": [], "valid": []} | |
# for i in tqdm(range(gen_frag_inds.shape[0] // batch_size + 1)): | |
# inputs = gen_frag_inds[i * batch_size: (i + 1) * batch_size] | |
# canon_order_inds, scores = restore_order( | |
# frag_inds=inputs, | |
# ordering_model=ordering_model, | |
# ) | |
# molecules, validity = decode_molecules( | |
# ordered_frag_inds=canon_order_inds, | |
# vqgae_model=vqgae_model | |
# ) | |
# gen_molecules.extend(molecules) | |
# results["score"].extend(scores) | |
# results["valid"].extend([1 if i else 0 for i in validity]) | |
# | |
# gen_stats = pd.DataFrame(results) | |
# full_stats = pd.concat([chosen_gen.reset_index(), gen_stats], axis=1, ignore_index=False) | |
# valid_gen_stats = full_stats[full_stats.valid == 1] | |
# valid_gen_mols = [] | |
# for i, record in zip(list(valid_gen_stats.index), valid_gen_stats.to_dict("records")): | |
# mol = gen_molecules[i] | |
# mol.meta.update({ | |
# "rf_score": record["rf_score"], | |
# "similarity_score": record["similarity_score"], | |
# "ordering_score": record["ordering_score"], | |
# }) | |
# valid_gen_mols.append(mol) | |
# | |
# filtered_gen_mols = [] | |
# for mol in valid_gen_mols: | |
# is_frag = allene < mol or peroxide_charge < mol or peroxide < mol | |
# is_macro = False | |
# for ring in mol.sssr: | |
# if len(ring) > 8 or len(ring) < 4: | |
# is_macro = True | |
# break | |
# if not is_frag and not is_macro: | |
# filtered_gen_mols.append(mol) | |