import streamlit as st import re import numpy as np import pandas as pd import sklearn import xgboost import shap from shap_plots import shap_summary_plot import plotly.tools as tls import dash_core_components as dcc import matplotlib import plotly.graph_objs as go try: import matplotlib.pyplot as pl from matplotlib.colors import LinearSegmentedColormap from matplotlib.ticker import MaxNLocator except ImportError: pass st.set_option('deprecation.showPyplotGlobalUse', False) seed=42 annotations = pd.read_csv("annotations_dataset.csv") annotations = annotations.set_index("Gene") training_data = pd.read_csv("./selected_features_training_data.csv", header=0) training_data.columns = [ regex.sub("_", col) if any(x in str(col) for x in set(("[", "]", "<"))) else col for col in training_data.columns.values ] training_data["BPlabel_encoded"] = training_data["BPlabel"].map( {"most likely": 1, "probable": 0.75, "least likely": 0.1} ) Y = training_data["BPlabel_encoded"] X = training_data.drop(columns=["BPlabel_encoded","BPlabel"]) xgb = xgboost.XGBRegressor( n_estimators=40, learning_rate=0.2, max_depth=4, reg_alpha=1, reg_lambda=1, random_state=seed, objective="reg:squarederror") xgb.fit(X, Y) prediction_list = list(xgb.predict(annotations)) predictions = [round(prediction, 2) for prediction in prediction_list] output = pd.Series(data=predictions, index=annotations.index, name="XGB_Score") df_total = pd.concat([annotations, output], axis=1) df_total = df_total[['XGB_Score', 'mousescore_Exomiser', 'SDI', 'Liver_GTExTPM', 'pLI_ExAC', 'HIPred', 'Cells - EBV-transformed lymphocytes_GTExTPM', 'Pituitary_GTExTPM', 'IPA_BP_annotation']] st.title('Blood Pressure Gene Prioritisation Post-GWAS') st.markdown(""" A machine learning pipeline for predicting disease-causing genes post-genome-wide association study in blood pressure. """) collect_genes = lambda x : [str(i) for i in re.split(",|,\s+|\s+", x) if i != ""] input_gene_list = st.text_input("Input list of HGNC genes (enter comma separated):") gene_list = collect_genes(input_gene_list) explainer = shap.TreeExplainer(xgb) @st.experimental_memo def convert_df(df): return df.to_csv(index=False).encode('utf-8') cdict1 = { 'red': ((0.0, 0.11764705882352941, 0.11764705882352941), (1.0, 0.9607843137254902, 0.9607843137254902)), 'green': ((0.0, 0.5333333333333333, 0.5333333333333333), (1.0, 0.15294117647058825, 0.15294117647058825)), 'blue': ((0.0, 0.8980392156862745, 0.8980392156862745), (1.0, 0.3411764705882353, 0.3411764705882353)), 'alpha': ((0.0, 1, 1), (0.5, 1, 1), (1.0, 1, 1)) } # #1E88E5 -> #ff0052 red_blue = LinearSegmentedColormap('RedBlue', cdict1) def matplotlib_to_plotly(cmap, pl_entries): h = 1.0/(pl_entries-1) pl_colorscale = [] for k in range(pl_entries): C = list(map(np.uint8, np.array(cmap(k*h)[:3])*255)) pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))]) return pl_colorscale red_blue = matplotlib_to_plotly(red_blue, 255) if len(gene_list) > 1: df = df_total[df_total.index.isin(gene_list)] df['Gene'] = df.index df.reset_index(drop=True, inplace=True) df = df[['Gene','XGB_Score', 'mousescore_Exomiser', 'SDI', 'Liver_GTExTPM', 'pLI_ExAC', 'HIPred', 'Cells - EBV-transformed lymphocytes_GTExTPM', 'Pituitary_GTExTPM', 'IPA_BP_annotation']] st.dataframe(df) output = df[['Gene', 'XGB_Score']] csv = convert_df(output) st.download_button( "Download Gene Prioritisation", csv, "bp_gene_prioritisation.csv", "text/csv", key='download-csv' ) df_shap = df_total[df_total.index.isin(gene_list)] df_shap.drop(columns='XGB_Score', inplace=True) shap_values = explainer.shap_values(df_shap) summary_plot = shap.summary_plot(shap_values, df_shap) st.pyplot(fig=summary_plot) st.caption("SHAP Summary Plot of All Input Genes") feature_order = np.argsort(np.sum(np.abs(shap_values), axis=0)[:-1]) feature_order = feature_order[-min(8, len(feature_order)):] col_order = [df_shap.columns[i] for i in feature_order] st.caption("Interactive SHAP Summary Plot of All Input Genes") mpl_fig = shap_summary_plot(shap_values, df_shap, max_display=8, show=False, feature_names=col_order) plotly_fig = tls.mpl_to_plotly(mpl_fig) plotly_fig['layout'] = {'xaxis': {'title': 'SHAP value (impact on model output)'}} max_display=8 feature_names=df_shap.columns gene_index = df_shap.index for i in range(1, len(plotly_fig['data']), 2): t = gene_index #plotly_fig['data'] plotly_fig['data'][i]['name'] = '' plotly_fig['data'][i]['text'] = t plotly_fig['data'][i]['hoverinfo'] = 'text' colorbar_trace = go.Scatter(x=[None], y=col_order, # [None], visible=True, mode='markers', marker=dict( colorscale=red_blue, showscale=True, cmin=-5, cmax=5, colorbar=dict(thickness=5, tickvals=[-5, 5], ticktext=['Low', 'High'], outlinewidth=0) ), hoverinfo='none' ) plotly_fig['layout']['showlegend'] = False plotly_fig['layout']['hovermode'] = 'closest' plotly_fig['layout']['height']=600 plotly_fig['layout']['width']=500 plotly_fig['layout']['xaxis'].update(zeroline=True, showline=True, ticklen=4, showgrid=False) plotly_fig['layout']['yaxis'].update(dict(visible=True)) plotly_fig.add_trace(colorbar_trace) plotly_fig.layout.update( annotations=[dict( x=1.18, align="right", valign="top", text='Feature value', showarrow=False, xref="paper", yref="paper", xanchor="right", yanchor="middle", textangle=-90, font=dict(family='Calibri', size=14), ) ], margin=dict(t=20) ) st.plotly_chart(plotly_fig) else: pass input_gene = st.text_input("Input individual HGNC gene:") df2 = df_total[df_total.index == input_gene] df2['Gene'] = df2.index df2.reset_index(drop=True, inplace=True) df2 = df2[['Gene','XGB_Score', 'mousescore_Exomiser', 'SDI', 'Liver_GTExTPM', 'pLI_ExAC', 'HIPred', 'Cells - EBV-transformed lymphocytes_GTExTPM', 'Pituitary_GTExTPM', 'IPA_BP_annotation']] st.dataframe(df2) if input_gene: df2_shap = df_total[df_total.index == input_gene] df2_shap.drop(columns='XGB_Score', inplace=True) shap_values = explainer.shap_values(df2_shap) shap.getjs() force_plot = shap.force_plot( explainer.expected_value, shap_values, df2_shap, matplotlib = True,show=False) st.pyplot(fig=force_plot) else: pass st.markdown(""" Total Gene Prioritisation Results: """) df_total_output = df_total df_total_output['Gene'] = df_total_output.index df_total_output.reset_index(drop=True, inplace=True) df_total_output = df_total_output[['Gene','XGB_Score', 'mousescore_Exomiser', 'SDI', 'Liver_GTExTPM', 'pLI_ExAC', 'HIPred', 'Cells - EBV-transformed lymphocytes_GTExTPM', 'Pituitary_GTExTPM', 'IPA_BP_annotation']] st.dataframe(df_total_output)