hlnicholls's picture
Update app.py
a8d51d4
raw
history blame
2.67 kB
import streamlit as st
import re
import numpy as np
import pandas as pd
import sklearn
import xgboost
import shap
import streamlit.components.v1 as components
seed=42
data = pd.read_csv("annotations_dataset.csv")
data = data.set_index("Gene")
training_data = pd.read_csv("./selected_features_training_data.csv", header=0)
training_data.columns = [
regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col
for col in training_data.columns.values
]
training_data["BPlabel_encoded"] = training_data["BPlabel"].map(
{"most likely": 1, "probable": 0.75, "least likely": 0.1}
)
Y = training_data["BPlabel_encoded"]
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"])
xgb = xgboost.XGBRegressor(
n_estimators=40,
learning_rate=0.2,
max_depth=4,
reg_alpha=1,
reg_lambda=1,
random_state=seed,
objective="reg:squarederror",
)
xgb.fit(X, Y)
predictions = list(xgb.predict(data))
predictions = [round(item, 2) for item in predictions]
output = pd.Series(data=predictions, index=data.index, name="XGB_Score")
df_total = pd.concat([data, output], axis=1)
df_total.rename_axis('Gene').reset_index()
df_total = df_total[['XGB_Score', 'mousescore_Exomiser',
'SDI', 'Liver_GTExTPM', 'pLI_ExAC',
'HIPred',
'Cells - EBV-transformed lymphocytes_GTExTPM',
'Pituitary_GTExTPM',
'IPA_BP_annotation']]
st.title('Blood Pressure Gene Prioritisation Post-GWAS')
st.markdown("""
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
""")
collect_genes = lambda x : [str(i) for i in re.split(",|, ", x) if i != ""]
input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):")
gene_list = collect_genes(input_gene_list)
explainer = shap.TreeExplainer(xgb)
if len(gene_list) > 1:
df = df_total[df_total.index.isin(gene_list)]
st.dataframe(df)
df.drop(columns='XGB_Score', inplace=True)
shap_values = explainer.shap_values(df)
summary_plot = shap.summary_plot(shap_values, df)
st.caption("SHAP Summary Plot of All Input Genes")
components.html(summary_plot, scrolling = True)
else:
pass
input_gene = st.text_input("Input individual HGNC gene:")
df2 = df_total[df_total.index == input_gene]
st.dataframe(df2)
df2.drop(columns='XGB_Score', inplace=True)
if len(input_gene) == 1:
shap_values = explainer.shap_values(df2)
shap.initjs()
force_plot = shap.force_plot(
explainer.expected_value,
shap_values.values,
df2)
components.html(force_plot, scrolling = True)
else:
pass
st.markdown("""
Total Gene Prioritisation Results:
""")
st.dataframe(df_total)