File size: 3,508 Bytes
75e0a65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import numpy as np
from functools import partial

from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering


theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
)
model_card = f"""
## Description

This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset.
There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan`
You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations.
Note: If `linkage criterion` is **ward**, only **euclidean** can use


## Dataset

Iris dataset
"""
iris = load_iris()
X = iris.data

def iter_grid(n_rows, n_cols):
    # create a grid using gradio Block
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield

def plot_dendrogram(linkage_name, metric_name):
    # Create linkage matrix and then plot the dendrogram
    if linkage_name == "ward" and metric_name != "euclidean":
        return None
    # setting distance_threshold=0 ensures we compute the full tree.
    model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name)

    model = model.fit(X)

    # plot the top three levels of the dendrogram

    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)
    fig, axes = plt.subplots()

    dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3)
    # Plot the corresponding dendrogram
    axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}")
    axes.set_xlabel("Number of points in node (or index of point if no parenthesis).")
    return fig



with gr.Blocks(theme=theme) as demo:
    gr.Markdown('''
            <div>
            <h1 style='text-align: center'>Hierarchical Clustering Dendrogram</h1>
            </div>
        ''')
    gr.Markdown(model_card)
    gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py\">scikit-learn</a>")
    input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use")
    metrics = ["euclidean", "l1", "l2", "manhattan"]
    counter = 0
    for _ in iter_grid(2, 2):
        if counter >= len(metrics):
            break

        input_metric = metrics[counter]
        plot = gr.Plot(label=input_metric)
        fn = partial(plot_dendrogram, metric_name=input_metric)
        input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot)
        counter += 1

demo.launch()