ArturG9 commited on
Commit
c52a337
1 Parent(s): 5b5bce3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -215
app.py CHANGED
@@ -1,217 +1,110 @@
 
 
1
  import streamlit as st
2
- import pandas as pd
3
  import joblib
4
- from enum import Enum
5
- from pydantic import BaseModel, Field, confloat, constr, conlist, ValidationError
6
- from typing import Optional
7
-
8
- # Load the model
9
- model = joblib.load('lgb_model_main.joblib')
10
-
11
- categorical_features = [
12
- 'NAME_CONTRACT_TYPE',
13
- 'CODE_GENDER',
14
- 'NAME_INCOME_TYPE',
15
- 'NAME_EDUCATION_TYPE',
16
- 'NAME_FAMILY_STATUS',
17
- 'OCCUPATION_TYPE',
18
- 'ORGANIZATION_TYPE',
19
- ]
20
-
21
- class ContractType(str, Enum):
22
- Cash_loans = "Cash loans"
23
- Revolving_loans = "Revolving loans"
24
-
25
- class Gender(str, Enum):
26
- Male = "M"
27
- Female = "F"
28
- XNA = "XNA"
29
-
30
- class IncomeType(str, Enum):
31
- Working = "Working"
32
- Other = "Other"
33
- Commercial_associate = "Commercial associate"
34
- Pensioner = "Pensioner"
35
-
36
- class EducationType(str, Enum):
37
- Other = "Other"
38
- Higher_education = "Higher education"
39
- Secondary = "Secondary / secondary special"
40
-
41
- class FamilyStatus(str, Enum):
42
- Civil_marriage = "Civil marriage"
43
- Married = "Married"
44
- Single = "Single / not married"
45
- Other = "Other"
46
-
47
- class OccupationType(str, Enum):
48
- Laborers = "Laborers"
49
- Sales_staff = "Sales staff"
50
- Core_staff = "Core staff"
51
- Managers = "Managers"
52
- Drivers = "Drivers"
53
- Other = "Other"
54
-
55
- class OrganizationType(str, Enum):
56
- Business_Entity = "Business Entity Type 3"
57
- Other = "Other"
58
- XNA = "XNA"
59
- Self_employed = "Self-employed"
60
-
61
- class PredictionInput(BaseModel):
62
- AMT_INCOME_TOTAL: confloat(ge=0)
63
- AMT_CREDIT: confloat(ge=0)
64
- REGION_POPULATION_RELATIVE: confloat(ge=0)
65
- DAYS_REGISTRATION: int
66
- DAYS_BIRTH: int
67
- DAYS_ID_PUBLISH: int
68
- FLAG_WORK_PHONE: int
69
- FLAG_PHONE: int
70
- REGION_RATING_CLIENT_W_CITY: int
71
- REG_CITY_NOT_WORK_CITY: int
72
- FLAG_DOCUMENT_3: int
73
- NAME_CONTRACT_TYPE: ContractType
74
- CODE_GENDER: Gender
75
- FLAG_OWN_CAR: int
76
- NAME_INCOME_TYPE: IncomeType
77
- NAME_EDUCATION_TYPE: EducationType
78
- NAME_FAMILY_STATUS: FamilyStatus
79
- OCCUPATION_TYPE: OccupationType
80
- ORGANIZATION_TYPE: OrganizationType
81
- CREDIT_ACTIVE_Active_count_Bureau: Optional[int] = None
82
- CREDIT_ACTIVE_Closed_count_Bureau: Optional[int] = None
83
- DAYS_CREDIT_Bureau: Optional[int] = None
84
- AMT_INSTALMENT_mean_HCredit_installments: Optional[int] = None
85
- DAYS_INSTALMENT_mean_HCredit_installments: Optional[int] = None
86
- NUM_INSTALMENT_NUMBER_mean_HCredit_installments: Optional[int] = None
87
- NUM_INSTALMENT_VERSION_mean_HCredit_installments: Optional[int] = None
88
- NAME_CONTRACT_STATUS_Active_count_pos_cash: Optional[int] = None
89
- NAME_CONTRACT_STATUS_Completed_count_pos_cash: Optional[int] = None
90
- SK_DPD_DEF_pos_cash: Optional[int] = None
91
- NAME_CONTRACT_STATUS_Refused_count_HCredit_PApp: Optional[int] = None
92
- NAME_GOODS_CATEGORY_Other_count_HCredit_PApp: Optional[int] = None
93
- NAME_PORTFOLIO_Cash_count_HCredit_PApp: Optional[int] = None
94
- NAME_PRODUCT_TYPE_walk_in_count_HCredit_PApp: Optional[int] = None
95
- NAME_SELLER_INDUSTRY_Other_count_HCredit_PApp: Optional[int] = None
96
- NAME_YIELD_GROUP_high_count_HCredit_PApp: Optional[int] = None
97
- NAME_YIELD_GROUP_low_action_count_HCredit_PApp: Optional[int] = None
98
- AMT_CREDIT_HCredit_PApp: Optional[int] = None
99
- SELLERPLACE_AREA_HCredit_PApp: Optional[int] = None
100
-
101
- def make_prediction(input_data: dict):
102
- try:
103
- # Convert dictionary to a pandas DataFrame
104
- input_df = pd.DataFrame([input_data])
105
-
106
- # Convert categorical features to 'category' type
107
- for feature in categorical_features:
108
- input_df[feature] = input_df[feature].astype('category')
109
-
110
- # Make predictions using the loaded model
111
- predictions = model.predict_proba(input_df, categorical_feature=categorical_features)[:, 1]
112
-
113
- # Placeholder response for demonstration
114
- response = {"Probability for this credit to be defaulted is: ": predictions[0]} # Extract the probability for class 1
115
-
116
- return response
117
- except Exception as e:
118
- return {"error": str(e)}
119
-
120
- def main():
121
- st.title("Credit Default Prediction")
122
-
123
- st.header("Input Data")
124
- with st.form(key='input_form'):
125
- AMT_INCOME_TOTAL = st.number_input("AMT_INCOME_TOTAL", min_value=0.0, format="%f")
126
- AMT_CREDIT = st.number_input("AMT_CREDIT", min_value=0.0, format="%f")
127
- REGION_POPULATION_RELATIVE = st.number_input("REGION_POPULATION_RELATIVE", min_value=0.0, format="%f")
128
- DAYS_REGISTRATION = st.number_input("DAYS_REGISTRATION", min_value=-100000, max_value=100000, format="%d")
129
- DAYS_BIRTH = st.number_input("DAYS_BIRTH", min_value=-100000, max_value=100000, format="%d")
130
- DAYS_ID_PUBLISH = st.number_input("DAYS_ID_PUBLISH", min_value=-100000, max_value=100000, format="%d")
131
- FLAG_WORK_PHONE = st.number_input("FLAG_WORK_PHONE", min_value=0, max_value=1, format="%d")
132
- FLAG_PHONE = st.number_input("FLAG_PHONE", min_value=0, max_value=1, format="%d")
133
- REGION_RATING_CLIENT_W_CITY = st.number_input("REGION_RATING_CLIENT_W_CITY", min_value=0, max_value=10, format="%d")
134
- REG_CITY_NOT_WORK_CITY = st.number_input("REG_CITY_NOT_WORK_CITY", min_value=0, max_value=1, format="%d")
135
- FLAG_DOCUMENT_3 = st.number_input("FLAG_DOCUMENT_3", min_value=0, max_value=1, format="%d")
136
- NAME_CONTRACT_TYPE = st.selectbox("NAME_CONTRACT_TYPE", list(ContractType))
137
- CODE_GENDER = st.selectbox("CODE_GENDER", list(Gender))
138
- FLAG_OWN_CAR = st.number_input("FLAG_OWN_CAR", min_value=0, max_value=1, format="%d")
139
- NAME_INCOME_TYPE = st.selectbox("NAME_INCOME_TYPE", list(IncomeType))
140
- NAME_EDUCATION_TYPE = st.selectbox("NAME_EDUCATION_TYPE", list(EducationType))
141
- NAME_FAMILY_STATUS = st.selectbox("NAME_FAMILY_STATUS", list(FamilyStatus))
142
- OCCUPATION_TYPE = st.selectbox("OCCUPATION_TYPE", list(OccupationType))
143
- ORGANIZATION_TYPE = st.selectbox("ORGANIZATION_TYPE", list(OrganizationType))
144
-
145
- CREDIT_ACTIVE_Active_count_Bureau = st.number_input("CREDIT_ACTIVE_Active_count_Bureau", min_value=0, format="%d", value=0)
146
- CREDIT_ACTIVE_Closed_count_Bureau = st.number_input("CREDIT_ACTIVE_Closed_count_Bureau", min_value=0, format="%d", value=0)
147
- DAYS_CREDIT_Bureau = st.number_input("DAYS_CREDIT_Bureau", min_value=-100000, max_value=100000, format="%d", value=0)
148
- AMT_INSTALMENT_mean_HCredit_installments = st.number_input("AMT_INSTALMENT_mean_HCredit_installments", min_value=0, format="%f", value=0.0)
149
- DAYS_INSTALMENT_mean_HCredit_installments = st.number_input("DAYS_INSTALMENT_mean_HCredit_installments", min_value=-100000, max_value=100000, format="%d", value=0)
150
- NUM_INSTALMENT_NUMBER_mean_HCredit_installments = st.number_input("NUM_INSTALMENT_NUMBER_mean_HCredit_installments", min_value=0, format="%d", value=0)
151
- NUM_INSTALMENT_VERSION_mean_HCredit_installments = st.number_input("NUM_INSTALMENT_VERSION_mean_HCredit_installments", min_value=0, format="%d", value=0)
152
- NAME_CONTRACT_STATUS_Active_count_pos_cash = st.number_input("NAME_CONTRACT_STATUS_Active_count_pos_cash", min_value=0, format="%d", value=0)
153
- NAME_CONTRACT_STATUS_Completed_count_pos_cash = st.number_input("NAME_CONTRACT_STATUS_Completed_count_pos_cash", min_value=0, format="%d", value=0)
154
- SK_DPD_DEF_pos_cash = st.number_input("SK_DPD_DEF_pos_cash", min_value=0, format="%d", value=0)
155
- NAME_CONTRACT_STATUS_Refused_count_HCredit_PApp = st.number_input("NAME_CONTRACT_STATUS_Refused_count_HCredit_PApp", min_value=0, format="%d", value=0)
156
- NAME_GOODS_CATEGORY_Other_count_HCredit_PApp = st.number_input("NAME_GOODS_CATEGORY_Other_count_HCredit_PApp", min_value=0, format="%d", value=0)
157
- NAME_PORTFOLIO_Cash_count_HCredit_PApp = st.number_input("NAME_PORTFOLIO_Cash_count_HCredit_PApp", min_value=0, format="%d", value=0)
158
- NAME_PRODUCT_TYPE_walk_in_count_HCredit_PApp = st.number_input("NAME_PRODUCT_TYPE_walk_in_count_HCredit_PApp", min_value=0, format="%d", value=0)
159
- NAME_SELLER_INDUSTRY_Other_count_HCredit_PApp = st.number_input("NAME_SELLER_INDUSTRY_Other_count_HCredit_PApp", min_value=0, format="%d", value=0)
160
- NAME_YIELD_GROUP_high_count_HCredit_PApp = st.number_input("NAME_YIELD_GROUP_high_count_HCredit_PApp", min_value=0, format="%d", value=0)
161
- NAME_YIELD_GROUP_low_action_count_HCredit_PApp = st.number_input("NAME_YIELD_GROUP_low_action_count_HCredit_PApp", min_value=0, format="%d", value=0)
162
- AMT_CREDIT_HCredit_PApp = st.number_input("AMT_CREDIT_HCredit_PApp", min_value=0, format="%f", value=0.0)
163
- SELLERPLACE_AREA_HCredit_PApp = st.number_input("SELLERPLACE_AREA_HCredit_PApp", min_value=0, format="%d", value=0)
164
-
165
- submit_button = st.form_submit_button(label='Predict')
166
-
167
- if submit_button:
168
- input_data = {
169
- "AMT_INCOME_TOTAL": AMT_INCOME_TOTAL,
170
- "AMT_CREDIT": AMT_CREDIT,
171
- "REGION_POPULATION_RELATIVE": REGION_POPULATION_RELATIVE,
172
- "DAYS_REGISTRATION": DAYS_REGISTRATION,
173
- "DAYS_BIRTH": DAYS_BIRTH,
174
- "DAYS_ID_PUBLISH": DAYS_ID_PUBLISH,
175
- "FLAG_WORK_PHONE": FLAG_WORK_PHONE,
176
- "FLAG_PHONE": FLAG_PHONE,
177
- "REGION_RATING_CLIENT_W_CITY": REGION_RATING_CLIENT_W_CITY,
178
- "REG_CITY_NOT_WORK_CITY": REG_CITY_NOT_WORK_CITY,
179
- "FLAG_DOCUMENT_3": FLAG_DOCUMENT_3,
180
- "NAME_CONTRACT_TYPE": NAME_CONTRACT_TYPE,
181
- "CODE_GENDER": CODE_GENDER,
182
- "FLAG_OWN_CAR": FLAG_OWN_CAR,
183
- "NAME_INCOME_TYPE": NAME_INCOME_TYPE,
184
- "NAME_EDUCATION_TYPE": NAME_EDUCATION_TYPE,
185
- "NAME_FAMILY_STATUS": NAME_FAMILY_STATUS,
186
- "OCCUPATION_TYPE": OCCUPATION_TYPE,
187
- "ORGANIZATION_TYPE": ORGANIZATION_TYPE,
188
- "CREDIT_ACTIVE_Active_count_Bureau": CREDIT_ACTIVE_Active_count_Bureau,
189
- "CREDIT_ACTIVE_Closed_count_Bureau": CREDIT_ACTIVE_Closed_count_Bureau,
190
- "DAYS_CREDIT_Bureau": DAYS_CREDIT_Bureau,
191
- "AMT_INSTALMENT_mean_HCredit_installments": AMT_INSTALMENT_mean_HCredit_installments,
192
- "DAYS_INSTALMENT_mean_HCredit_installments": DAYS_INSTALMENT_mean_HCredit_installments,
193
- "NUM_INSTALMENT_NUMBER_mean_HCredit_installments": NUM_INSTALMENT_NUMBER_mean_HCredit_installments,
194
- "NUM_INSTALMENT_VERSION_mean_HCredit_installments": NUM_INSTALMENT_VERSION_mean_HCredit_installments,
195
- "NAME_CONTRACT_STATUS_Active_count_pos_cash": NAME_CONTRACT_STATUS_Active_count_pos_cash,
196
- "NAME_CONTRACT_STATUS_Completed_count_pos_cash": NAME_CONTRACT_STATUS_Completed_count_pos_cash,
197
- "SK_DPD_DEF_pos_cash": SK_DPD_DEF_pos_cash,
198
- "NAME_CONTRACT_STATUS_Refused_count_HCredit_PApp": NAME_CONTRACT_STATUS_Refused_count_HCredit_PApp,
199
- "NAME_GOODS_CATEGORY_Other_count_HCredit_PApp": NAME_GOODS_CATEGORY_Other_count_HCredit_PApp,
200
- "NAME_PORTFOLIO_Cash_count_HCredit_PApp": NAME_PORTFOLIO_Cash_count_HCredit_PApp,
201
- "NAME_PRODUCT_TYPE_walk_in_count_HCredit_PApp": NAME_PRODUCT_TYPE_walk_in_count_HCredit_PApp,
202
- "NAME_SELLER_INDUSTRY_Other_count_HCredit_PApp": NAME_SELLER_INDUSTRY_Other_count_HCredit_PApp,
203
- "NAME_YIELD_GROUP_high_count_HCredit_PApp": NAME_YIELD_GROUP_high_count_HCredit_PApp,
204
- "NAME_YIELD_GROUP_low_action_count_HCredit_PApp": NAME_YIELD_GROUP_low_action_count_HCredit_PApp,
205
- "AMT_CREDIT_HCredit_PApp": AMT_CREDIT_HCredit_PApp,
206
- "SELLERPLACE_AREA_HCredit_PApp": SELLERPLACE_AREA_HCredit_PApp
207
- }
208
-
209
- try:
210
- input_data_validated = PredictionInput(**input_data)
211
- prediction = make_prediction(input_data_validated.dict())
212
- st.write(prediction)
213
- except ValidationError as e:
214
- st.error(f"Validation error: {e}")
215
-
216
- if __name__ == "__main__":
217
- main()
 
1
+ import io
2
+ import pickle
3
  import streamlit as st
 
4
  import joblib
5
+ import shap
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ # Load the LightGBM model and other necessary objects
11
+ with open('lgb1_model.pkl', 'rb') as f:
12
+ lgb1 = pickle.load(f)
13
+
14
+ categorical_features = joblib.load("categorical_features.joblib")
15
+ encoder = joblib.load("encoder.joblib")
16
+
17
+ # Sidebar option to select the dashboard
18
+ option = st.sidebar.selectbox("Which dashboard?", ("Model information", "Stroke prediction"))
19
+ st.title(option)
20
+
21
+ def get_pred():
22
+ """
23
+ Function to display the stroke probability calculator and Shap force plot.
24
+ """
25
+ st.header("Stroke probability calculator ")
26
+
27
+ # User input for prediction
28
+ gender = st.selectbox("Select gender: ", ["Male", "Female", 'Other'])
29
+ work_type = st.selectbox("Work type: ", ["Private", "Self_employed", 'children', 'Govt_job', 'Never_worked'])
30
+ residence_status = st.selectbox("Residence status: ", ["Urban", "Rural"])
31
+ smoking_status = st.selectbox("Smoking status: ", ["Unknown", "formerly smoked", 'never smoked', 'smokes'])
32
+ age = st.slider("Input age: ", 0, 120)
33
+ hypertension = st.select_slider("Do you have hypertension: ", [0, 1])
34
+ heart_disease = st.select_slider("Do you have heart disease: ", [0, 1])
35
+ ever_married = st.select_slider("Have you ever married? ", [0, 1])
36
+ avg_glucosis_lvl = st.slider("Average glucosis level: ", 50, 280)
37
+ bmi = st.slider("Input Bmi: ", 10, 100)
38
+
39
+ # User input data
40
+ data = {
41
+ "gender": gender,
42
+ "work_type": work_type,
43
+ "Residence_type": residence_status,
44
+ "smoking_status": smoking_status,
45
+ "age": age,
46
+ "hypertension": hypertension,
47
+ "heart_disease": heart_disease,
48
+ "ever_married": ever_married,
49
+ "avg_glucose_level": avg_glucosis_lvl,
50
+ "bmi": bmi
51
+ }
52
+
53
+ # Prediction button
54
+ if st.button("Predict"):
55
+ # Convert input data to a DataFrame
56
+ X = pd.DataFrame([data])
57
+
58
+ # Encode categorical features
59
+ encoded_features = encoder.transform(X[categorical_features])
60
+
61
+ # Get the feature names from the encoder
62
+ feature_names = encoder.get_feature_names_out(input_features=categorical_features)
63
+
64
+ # Create a DataFrame with the encoded features and feature names
65
+ encoded_df = pd.DataFrame(encoded_features, columns=feature_names)
66
+ X_encoded = pd.concat([X.drop(columns=categorical_features), encoded_df], axis=1)
67
+
68
+ # Make predictions
69
+ prediction_proba = lgb1.predict_proba(X_encoded)
70
+
71
+ # Get SHAP values
72
+ explainer = shap.TreeExplainer(lgb1)
73
+ shap_values = explainer.shap_values(X_encoded)
74
+
75
+ # Extract prediction probability and display it to the user
76
+ probability = prediction_proba[0, 1] # Assuming binary classification
77
+ st.subheader(f"The predicted probability of stroke is {probability}.")
78
+ st.subheader("IF you see result , higher than 0.3, we advice you to see a doctor")
79
+ st.header("Shap forceplot")
80
+ st.subheader("Features values impact on model made prediction")
81
+
82
+ # Display SHAP force plot using Matplotlib
83
+ shap.force_plot(explainer.expected_value[1], shap_values[1], features=X_encoded.iloc[0, :], matplotlib=True)
84
+
85
+ # Save the figure to a BytesIO buffer
86
+ buf = io.BytesIO()
87
+ plt.savefig(buf, format="png", dpi=800)
88
+ buf.seek(0)
89
+
90
+ # Display the image in Streamlit
91
+ st.image(buf, width=1100)
92
+
93
+ # Display summary plot of feature importance
94
+ shap.summary_plot(shap_values[1], X_encoded)
95
+
96
+ # Display interaction summary plot
97
+ shap_interaction_values = explainer.shap_interaction_values(X_encoded)
98
+ shap.summary_plot(shap_interaction_values, X_encoded)
99
+
100
+ # Execute get_pred() only if the option is "Stroke prediction"
101
+ if option == "Stroke prediction":
102
+ get_pred()
103
+
104
+ if option == "Model information":
105
+ st.header("Light gradient boosting model")
106
+ st.subheader("First tree of light gradient boosting model and how it makes decisions")
107
+ st.image(r'lgbm_tree.png')
108
+
109
+ st.subheader("Shap values visualization of how features contribute to model prediction")
110
+ st.image(r'lgbm_model_shap_evaluation.png')