hlnicholls's picture
Update app.py
85ec45f
raw
history blame
No virus
4.57 kB
import streamlit as st
import re
import numpy as np
import pandas as pd
import sklearn
import xgboost
import shap
from shap_plots import shap_summary_plot
import plotly.tools as tls
import dash_core_components as dcc
import matplotlib
import plotly.graph_objs as go
try:
import matplotlib.pyplot as pl
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import MaxNLocator
except ImportError:
pass
st.set_option('deprecation.showPyplotGlobalUse', False)
seed=42
annotations = pd.read_csv("annotations_dataset.csv")
annotations = annotations.set_index("Gene")
training_data = pd.read_csv("./selected_features_training_data.csv", header=0)
training_data.columns = [
regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col
for col in training_data.columns.values
]
training_data["BPlabel_encoded"] = training_data["BPlabel"].map(
{"most likely": 1, "probable": 0.75, "least likely": 0.1}
)
Y = training_data["BPlabel_encoded"]
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"])
xgb = xgboost.XGBRegressor(
n_estimators=40,
learning_rate=0.2,
max_depth=4,
reg_alpha=1,
reg_lambda=1,
random_state=seed,
objective="reg:squarederror")
xgb.fit(X, Y)
prediction_list = list(xgb.predict(annotations))
predictions = [round(prediction, 2) for prediction in prediction_list]
output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score")
df_total = pd.concat([annotations, output], axis=1)
df_total = df_total[['XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.title('Blood Pressure Gene Prioritisation Post-GWAS')
st.markdown("""
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
""")
collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]
input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):")
gene_list = collect_genes(input_gene_list)
explainer = shap.TreeExplainer(xgb)
@st.experimental_memo
def convert_df(df):
return df.to_csv(index=False).encode('utf-8')
if len(gene_list) > 1:
df = df_total[df_total.index.isin(gene_list)]
df['Gene'] = df.index
df.reset_index(drop=True, inplace=True)
df = df[['Gene','XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.dataframe(df)
output = df[['Gene', 'XGB_Score']]
csv = convert_df(output)
st.download_button(
"Download Gene Prioritisation",
csv,
"bp_gene_prioritisation.csv",
"text/csv",
key='download-csv'
)
df_shap = df_total[df_total.index.isin(gene_list)]
df_shap.drop(columns='XGB_Score', inplace=True)
shap_values = explainer.shap_values(df_shap)
summary_plot = shap.summary_plot(shap_values, df_shap)
st.pyplot(fig=summary_plot)
st.caption("SHAP Summary Plot of All Input Genes")
feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1])
feature_order = feature_order[-min(8, len(feature_order)):]
col_order = [df_shap.columns[i] for i in feature_order]
else:
pass
input_gene = st.text_input("Input individual HGNC gene:")
df2 = df_total[df_total.index == input_gene]
df2['Gene'] = df2.index
df2.reset_index(drop=True, inplace=True)
df2 = df2[['Gene','XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.dataframe(df2)
if input_gene:
df2_shap = df_total[df_total.index == input_gene]
df2_shap.drop(columns='XGB_Score', inplace=True)
shap_values = explainer.shap_values(df2_shap)
shap.getjs()
force_plot = shap.force_plot(
explainer.expected_value,
shap_values,
df2_shap,
matplotlib = True,show=False)
st.pyplot(fig=force_plot)
else:
pass
st.markdown("""
Total Gene Prioritisation Results:
""")
df_total_output = df_total
df_total_output['Gene'] = df_total_output.index
df_total_output.reset_index(drop=True, inplace=True)
df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.dataframe(df_total_output)