File size: 3,164 Bytes
f60ce93
7d04c1c
0059ef7
 
bcf8eca
bf91270
a2e3d4b
9ccc625
7944a63
0059ef7
037b3d4
66ad10a
 
0059ef7
4d92e12
4b60c06
 
 
 
0059ef7
4b60c06
 
 
 
5c158f1
4b60c06
 
 
 
 
 
 
cfb90d5
4b60c06
 
 
 
d5fefc1
037b3d4
0059ef7
037b3d4
 
ac213c9
f60ce93
060dcc2
45b475d
 
 
 
 
 
 
659d788
f60ce93
 
 
 
a9b361d
c73e4be
a044018
a2e3d4b
 
 
d53ce83
3592cb3
 
 
 
ff8cf9b
330195f
 
1b01b6f
 
 
3592cb3
5e24005
3592cb3
1b01b6f
3592cb3
 
 
cfb90d5
5e24005
 
 
a2e3d4b
f83ec42
a2e3d4b
 
 
 
 
 
 
 
7c6b695
9ccc625
 
a2e3d4b
e558219
a2e3d4b
 
1dab293
47da41e
e558219
f83ec42
330195f
 
a044018
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import re
import numpy as np
import pandas as pd
import sklearn
import xgboost
import shap
st.set_option('deprecation.showPyplotGlobalUse', False)
seed=42

annotations = pd.read_csv("annotations_dataset.csv")
annotations = annotations.set_index("Gene")
annotations.rename_axis(index='Gene', inplace=True)

training_data = pd.read_csv("./selected_features_training_data.csv", header=0)
training_data.columns = [
    regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col
    for col in training_data.columns.values
]

training_data["BPlabel_encoded"] = training_data["BPlabel"].map(
    {"most likely": 1, "probable": 0.75, "least likely": 0.1}
)
Y = training_data["BPlabel_encoded"]
X = training_data.drop(columns=["BPlabel_encoded","BPlabel"])
xgb = xgboost.XGBRegressor(
    n_estimators=40,
    learning_rate=0.2,
    max_depth=4,
    reg_alpha=1,
    reg_lambda=1,
    random_state=seed,
    objective="reg:squarederror")


xgb.fit(X, Y)

prediction_list = list(xgb.predict(annotations))
predictions = [round(prediction, 2) for prediction in prediction_list]

output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score")
df_total = pd.concat([annotations, output], axis=1)


df_total = df_total[['XGB_Score', 'mousescore_Exomiser',
 'SDI', 'Liver_GTExTPM',  'pLI_ExAC',
 'HIPred',
 'Cells - EBV-transformed lymphocytes_GTExTPM',
 'Pituitary_GTExTPM',
 'IPA_BP_annotation']]


st.title('Blood Pressure Gene Prioritisation Post-GWAS')
st.markdown("""
A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure.
""")


collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""]

input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):")
gene_list = collect_genes(input_gene_list)
explainer = shap.TreeExplainer(xgb)

@st.experimental_memo
def convert_df(df):
   return df.to_csv(index=False).encode('utf-8')

if len(gene_list) > 1:
    df = df_total[df_total.index.isin(gene_list)]
    st.dataframe(df)
    df['Gene'] = df.index
    output = df[['Gene', 'XGB_Score']]
    csv = convert_df(output)
    st.download_button(
       "Download Gene Prioritisation",
       csv,
       "bp_gene_prioritisation.csv",
       "text/csv",
       key='download-csv'
    )
    df_shap = df_total[df_total.index.isin(gene_list)]
    df_shap.drop(columns='XGB_Score', inplace=True)
    shap_values = explainer.shap_values(df_shap)
    summary_plot = shap.summary_plot(shap_values, df_shap, show=False)
    st.caption("SHAP Summary Plot of All Input Genes")
    st.pyplot(fig=summary_plot)

else:
    pass


input_gene = st.text_input("Input individual HGNC gene:")
df2 = df_total[df_total.index == input_gene]
st.dataframe(df2)
df2.drop(columns='XGB_Score', inplace=True)

if input_gene:
    shap_values = explainer.shap_values(df2)
    shap.getjs()
    force_plot = shap.force_plot(
    explainer.expected_value,
    shap_values,
    df2, 
    matplotlib = True,show=False)
    st.pyplot(fig=force_plot)
else:
    pass

st.markdown("""
Total Gene Prioritisation Results:
""")

st.dataframe(df_total)