Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
import numpy as np | |
from functools import partial | |
from matplotlib import pyplot as plt | |
from scipy.cluster.hierarchy import dendrogram | |
from sklearn.datasets import load_iris | |
from sklearn.cluster import AgglomerativeClustering | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
) | |
model_card = f""" | |
## Description | |
This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset. | |
There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan` | |
You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations. | |
Note: If `linkage criterion` is **ward**, only **euclidean** can use | |
## Dataset | |
Iris dataset | |
""" | |
iris = load_iris() | |
X = iris.data | |
def iter_grid(n_rows, n_cols): | |
# create a grid using gradio Block | |
for _ in range(n_rows): | |
with gr.Row(): | |
for _ in range(n_cols): | |
with gr.Column(): | |
yield | |
def plot_dendrogram(linkage_name, metric_name): | |
# Create linkage matrix and then plot the dendrogram | |
if linkage_name == "ward" and metric_name != "euclidean": | |
return None | |
# setting distance_threshold=0 ensures we compute the full tree. | |
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name) | |
model = model.fit(X) | |
# plot the top three levels of the dendrogram | |
counts = np.zeros(model.children_.shape[0]) | |
n_samples = len(model.labels_) | |
for i, merge in enumerate(model.children_): | |
current_count = 0 | |
for child_idx in merge: | |
if child_idx < n_samples: | |
current_count += 1 # leaf node | |
else: | |
current_count += counts[child_idx - n_samples] | |
counts[i] = current_count | |
linkage_matrix = np.column_stack( | |
[model.children_, model.distances_, counts] | |
).astype(float) | |
fig, axes = plt.subplots() | |
dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3) | |
# Plot the corresponding dendrogram | |
axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}") | |
axes.set_xlabel("Number of points in node (or index of point if no parenthesis).") | |
return fig | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown(''' | |
<div> | |
<h1 style='text-align: center'>Hierarchical Clustering Dendrogram</h1> | |
</div> | |
''') | |
gr.Markdown(model_card) | |
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py\">scikit-learn</a>") | |
input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use") | |
metrics = ["euclidean", "l1", "l2", "manhattan"] | |
counter = 0 | |
for _ in iter_grid(2, 2): | |
if counter >= len(metrics): | |
break | |
input_metric = metrics[counter] | |
plot = gr.Plot(label=input_metric) | |
fn = partial(plot_dendrogram, metric_name=input_metric) | |
input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot) | |
counter += 1 | |
demo.launch() |