import joblib import pandas as pd import plotly.graph_objects as go import plotly.express as px hgb = joblib.load('hgb_classifier.joblib') FEATS = [ 'srcip', 'sport', 'dstip', 'dsport', 'proto', #'state', I dropped this one when I trained the model 'dur', 'sbytes', 'dbytes', 'sttl', 'dttl', 'sloss', 'dloss', 'service', 'Sload', 'Dload', 'Spkts', 'Dpkts', 'swin', 'dwin', 'stcpb', 'dtcpb', 'smeansz', 'dmeansz', 'trans_depth', 'res_bdy_len', 'Sjit', 'Djit', 'Stime', 'Ltime', 'Sintpkt', 'Dintpkt', 'tcprtt', 'synack', 'ackdat', 'is_sm_ips_ports', 'ct_state_ttl', 'ct_flw_http_mthd', 'is_ftp_login', 'ct_ftp_cmd', 'ct_srv_src', 'ct_srv_dst', 'ct_dst_ltm', 'ct_src_ltm', 'ct_src_dport_ltm', 'ct_dst_sport_ltm', 'ct_dst_src_ltm', ] # plotly only has the CSS named colors # I don't think I can use xkcd colors # I copied a bunch of CSS colors from somewhere online # and then deleted whites and things that showed up too close on the tree # this is not really a general solution, it just works for this specific tree # I'll have to come up with a better colormap at some point COLORS = [ 'aliceblue','aqua','aquamarine','azure', 'bisque','black','blanchedalmond','blue', 'blueviolet','brown','burlywood','cadetblue', 'chartreuse','chocolate','coral','cornflowerblue', 'cornsilk','crimson','cyan','darkblue','darkcyan', 'darkgoldenrod','darkgray','darkgreen', 'darkkhaki','darkmagenta','darkolivegreen','darkorange', 'darkorchid','darkred','darksalmon','darkseagreen', 'darkslateblue','darkslategray', 'darkturquoise','darkviolet','deeppink','deepskyblue', 'dimgray','dodgerblue', 'forestgreen','fuchsia','gainsboro', 'gold','goldenrod','gray','green', 'greenyellow','honeydew','hotpink','indianred','indigo', 'ivory','khaki','lavender','lavenderblush','lawngreen', 'lemonchiffon','lightblue','lightcoral','lightcyan', 'lightgoldenrodyellow','lightgray', 'lightgreen','lightpink','lightsalmon','lightseagreen', 'lightskyblue','lightslategray', 'lightsteelblue','lightyellow','lime','limegreen', 'linen','magenta','maroon','mediumaquamarine', 'mediumblue','mediumorchid','mediumpurple', 'mediumseagreen','mediumslateblue','mediumspringgreen', 'mediumturquoise','mediumvioletred','midnightblue', 'mintcream','mistyrose','moccasin','navy', 'oldlace','olive','olivedrab','orange','orangered', 'orchid','palegoldenrod','palegreen','paleturquoise', 'palevioletred','papayawhip','peachpuff','peru','pink', 'plum','powderblue','purple','red','rosybrown', 'royalblue','saddlebrown','salmon','sandybrown', 'seagreen','seashell','sienna','silver','skyblue', 'slateblue','slategray','slategrey','snow','springgreen', 'steelblue','tan','teal','thistle','tomato','turquoise', 'violet','wheat','yellow','yellowgreen' ] trees = [x[0].nodes for x in hgb._predictors] # the final tree definitely has a similar structure but is noticably different # that's really cool # I think this will make a cool animation # if I can figure it out tree = pd.DataFrame(trees[0]) #tree = pd.DataFrame(trees[9]) # parents is going to be tricky # I need get the index of whichever node has the current node listed in either left or right parents = [None] # keep track of whether each node is a left or right child of the parent in the list directions = [None] # it uses 0 to say "no left/right child" # so I have to skip searching for node 0 # which is fine b/c node 0 is the root for i in tree.index[1:]: # it seems to make a very even tree # so just guess it's in the right side # and that will be right half the time parent = tree[tree['right']==i].index if parent.empty: parents.append(str(tree[tree['left']==i].index[0])) directions.append('l') else: parents.append(str(parent[0])) directions.append('r') # generate the labels # and the colors labels = ['Histogram Gradient-Boosted Decision Tree'] colors = ['white'] for i, node, parent, direction in zip( tree.index.to_numpy(), tree.iterrows(), parents, directions ): # skip the first one (the root) if i == 0: continue node = node[1] feat = FEATS[int(tree.loc[int(parent), 'feature_idx'])] thresh = tree.loc[int(parent), 'num_threshold'] if direction == 'l': labels.append(f"[{i}] {feat} <= {thresh}") else: labels.append(f"[{i}] {feat} > {thresh}") # colors offset = FEATS.index(feat) colors.append(COLORS[offset]) # actual plot f = go.Figure( go.Treemap( values=tree['count'].to_numpy(), labels=labels, ids=tree.index.to_numpy(), parents=parents, marker_colors=colors, ) ) #f.update_layout( # treemapcolorway = ['pink'] #) breakpoint() # converting the ndarry with columns names to a pandas df # 3284 bytes as an ndarry # 3300 bytes as a dataframe # so they're the same size # do I need to convert it to pandas? idk # just curious # https://linuxtut.com/en/ffb2e319db5545965933/ # https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx # figuring out how the thing works # `value` is the predicted class / value / whatever # so if it's a leaf node, it returns that value as the prediction # there are negative values in some of the leaves # maybe the classes are +/-1 instead of 0/1? # if the data value is <= `num_threshold` then it goes in the left node # if it's > `num_threshold` then it goes in the right node # okay and then all the leave have feature_idx=0, num_threshold=0, left=0, right=0 # that makes sense # still kind of annoying that they use 0 instead of np.nan but oh well # also super super hard to figure out what the labels on the tree map should be # like it has to check the parent's feature_idx and num_threshold # which I guess isn't too bad once we have the list of parents already built # except that I don't know whether a node is left or right from its parent # hmmmm