hgbdt-viz / train_classifier.py
none
Working version of the streamlit animation
045d7d4
raw
history blame
No virus
3.89 kB
import joblib
import pandas as pd
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.metrics import classification_report
def main():
train_df = pd.read_csv('train_data.csv', na_values='-')
# `service` is about half-empty and the rest are completely full
# one of the rows has `no` for `state` which isn't listed as an option in the description of the fields
# I'm just going to delete that
train_df = train_df.drop(columns=['id'])
train_df = train_df.drop(index=train_df[train_df['state']=='no'].index)
# It can predict `label` really well ~0.95 accuracy/f1/whatever other stat you care about
# It does a lot worse trying to predict `attack_cat` b/c there are 10 classes
# and some of them are not well-represented
# so that might be more interesting to visualize
cheating = train_df.pop('attack_cat')
y_enc = LabelEncoder().fit(train_df['label'])
train_y = y_enc.transform(train_df.pop('label'))
x_enc = OrdinalEncoder().fit(train_df)
train_df = x_enc.transform(train_df)
# Random forest doesn't handle NaNs
# I could drop the `service` column or I can use the HistGradientBoostingClassifier
# super helpful error message from sklearn pointed me to this list
# https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values
#rf = RandomForestClassifier()
#rf.fit(train_df, y_train)
# max_iter is the number of time it builds a gradient-boosted tree
# so it's the number of estimators
hgb = HistGradientBoostingClassifier(max_iter=10).fit(train_df, train_y)
joblib.dump(hgb, 'hgb_classifier.joblib', compress=9)
test_df = pd.read_csv('test_data.csv', na_values='-')
test_df = test_df.drop(columns=['id', 'attack_cat'])
test_y = y_enc.transform(test_df.pop('label'))
test_df = x_enc.transform(test_df)
test_preds = hgb.predict(test_df)
print(classification_report(test_y, test_preds))
# I guess they took out the RF feature importance
# or maybe that's only in XGBoost
# you can still kind of get to it
# with RandomForestClassifier.feature_importances_
# or like this
# https://scikit-learn.org/stable/auto_examples/ensemble/plot_forest_importances.html
# but there's really nothing for the HistGradientBoostingClassifier
# but you can get to the actual nodes for each predictor/estimator like this
# hgb._predictors[i][0].nodes
# and that has information gain metric for each node which might be viz-able
# so that might be an interesting viz
# like plot the whole forest
# maybe only do like 10 estimators to keep it smaller
# or stick with 100 and figure out a good way to viz big models
# the first two estimators are almost identical
# so maybe like plot the first estimator
# and then fuzz the nodes by how much the other estimators differ
# assuming there's some things they all agree on exactly and others where they differ a little bit
# idk I don't really know how the algorithm works
# the 96th estimator looks pretty different (I'm assuming from boosting)
# so maybe like an evolution animation from the first to the last
# to see the effect of the boosting
# like plot the points and show how the decision boundary shifts with each generation
# alongside an animation of the actual decision tree morphing each step
# That might look too much like an animation of the model being trained though
# which I guess that's sort of what it is so idk
# https://scikit-learn.org/stable/modules/ensemble.html#interpretation-with-feature-importance
# also
# you can see what path a data point takes through the forest
# with RandomForestClassifier.decision_path()
# which might be really cool
# to see like 10 trees and the path through each tree and what each tree predicted
if __name__ == '__main__':
main()