hlnicholls's picture
Update app.py
d5fefc1
raw
history blame
3.2 kB
import streamlit as st
import re
import numpy as np
import pandas as pd
import sklearn
import xgboost
import shap
st.set_option('deprecation.showPyplotGlobalUse', False)
seed=42
annotations = pd.read_csv("annotations_dataset.csv")
annotations = annotations.set_index("Gene")
training_data = pd.read_csv("./selected_features_training_data.csv", header=0)
training_data.columns = [
regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col
for col in training_data.columns.values
]
training_data["BPlabel_encoded"] = training_data["BPlabel"].map(
{"most likely": 1, "probable": 0.75, "least likely": 0.1}
)
Y = training_data["BPlabel_encoded"]
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"])
xgb = xgboost.XGBRegressor(
n_estimators=40,
learning_rate=0.2,
max_depth=4,
reg_alpha=1,
reg_lambda=1,
random_state=seed,
objective="reg:squarederror")
xgb.fit(X, Y)
prediction_list = list(xgb.predict(annotations))
predictions = [round(prediction, 2) for prediction in prediction_list]
output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score")
df_total = pd.concat([annotations, output], axis=1)
#df_total['Gene'] = df_total.index
#df_total.reset_index()
df_total.rename_axis('Gene')
df_total = df_total[['XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.title('Blood Pressure Gene Prioritisation Post-GWAS')
st.markdown("""
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
""")
collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):")
gene_list = collect_genes(input_gene_list)
explainer = shap.TreeExplainer(xgb)
@st.experimental_memo
def convert_df(df):
return df.to_csv(index=False).encode('utf-8')
if len(gene_list) > 1:
df = df_total[df_total.index.isin(gene_list)]
st.dataframe(df)
df['Gene'] = df.index
output = df[['Gene', 'XGB_Score']]
csv = convert_df(output)
st.download_button(
"Download Gene Prioritisation",
csv,
"bp_gene_prioritisation.csv",
"text/csv",
key='download-csv'
)
df_shap = df_total[df_total.index.isin(gene_list)]
df_shap.drop(columns='XGB_Score', inplace=True)
shap_values = explainer.shap_values(df_shap)
summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
st.caption("SHAP Summary Plot of All Input Genes")
st.pyplot(fig=summary_plot)
else:
pass
input_gene = st.text_input("Input individual HGNC gene:")
df2 = df_total[df_total.index == input_gene]
st.dataframe(df2)
df2.drop(columns='XGB_Score', inplace=True)
if input_gene:
shap_values = explainer.shap_values(df2)
shap.getjs()
force_plot = shap.force_plot(
explainer.expected_value,
shap_values,
df2,
matplotlib = True,show=False)
st.pyplot(fig=force_plot)
else:
pass
st.markdown("""
Total Gene Prioritisation Results:
""")
st.dataframe(df_total)