hgbdt-viz / train_classifier.py
none
Working version of the streamlit animation
045d7d4
import joblib
import pandas as pd
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.metrics import classification_report
def main():
train_df = pd.read_csv('train_data.csv', na_values='-')
# `service` is about half-empty and the rest are completely full
# one of the rows has `no` for `state` which isn't listed as an option in the description of the fields
# I'm just going to delete that
train_df = train_df.drop(columns=['id'])
train_df = train_df.drop(index=train_df[train_df['state']=='no'].index)
# It can predict `label` really well ~0.95 accuracy/f1/whatever other stat you care about
# It does a lot worse trying to predict `attack_cat` b/c there are 10 classes
# and some of them are not well-represented
# so that might be more interesting to visualize
cheating = train_df.pop('attack_cat')
y_enc = LabelEncoder().fit(train_df['label'])
train_y = y_enc.transform(train_df.pop('label'))
x_enc = OrdinalEncoder().fit(train_df)
train_df = x_enc.transform(train_df)
# Random forest doesn't handle NaNs
# I could drop the `service` column or I can use the HistGradientBoostingClassifier
# super helpful error message from sklearn pointed me to this list
# https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values
#rf = RandomForestClassifier()
#rf.fit(train_df, y_train)
# max_iter is the number of time it builds a gradient-boosted tree
# so it's the number of estimators
hgb = HistGradientBoostingClassifier(max_iter=10).fit(train_df, train_y)
joblib.dump(hgb, 'hgb_classifier.joblib', compress=9)
test_df = pd.read_csv('test_data.csv', na_values='-')
test_df = test_df.drop(columns=['id', 'attack_cat'])
test_y = y_enc.transform(test_df.pop('label'))
test_df = x_enc.transform(test_df)
test_preds = hgb.predict(test_df)
print(classification_report(test_y, test_preds))
# I guess they took out the RF feature importance
# or maybe that's only in XGBoost
# you can still kind of get to it
# with RandomForestClassifier.feature_importances_
# or like this
# https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
# but there's really nothing for the HistGradientBoostingClassifier
# but you can get to the actual nodes for each predictor/estimator like this
# hgb._predictors[i][0].nodes
# and that has information gain metric for each node which might be viz-able
# so that might be an interesting viz
# like plot the whole forest
# maybe only do like 10 estimators to keep it smaller
# or stick with 100 and figure out a good way to viz big models
# the first two estimators are almost identical
# so maybe like plot the first estimator
# and then fuzz the nodes by how much the other estimators differ
# assuming there's some things they all agree on exactly and others where they differ a little bit
# idk I don't really know how the algorithm works
# the 96th estimator looks pretty different (I'm assuming from boosting)
# so maybe like an evolution animation from the first to the last
# to see the effect of the boosting
# like plot the points and show how the decision boundary shifts with each generation
# alongside an animation of the actual decision tree morphing each step
# That might look too much like an animation of the model being trained though
# which I guess that's sort of what it is so idk
# https://scikit-learn.org/stable/modules/ensemble.html#interpretation-with-feature-importance
# also
# you can see what path a data point takes through the forest
# with RandomForestClassifier.decision_path()
# which might be really cool
# to see like 10 trees and the path through each tree and what each tree predicted
if __name__ == '__main__':
main()