Update app.py
Browse files
app.py
CHANGED
@@ -50,54 +50,65 @@ def get_pred():
|
|
50 |
"bmi": bmi
|
51 |
}
|
52 |
|
53 |
-
|
54 |
if st.button("Predict"):
|
55 |
# Convert input data to a DataFrame
|
56 |
X = pd.DataFrame([data])
|
57 |
|
58 |
-
|
59 |
encoded_features = encoder.transform(X[categorical_features])
|
60 |
|
61 |
-
|
62 |
feature_names = encoder.get_feature_names_out(input_features=categorical_features)
|
63 |
|
64 |
-
|
65 |
encoded_df = pd.DataFrame(encoded_features, columns=feature_names)
|
66 |
X_encoded = pd.concat([X.drop(columns=categorical_features), encoded_df], axis=1)
|
67 |
|
68 |
-
|
69 |
prediction_proba = lgb1.predict_proba(X_encoded)
|
70 |
|
71 |
-
|
72 |
explainer = shap.TreeExplainer(lgb1)
|
73 |
shap_values = explainer.shap_values(X_encoded)
|
74 |
|
75 |
-
|
76 |
probability = prediction_proba[0, 1] # Assuming binary classification
|
77 |
st.subheader(f"The predicted probability of stroke is {probability}.")
|
78 |
st.subheader("IF you see result , higher than 0.3, we advice you to see a doctor")
|
79 |
st.header("Shap forceplot")
|
80 |
st.subheader("Features values impact on model made prediction")
|
81 |
|
82 |
-
|
83 |
shap.force_plot(explainer.expected_value[1], shap_values[1], features=X_encoded.iloc[0, :], matplotlib=True)
|
84 |
|
85 |
-
|
86 |
buf = io.BytesIO()
|
87 |
plt.savefig(buf, format="png", dpi=800)
|
88 |
buf.seek(0)
|
89 |
|
90 |
-
|
91 |
st.image(buf, width=1100)
|
92 |
|
93 |
-
|
94 |
shap.summary_plot(shap_values[1], X_encoded)
|
95 |
|
96 |
-
|
97 |
shap_interaction_values = explainer.shap_interaction_values(X_encoded)
|
98 |
shap.summary_plot(shap_interaction_values, X_encoded)
|
99 |
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
if option == "Stroke prediction":
|
102 |
get_pred()
|
103 |
|
@@ -105,6 +116,9 @@ if option == "Model information":
|
|
105 |
st.header("Light gradient boosting model")
|
106 |
st.subheader("First tree of light gradient boosting model and how it makes decisions")
|
107 |
st.image(r'lgbm_tree.png')
|
|
|
|
|
|
|
108 |
|
109 |
st.subheader("Shap values visualization of how features contribute to model prediction")
|
110 |
st.image(r'lgbm_model_shap_evaluation.png')
|
|
|
50 |
"bmi": bmi
|
51 |
}
|
52 |
|
53 |
+
|
54 |
if st.button("Predict"):
|
55 |
# Convert input data to a DataFrame
|
56 |
X = pd.DataFrame([data])
|
57 |
|
58 |
+
|
59 |
encoded_features = encoder.transform(X[categorical_features])
|
60 |
|
61 |
+
|
62 |
feature_names = encoder.get_feature_names_out(input_features=categorical_features)
|
63 |
|
64 |
+
|
65 |
encoded_df = pd.DataFrame(encoded_features, columns=feature_names)
|
66 |
X_encoded = pd.concat([X.drop(columns=categorical_features), encoded_df], axis=1)
|
67 |
|
68 |
+
|
69 |
prediction_proba = lgb1.predict_proba(X_encoded)
|
70 |
|
71 |
+
|
72 |
explainer = shap.TreeExplainer(lgb1)
|
73 |
shap_values = explainer.shap_values(X_encoded)
|
74 |
|
75 |
+
|
76 |
probability = prediction_proba[0, 1] # Assuming binary classification
|
77 |
st.subheader(f"The predicted probability of stroke is {probability}.")
|
78 |
st.subheader("IF you see result , higher than 0.3, we advice you to see a doctor")
|
79 |
st.header("Shap forceplot")
|
80 |
st.subheader("Features values impact on model made prediction")
|
81 |
|
82 |
+
|
83 |
shap.force_plot(explainer.expected_value[1], shap_values[1], features=X_encoded.iloc[0, :], matplotlib=True)
|
84 |
|
85 |
+
|
86 |
buf = io.BytesIO()
|
87 |
plt.savefig(buf, format="png", dpi=800)
|
88 |
buf.seek(0)
|
89 |
|
90 |
+
|
91 |
st.image(buf, width=1100)
|
92 |
|
93 |
+
|
94 |
shap.summary_plot(shap_values[1], X_encoded)
|
95 |
|
96 |
+
|
97 |
shap_interaction_values = explainer.shap_interaction_values(X_encoded)
|
98 |
shap.summary_plot(shap_interaction_values, X_encoded)
|
99 |
|
100 |
+
|
101 |
+
if option == "Information about training data":
|
102 |
+
st.header("Stroke Prediction Dataset")
|
103 |
+
st.subheader("According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.")
|
104 |
+
st.subheader(" Stroke dataset has 5110 records and 12 features.")
|
105 |
+
st.subheader(" Correlation between features:.")
|
106 |
+
st.image(r'lgbm_tree.png')
|
107 |
+
st.subheader("Features Shap values and how it effects Target variable "Stroke")
|
108 |
+
st.image(r'lgbm_tree.png')
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
if option == "Stroke prediction":
|
113 |
get_pred()
|
114 |
|
|
|
116 |
st.header("Light gradient boosting model")
|
117 |
st.subheader("First tree of light gradient boosting model and how it makes decisions")
|
118 |
st.image(r'lgbm_tree.png')
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
|
123 |
st.subheader("Shap values visualization of how features contribute to model prediction")
|
124 |
st.image(r'lgbm_model_shap_evaluation.png')
|