import streamlit as st import re import numpy as np import pandas as pd import sklearn import xgboost import shap import streamlit.components.v1 as components seed=42 data = pd.read_csv("annotations_dataset.csv") data = data.set_index("Gene") training_data = pd.read_csv("./selected_features_training_data.csv", header=0) training_data.columns = [ regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col for col in training_data.columns.values ] training_data["BPlabel_encoded"] = training_data["BPlabel"].map( {"most likely": 1, "probable": 0.75, "least likely": 0.1} ) Y = training_data["BPlabel_encoded"] X = training_data.drop(columns=["BPlabel_encoded","BPlabel"]) xgb = xgboost.XGBRegressor( n_estimators=40, learning_rate=0.2, max_depth=4, reg_alpha=1, reg_lambda=1, random_state=seed, objective="reg:squarederror", ) xgb.fit(X, Y) predictions = list(xgb.predict(data)) predictions = [round(item, 2) for item in predictions] output = pd.Series(data=predictions, index=data.index, name="XGB_Score") df_total = pd.concat([data, output], axis=1) df_total.rename_axis('Gene').reset_index() df_total = df_total[['XGB_Score', 'mousescore_Exomiser', 'SDI', 'Liver_GTExTPM', 'pLI_ExAC', 'HIPred', 'Cells - EBV-transformed lymphocytes_GTExTPM', 'Pituitary_GTExTPM', 'IPA_BP_annotation']] st.title('Blood Pressure Gene Prioritisation Post-GWAS') st.markdown(""" A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure. """) collect_genes = lambda x : [str(i) for i in re.split(",|, ", x) if i != ""] input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):") gene_list = collect_genes(input_gene_list) explainer = shap.TreeExplainer(xgb) if len(gene_list) > 1: df = df_total[df_total.index.isin(gene_list)] st.dataframe(df) df.drop(columns='XGB_Score', inplace=True) shap_values = explainer.shap_values(df) summary_plot = shap.summary_plot(shap_values, df) st.caption("SHAP Summary Plot of All Input Genes") components.html(summary_plot, scrolling = True) else: pass input_gene = st.text_input("Input individual HGNC gene:") df2 = df_total[df_total.index == input_gene] st.dataframe(df2) df2.drop(columns='XGB_Score', inplace=True) if len(input_gene) == 1: shap_values = explainer.shap_values(df2) shap.initjs() force_plot = shap.force_plot( explainer.expected_value, shap_values.values, df2) components.html(force_plot, scrolling = True) else: pass st.markdown(""" Total Gene Prioritisation Results: """) st.dataframe(df_total)