ybelkada's picture
Update app.py
97db185 verified
raw
history blame contribute delete
No virus
3.25 kB
import altair as alt
import gradio as gr
import pandas as pd
from functools import partial
from datasets import load_dataset
def get_data():
model_id = "ybelkada/model_cards_correct_tag"
dataset = load_dataset(model_id, split="train").to_pandas()
# Convert dataset to a pandas DataFrame and sort by commit_dates
df = pd.DataFrame(dataset)
df["commit_dates"] = pd.to_datetime(df["commit_dates"]) # Convert commit_dates to datetime format
df = df.sort_values(by="commit_dates")
melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type')
df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100
ratio_df = df[['commit_dates', 'ratio']].copy()
return ratio_df, melted_df
ratio_df, melted_df = get_data()
def make_plot(plot_type, refresh=False):
global ratio_df, melted_df
if refresh:
ratio_df, melted_df = get_data()
if plot_type == "Total models with missing 'transformers' tag":
highlight = alt.selection(type='single', on='mouseover',
fields=['type'], nearest=True)
base = alt.Chart(melted_df).encode(
x=alt.X('commit_dates:T', title='Date'),
y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"),
color='type:N',
)
points = base.mark_circle().encode(
opacity=alt.value(1),
).add_selection(
highlight
).properties(
width=1200,
height=800,
)
lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3))
)
return points + lines
else:
highlight = alt.selection(type='single', on='mouseover',
fields=['ratio'], nearest=True)
base = alt.Chart(ratio_df).encode(
x=alt.X('commit_dates:T', title='Date'),
y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"),
)
points = base.mark_circle().encode(
opacity=alt.value(1)
).add_selection(
highlight
).properties(
width=1200,
height=800,
)
lines = base.mark_line().encode(
size=alt.condition(~highlight, alt.value(1), alt.value(3))
)
return points + lines
with gr.Blocks() as demo:
button = gr.Radio(
label="Plot type",
choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"],
value="Total models with missing 'transformers' tag"
)
refresh_button = gr.Button(value="Fetch latest data")
plot = gr.Plot(label="Plot")
button.change(make_plot, inputs=[button], outputs=[plot])
refresh_button.click(partial(make_plot, refresh=True), inputs=[button], outputs=[plot])
demo.load(make_plot, inputs=[button], outputs=[plot])
if __name__ == "__main__":
demo.launch()