ArturG9 commited on
Commit
b984ae4
1 Parent(s): f46abf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -13
app.py CHANGED
@@ -50,54 +50,65 @@ def get_pred():
50
  "bmi": bmi
51
  }
52
 
53
- # Prediction button
54
  if st.button("Predict"):
55
  # Convert input data to a DataFrame
56
  X = pd.DataFrame([data])
57
 
58
- # Encode categorical features
59
  encoded_features = encoder.transform(X[categorical_features])
60
 
61
- # Get the feature names from the encoder
62
  feature_names = encoder.get_feature_names_out(input_features=categorical_features)
63
 
64
- # Create a DataFrame with the encoded features and feature names
65
  encoded_df = pd.DataFrame(encoded_features, columns=feature_names)
66
  X_encoded = pd.concat([X.drop(columns=categorical_features), encoded_df], axis=1)
67
 
68
- # Make predictions
69
  prediction_proba = lgb1.predict_proba(X_encoded)
70
 
71
- # Get SHAP values
72
  explainer = shap.TreeExplainer(lgb1)
73
  shap_values = explainer.shap_values(X_encoded)
74
 
75
- # Extract prediction probability and display it to the user
76
  probability = prediction_proba[0, 1] # Assuming binary classification
77
  st.subheader(f"The predicted probability of stroke is {probability}.")
78
  st.subheader("IF you see result , higher than 0.3, we advice you to see a doctor")
79
  st.header("Shap forceplot")
80
  st.subheader("Features values impact on model made prediction")
81
 
82
- # Display SHAP force plot using Matplotlib
83
  shap.force_plot(explainer.expected_value[1], shap_values[1], features=X_encoded.iloc[0, :], matplotlib=True)
84
 
85
- # Save the figure to a BytesIO buffer
86
  buf = io.BytesIO()
87
  plt.savefig(buf, format="png", dpi=800)
88
  buf.seek(0)
89
 
90
- # Display the image in Streamlit
91
  st.image(buf, width=1100)
92
 
93
- # Display summary plot of feature importance
94
  shap.summary_plot(shap_values[1], X_encoded)
95
 
96
- # Display interaction summary plot
97
  shap_interaction_values = explainer.shap_interaction_values(X_encoded)
98
  shap.summary_plot(shap_interaction_values, X_encoded)
99
 
100
- # Execute get_pred() only if the option is "Stroke prediction"
 
 
 
 
 
 
 
 
 
 
 
101
  if option == "Stroke prediction":
102
  get_pred()
103
 
@@ -105,6 +116,9 @@ if option == "Model information":
105
  st.header("Light gradient boosting model")
106
  st.subheader("First tree of light gradient boosting model and how it makes decisions")
107
  st.image(r'lgbm_tree.png')
 
 
 
108
 
109
  st.subheader("Shap values visualization of how features contribute to model prediction")
110
  st.image(r'lgbm_model_shap_evaluation.png')
 
50
  "bmi": bmi
51
  }
52
 
53
+
54
  if st.button("Predict"):
55
  # Convert input data to a DataFrame
56
  X = pd.DataFrame([data])
57
 
58
+
59
  encoded_features = encoder.transform(X[categorical_features])
60
 
61
+
62
  feature_names = encoder.get_feature_names_out(input_features=categorical_features)
63
 
64
+
65
  encoded_df = pd.DataFrame(encoded_features, columns=feature_names)
66
  X_encoded = pd.concat([X.drop(columns=categorical_features), encoded_df], axis=1)
67
 
68
+
69
  prediction_proba = lgb1.predict_proba(X_encoded)
70
 
71
+
72
  explainer = shap.TreeExplainer(lgb1)
73
  shap_values = explainer.shap_values(X_encoded)
74
 
75
+
76
  probability = prediction_proba[0, 1] # Assuming binary classification
77
  st.subheader(f"The predicted probability of stroke is {probability}.")
78
  st.subheader("IF you see result , higher than 0.3, we advice you to see a doctor")
79
  st.header("Shap forceplot")
80
  st.subheader("Features values impact on model made prediction")
81
 
82
+
83
  shap.force_plot(explainer.expected_value[1], shap_values[1], features=X_encoded.iloc[0, :], matplotlib=True)
84
 
85
+
86
  buf = io.BytesIO()
87
  plt.savefig(buf, format="png", dpi=800)
88
  buf.seek(0)
89
 
90
+
91
  st.image(buf, width=1100)
92
 
93
+
94
  shap.summary_plot(shap_values[1], X_encoded)
95
 
96
+
97
  shap_interaction_values = explainer.shap_interaction_values(X_encoded)
98
  shap.summary_plot(shap_interaction_values, X_encoded)
99
 
100
+
101
+ if option == "Information about training data":
102
+ st.header("Stroke Prediction Dataset")
103
+ st.subheader("According to the World Health Organization (WHO) stroke is the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. This dataset is used to predict whether a patient is likely to get stroke based on the input parameters like gender, age, various diseases, and smoking status. Each row in the data provides relavant information about the patient.")
104
+ st.subheader(" Stroke dataset has 5110 records and 12 features.")
105
+ st.subheader(" Correlation between features:.")
106
+ st.image(r'lgbm_tree.png')
107
+ st.subheader("Features Shap values and how it effects Target variable "Stroke")
108
+ st.image(r'lgbm_tree.png')
109
+
110
+
111
+
112
  if option == "Stroke prediction":
113
  get_pred()
114
 
 
116
  st.header("Light gradient boosting model")
117
  st.subheader("First tree of light gradient boosting model and how it makes decisions")
118
  st.image(r'lgbm_tree.png')
119
+
120
+
121
+
122
 
123
  st.subheader("Shap values visualization of how features contribute to model prediction")
124
  st.image(r'lgbm_model_shap_evaluation.png')