vumichien's picture
Create app.py
75e0a65
raw
history blame contribute delete
No virus
3.51 kB
import gradio as gr
import numpy as np
from functools import partial
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering
theme = gr.themes.Monochrome(
primary_hue="indigo",
secondary_hue="blue",
neutral_hue="slate",
)
model_card = f"""
## Description
This demo shows the plot of the corresponding **Dendrogram of Hierarchical Clustering** using **AgglomerativeClustering** and the dendrogram method on the Iris dataset.
There are several metrics that use to compute the distance like `euclidean`, `l1`, `l2`, `manhattan`
You can play around with different ``linkage criterion``. The linkage criterion determines which distance to use between sets of observations.
Note: If `linkage criterion` is **ward**, only **euclidean** can use
## Dataset
Iris dataset
"""
iris = load_iris()
X = iris.data
def iter_grid(n_rows, n_cols):
# create a grid using gradio Block
for _ in range(n_rows):
with gr.Row():
for _ in range(n_cols):
with gr.Column():
yield
def plot_dendrogram(linkage_name, metric_name):
# Create linkage matrix and then plot the dendrogram
if linkage_name == "ward" and metric_name != "euclidean":
return None
# setting distance_threshold=0 ensures we compute the full tree.
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric=metric_name, linkage=linkage_name)
model = model.fit(X)
# plot the top three levels of the dendrogram
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # leaf node
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)
fig, axes = plt.subplots()
dn1 = dendrogram(linkage_matrix, ax=axes, truncate_mode="level", p=3)
# Plot the corresponding dendrogram
axes.set_title(f"Hierarchical Clustering Dendrogram. Linkage criterion: {metric_name}")
axes.set_xlabel("Number of points in node (or index of point if no parenthesis).")
return fig
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''
<div>
<h1 style='text-align: center'>Hierarchical Clustering Dendrogram</h1>
</div>
''')
gr.Markdown(model_card)
gr.Markdown("Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the example from <a href=\"https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py\">scikit-learn</a>")
input_linkage = gr.Radio(choices=["ward", "complete", "average", "single"], value="average", label="Linkage criterion to use")
metrics = ["euclidean", "l1", "l2", "manhattan"]
counter = 0
for _ in iter_grid(2, 2):
if counter >= len(metrics):
break
input_metric = metrics[counter]
plot = gr.Plot(label=input_metric)
fn = partial(plot_dendrogram, metric_name=input_metric)
input_linkage.change(fn=fn, inputs=[input_linkage], outputs=plot)
counter += 1
demo.launch()